@@ -181,16 +181,11 @@ def collate_pose_sequence(pose_results_2d,
181181 pose_sequences = []
182182 for idx in range (N ):
183183 pose_seq = PoseDataSample ()
184- gt_instances = InstanceData ()
185184 pred_instances = InstanceData ()
186185
187- for k in pose_results_2d [target_frame ][idx ].gt_instances .keys ():
188- gt_instances .set_field (
189- pose_results_2d [target_frame ][idx ].gt_instances [k ], k )
190- for k in pose_results_2d [target_frame ][idx ].pred_instances .keys ():
191- if k != 'keypoints' :
192- pred_instances .set_field (
193- pose_results_2d [target_frame ][idx ].pred_instances [k ], k )
186+ gt_instances = pose_results_2d [target_frame ][idx ].gt_instances .clone ()
187+ pred_instances = pose_results_2d [target_frame ][
188+ idx ].pred_instances .clone ()
194189 pose_seq .pred_instances = pred_instances
195190 pose_seq .gt_instances = gt_instances
196191
@@ -228,7 +223,7 @@ def collate_pose_sequence(pose_results_2d,
228223 # replicate the right most frame
229224 keypoints [:, frame_idx + 1 :] = keypoints [:, frame_idx ]
230225 break
231- pose_seq .pred_instances .keypoints = keypoints
226+ pose_seq .pred_instances .set_field ( keypoints , ' keypoints' )
232227 pose_sequences .append (pose_seq )
233228
234229 return pose_sequences
@@ -276,8 +271,15 @@ def inference_pose_lifter_model(model,
276271 bbox_center = None
277272 bbox_scale = None
278273
274+ pose_results_2d_copy = []
279275 for i , pose_res in enumerate (pose_results_2d ):
276+ pose_res_copy = []
280277 for j , data_sample in enumerate (pose_res ):
278+ data_sample_copy = PoseDataSample ()
279+ data_sample_copy .gt_instances = data_sample .gt_instances .clone ()
280+ data_sample_copy .pred_instances = data_sample .pred_instances .clone (
281+ )
282+ data_sample_copy .track_id = data_sample .track_id
281283 kpts = data_sample .pred_instances .keypoints
282284 bboxes = data_sample .pred_instances .bboxes
283285 keypoints = []
@@ -292,11 +294,13 @@ def inference_pose_lifter_model(model,
292294 bbox_scale + bbox_center )
293295 else :
294296 keypoints .append (kpt [:, :2 ])
295- pose_results_2d [i ][j ].pred_instances .keypoints = np .array (
296- keypoints )
297+ data_sample_copy .pred_instances .set_field (
298+ np .array (keypoints ), 'keypoints' )
299+ pose_res_copy .append (data_sample_copy )
300+ pose_results_2d_copy .append (pose_res_copy )
297301
298- pose_sequences_2d = collate_pose_sequence (pose_results_2d , with_track_id ,
299- target_idx )
302+ pose_sequences_2d = collate_pose_sequence (pose_results_2d_copy ,
303+ with_track_id , target_idx )
300304
301305 if not pose_sequences_2d :
302306 return []
0 commit comments