Skip to content

Commit 4f35db0

Browse files
authored
[Fix] Fix Pose3dInferencer keypoint shape bug (#2543)
1 parent 9782368 commit 4f35db0

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

mmpose/apis/inferencers/pose3d_inferencer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def preprocess_single(self,
272272
),
273273
dtype=np.float32)
274274
data_info['lifting_target'] = np.zeros((1, K, 3), dtype=np.float32)
275+
data_info['factor'] = np.zeros((T, ), dtype=np.float32)
275276
data_info['lifting_target_visible'] = np.ones((1, K, 1),
276277
dtype=np.float32)
277278
data_info['camera_param'] = dict(w=width, h=height)
@@ -299,7 +300,6 @@ def forward(self,
299300
list: A list of data samples, each containing the model's output
300301
results.
301302
"""
302-
303303
pose_lift_results = self.model.test_step(inputs)
304304

305305
# Post-processing of pose estimation results
@@ -309,8 +309,16 @@ def forward(self,
309309
pose_lift_res.track_id = pose_est_results_converted[idx].get(
310310
'track_id', 1e4)
311311

312-
# Invert x and z values of the keypoints
312+
# align the shape of output keypoints coordinates and scores
313313
keypoints = pose_lift_res.pred_instances.keypoints
314+
keypoint_scores = pose_lift_res.pred_instances.keypoint_scores
315+
if keypoint_scores.ndim == 3:
316+
pose_lift_results[idx].pred_instances.keypoint_scores = \
317+
np.squeeze(keypoint_scores, axis=1)
318+
if keypoints.ndim == 4:
319+
keypoints = np.squeeze(keypoints, axis=1)
320+
321+
# Invert x and z values of the keypoints
314322
keypoints = keypoints[..., [0, 2, 1]]
315323
keypoints[..., 0] = -keypoints[..., 0]
316324
keypoints[..., 2] = -keypoints[..., 2]

0 commit comments

Comments
 (0)