From 6993ca3d718030fa666ff0e9a23329aafd4c2bf2 Mon Sep 17 00:00:00 2001 From: Adonai Vera <45982251+AdonaiVera@users.noreply.github.com> Date: Wed, 31 Jan 2024 15:07:27 -0500 Subject: [PATCH 1/7] Reset function to process multiple videos --- .../ultralytics_example_multiple_videos.py | 102 ++++++++++++++++++ supervision/tracker/byte_tracker/basetrack.py | 10 ++ supervision/tracker/byte_tracker/core.py | 13 ++- 3 files changed, 122 insertions(+), 3 deletions(-) create mode 100644 examples/tracking/ultralytics_example_multiple_videos.py diff --git a/examples/tracking/ultralytics_example_multiple_videos.py b/examples/tracking/ultralytics_example_multiple_videos.py new file mode 100644 index 000000000..9d71e5b27 --- /dev/null +++ b/examples/tracking/ultralytics_example_multiple_videos.py @@ -0,0 +1,102 @@ +import argparse + +from tqdm import tqdm +from ultralytics import YOLO + +import supervision as sv + + +def process_video( + source_weights_path: str, + source_video_path: str, + target_video_path: str, + tracker: sv.ByteTrack, + confidence_threshold: float = 0.3, + iou_threshold: float = 0.7, +) -> None: + model = YOLO(source_weights_path) + + box_annotator = sv.BoundingBoxAnnotator() + label_annotator = sv.LabelAnnotator() + frame_generator = sv.get_video_frames_generator(source_path=source_video_path) + video_info = sv.VideoInfo.from_video_path(video_path=source_video_path) + + with sv.VideoSink(target_path=target_video_path, video_info=video_info) as sink: + for frame in tqdm(frame_generator, total=video_info.total_frames): + results = model( + frame, verbose=False, conf=confidence_threshold, iou=iou_threshold + )[0] + detections = sv.Detections.from_ultralytics(results) + detections = tracker.update_with_detections(detections) + + labels = [ + f"#{tracker_id} {model.model.names[class_id]} {confidence:0.2f}" + for _, _, confidence, class_id, tracker_id, data + in detections + ] + + annotated_frame = box_annotator.annotate( + scene=frame.copy(), detections=detections + ) + + annotated_labeled_frame = label_annotator.annotate( + scene=annotated_frame, detections=detections, labels=labels + ) + + sink.write_frame(frame=annotated_labeled_frame) + # Reset the tracker after processing the video + tracker.reset() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Video Processing with YOLO and ByteTrack" + ) + parser.add_argument( + "--source_weights_path", + required=True, + help="Path to the source weights file", + type=str, + ) + parser.add_argument( + "--source_video_paths", + required=True, + help="Paths to the source video files", + nargs="+", + type=str, + ) + parser.add_argument( + "--target_video_paths", + required=True, + help="Paths to the target video files (output)", + nargs="+", + type=str, + ) + parser.add_argument( + "--confidence_threshold", + default=0.3, + help="Confidence threshold for the model", + type=float, + ) + parser.add_argument( + "--iou_threshold", default=0.7, help="IOU threshold for the model", type=float + ) + + args = parser.parse_args() + + source_video_paths = args.source_video_paths[0].split(',') + target_video_paths = args.target_video_paths[0].split(',') + + source_video_paths = [path.strip() for path in source_video_paths] + target_video_paths = [path.strip() for path in target_video_paths] + tracker = sv.ByteTrack() + + for source_video_path, target_video_path in zip(source_video_paths, target_video_paths): + process_video( + source_weights_path=args.source_weights_path, + source_video_path=source_video_path, + target_video_path=target_video_path, + confidence_threshold=args.confidence_threshold, + iou_threshold=args.iou_threshold, + tracker=tracker, + ) \ No newline at end of file diff --git a/supervision/tracker/byte_tracker/basetrack.py b/supervision/tracker/byte_tracker/basetrack.py index c8fadefa4..a62c412fe 100644 --- a/supervision/tracker/byte_tracker/basetrack.py +++ b/supervision/tracker/byte_tracker/basetrack.py @@ -38,6 +38,15 @@ def end_frame(self) -> int: def next_id() -> int: BaseTrack._count += 1 return BaseTrack._count + + @staticmethod + def reset_counter(): + BaseTrack._count = 0 + BaseTrack.track_id = 0 + BaseTrack.start_frame = 0 + BaseTrack.frame_id = 0 + BaseTrack.time_since_update = 0 + def activate(self, *args): raise NotImplementedError @@ -53,3 +62,4 @@ def mark_lost(self): def mark_removed(self): self.state = TrackState.Removed + \ No newline at end of file diff --git a/supervision/tracker/byte_tracker/core.py b/supervision/tracker/byte_tracker/core.py index ffe1b1b5f..cceb44988 100644 --- a/supervision/tracker/byte_tracker/core.py +++ b/supervision/tracker/byte_tracker/core.py @@ -210,7 +210,6 @@ def update_with_detections(self, detections: Detections) -> Detections: ```python import supervision as sv from ultralytics import YOLO - import numpy as np model = YOLO() tracker = sv.ByteTrack() @@ -221,7 +220,7 @@ def update_with_detections(self, detections: Detections) -> Detections: def callback(frame: np.ndarray, index: int) -> np.ndarray: results = model(frame)[0] detections = sv.Detections.from_ultralytics(results) - detections = tracker.update_with_detections(detections) + detections = byte_tracker.update_with_detections(detections) labels = [f"#{tracker_id}" for tracker_id in detections.tracker_id] @@ -261,6 +260,13 @@ def callback(frame: np.ndarray, index: int) -> np.ndarray: return detections + def reset(self): + self.frame_id = 0 + self.tracked_tracks: List[STrack] = [] + self.lost_tracks: List[STrack] = [] + self.removed_tracks: List[STrack] = [] + BaseTrack.reset_counter() + def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: """ Updates the tracker with the provided tensors and returns the updated tracks. @@ -306,6 +312,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: """ Add newly detected tracklets to tracked_stracks""" unconfirmed = [] tracked_stracks = [] # type: list[STrack] + for track in self.tracked_tracks: if not track.is_activated: unconfirmed.append(track) @@ -481,4 +488,4 @@ def remove_duplicate_tracks(tracks_a: List, tracks_b: List) -> Tuple[List, List] track for index, track in enumerate(tracks_b) if index not in duplicates_b ] - return result_a, result_b + return result_a, result_b \ No newline at end of file From 9ba0511610ed070127a2258f3457fc663366d3f5 Mon Sep 17 00:00:00 2001 From: Adonai Vera <45982251+AdonaiVera@users.noreply.github.com> Date: Wed, 31 Jan 2024 15:30:03 -0500 Subject: [PATCH 2/7] Change import name --- supervision/tracker/byte_tracker/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/supervision/tracker/byte_tracker/core.py b/supervision/tracker/byte_tracker/core.py index cceb44988..8c1778e93 100644 --- a/supervision/tracker/byte_tracker/core.py +++ b/supervision/tracker/byte_tracker/core.py @@ -220,7 +220,7 @@ def update_with_detections(self, detections: Detections) -> Detections: def callback(frame: np.ndarray, index: int) -> np.ndarray: results = model(frame)[0] detections = sv.Detections.from_ultralytics(results) - detections = byte_tracker.update_with_detections(detections) + detections = tracker.update_with_detections(detections) labels = [f"#{tracker_id}" for tracker_id in detections.tracker_id] From bef5fb02d0333022301a1f5f86b97f1f61df89a1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 31 Jan 2024 20:32:49 +0000 Subject: [PATCH 3/7] =?UTF-8?q?fix(pre=5Fcommit):=20=F0=9F=8E=A8=20auto=20?= =?UTF-8?q?format=20pre-commit=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ultralytics_example_multiple_videos.py | 15 ++++++++------- supervision/tracker/byte_tracker/basetrack.py | 4 +--- supervision/tracker/byte_tracker/core.py | 2 +- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/examples/tracking/ultralytics_example_multiple_videos.py b/examples/tracking/ultralytics_example_multiple_videos.py index 9d71e5b27..77f076e11 100644 --- a/examples/tracking/ultralytics_example_multiple_videos.py +++ b/examples/tracking/ultralytics_example_multiple_videos.py @@ -30,9 +30,8 @@ def process_video( detections = tracker.update_with_detections(detections) labels = [ - f"#{tracker_id} {model.model.names[class_id]} {confidence:0.2f}" - for _, _, confidence, class_id, tracker_id, data - in detections + f"#{tracker_id} {model.model.names[class_id]} {confidence:0.2f}" + for _, _, confidence, class_id, tracker_id, data in detections ] annotated_frame = box_annotator.annotate( @@ -84,14 +83,16 @@ def process_video( args = parser.parse_args() - source_video_paths = args.source_video_paths[0].split(',') - target_video_paths = args.target_video_paths[0].split(',') + source_video_paths = args.source_video_paths[0].split(",") + target_video_paths = args.target_video_paths[0].split(",") source_video_paths = [path.strip() for path in source_video_paths] target_video_paths = [path.strip() for path in target_video_paths] tracker = sv.ByteTrack() - for source_video_path, target_video_path in zip(source_video_paths, target_video_paths): + for source_video_path, target_video_path in zip( + source_video_paths, target_video_paths + ): process_video( source_weights_path=args.source_weights_path, source_video_path=source_video_path, @@ -99,4 +100,4 @@ def process_video( confidence_threshold=args.confidence_threshold, iou_threshold=args.iou_threshold, tracker=tracker, - ) \ No newline at end of file + ) diff --git a/supervision/tracker/byte_tracker/basetrack.py b/supervision/tracker/byte_tracker/basetrack.py index a62c412fe..806f75384 100644 --- a/supervision/tracker/byte_tracker/basetrack.py +++ b/supervision/tracker/byte_tracker/basetrack.py @@ -38,7 +38,7 @@ def end_frame(self) -> int: def next_id() -> int: BaseTrack._count += 1 return BaseTrack._count - + @staticmethod def reset_counter(): BaseTrack._count = 0 @@ -47,7 +47,6 @@ def reset_counter(): BaseTrack.frame_id = 0 BaseTrack.time_since_update = 0 - def activate(self, *args): raise NotImplementedError @@ -62,4 +61,3 @@ def mark_lost(self): def mark_removed(self): self.state = TrackState.Removed - \ No newline at end of file diff --git a/supervision/tracker/byte_tracker/core.py b/supervision/tracker/byte_tracker/core.py index 8c1778e93..f8e88f0b6 100644 --- a/supervision/tracker/byte_tracker/core.py +++ b/supervision/tracker/byte_tracker/core.py @@ -488,4 +488,4 @@ def remove_duplicate_tracks(tracks_a: List, tracks_b: List) -> Tuple[List, List] track for index, track in enumerate(tracks_b) if index not in duplicates_b ] - return result_a, result_b \ No newline at end of file + return result_a, result_b From 9da2d9f761ae30285453f42fc408fcb39b40f05f Mon Sep 17 00:00:00 2001 From: Adonai Vera <45982251+AdonaiVera@users.noreply.github.com> Date: Wed, 31 Jan 2024 16:26:49 -0500 Subject: [PATCH 4/7] Remove ultralytics example multiple video --- .../ultralytics_example_multiple_videos.py | 103 ------------------ 1 file changed, 103 deletions(-) delete mode 100644 examples/tracking/ultralytics_example_multiple_videos.py diff --git a/examples/tracking/ultralytics_example_multiple_videos.py b/examples/tracking/ultralytics_example_multiple_videos.py deleted file mode 100644 index 77f076e11..000000000 --- a/examples/tracking/ultralytics_example_multiple_videos.py +++ /dev/null @@ -1,103 +0,0 @@ -import argparse - -from tqdm import tqdm -from ultralytics import YOLO - -import supervision as sv - - -def process_video( - source_weights_path: str, - source_video_path: str, - target_video_path: str, - tracker: sv.ByteTrack, - confidence_threshold: float = 0.3, - iou_threshold: float = 0.7, -) -> None: - model = YOLO(source_weights_path) - - box_annotator = sv.BoundingBoxAnnotator() - label_annotator = sv.LabelAnnotator() - frame_generator = sv.get_video_frames_generator(source_path=source_video_path) - video_info = sv.VideoInfo.from_video_path(video_path=source_video_path) - - with sv.VideoSink(target_path=target_video_path, video_info=video_info) as sink: - for frame in tqdm(frame_generator, total=video_info.total_frames): - results = model( - frame, verbose=False, conf=confidence_threshold, iou=iou_threshold - )[0] - detections = sv.Detections.from_ultralytics(results) - detections = tracker.update_with_detections(detections) - - labels = [ - f"#{tracker_id} {model.model.names[class_id]} {confidence:0.2f}" - for _, _, confidence, class_id, tracker_id, data in detections - ] - - annotated_frame = box_annotator.annotate( - scene=frame.copy(), detections=detections - ) - - annotated_labeled_frame = label_annotator.annotate( - scene=annotated_frame, detections=detections, labels=labels - ) - - sink.write_frame(frame=annotated_labeled_frame) - # Reset the tracker after processing the video - tracker.reset() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Video Processing with YOLO and ByteTrack" - ) - parser.add_argument( - "--source_weights_path", - required=True, - help="Path to the source weights file", - type=str, - ) - parser.add_argument( - "--source_video_paths", - required=True, - help="Paths to the source video files", - nargs="+", - type=str, - ) - parser.add_argument( - "--target_video_paths", - required=True, - help="Paths to the target video files (output)", - nargs="+", - type=str, - ) - parser.add_argument( - "--confidence_threshold", - default=0.3, - help="Confidence threshold for the model", - type=float, - ) - parser.add_argument( - "--iou_threshold", default=0.7, help="IOU threshold for the model", type=float - ) - - args = parser.parse_args() - - source_video_paths = args.source_video_paths[0].split(",") - target_video_paths = args.target_video_paths[0].split(",") - - source_video_paths = [path.strip() for path in source_video_paths] - target_video_paths = [path.strip() for path in target_video_paths] - tracker = sv.ByteTrack() - - for source_video_path, target_video_path in zip( - source_video_paths, target_video_paths - ): - process_video( - source_weights_path=args.source_weights_path, - source_video_path=source_video_path, - target_video_path=target_video_path, - confidence_threshold=args.confidence_threshold, - iou_threshold=args.iou_threshold, - tracker=tracker, - ) From 5f8127a1a553660145f4ebbea51d831cf89ef07d Mon Sep 17 00:00:00 2001 From: Adonai Vera <45982251+AdonaiVera@users.noreply.github.com> Date: Thu, 1 Feb 2024 00:30:48 -0500 Subject: [PATCH 5/7] Add docstring to the reset function --- supervision/tracker/byte_tracker/core.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/supervision/tracker/byte_tracker/core.py b/supervision/tracker/byte_tracker/core.py index f8e88f0b6..a64ab0a4b 100644 --- a/supervision/tracker/byte_tracker/core.py +++ b/supervision/tracker/byte_tracker/core.py @@ -261,6 +261,20 @@ def callback(frame: np.ndarray, index: int) -> np.ndarray: return detections def reset(self): + """ + Resets the internal state of the ByteTrack tracker. + + This method is designed to clear the tracking data, including tracked, lost, and removed tracks, + as well as resetting the frame counter. It is particularly useful when processing multiple videos + sequentially, as it ensures the tracker starts with a clean state for each new video. + + Example: + tracker = ByteTrack() + tracker.reset() # Call this method before processing a new video + + No parameters are required for this method, and it does not return any value. It simply reinitializes + the internal state variables of the tracker to their default values. + """ self.frame_id = 0 self.tracked_tracks: List[STrack] = [] self.lost_tracks: List[STrack] = [] From 889044cf060e27ab50715ada44e8b94c45f2f941 Mon Sep 17 00:00:00 2001 From: Adonai Vera <45982251+AdonaiVera@users.noreply.github.com> Date: Thu, 1 Feb 2024 00:37:43 -0500 Subject: [PATCH 6/7] Add docs to reset function and reduce the length --- supervision/tracker/byte_tracker/core.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/supervision/tracker/byte_tracker/core.py b/supervision/tracker/byte_tracker/core.py index a64ab0a4b..1da9a70d4 100644 --- a/supervision/tracker/byte_tracker/core.py +++ b/supervision/tracker/byte_tracker/core.py @@ -264,16 +264,18 @@ def reset(self): """ Resets the internal state of the ByteTrack tracker. - This method is designed to clear the tracking data, including tracked, lost, and removed tracks, - as well as resetting the frame counter. It is particularly useful when processing multiple videos - sequentially, as it ensures the tracker starts with a clean state for each new video. + This method clears the tracking data, including tracked, lost, + and removed tracks, as well as resetting the frame counter. It's + particularly useful when processing multiple videos sequentially, + ensuring the tracker starts with a clean state for each new video. Example: tracker = ByteTrack() - tracker.reset() # Call this method before processing a new video + tracker.reset() # Call before processing a new video - No parameters are required for this method, and it does not return any value. It simply reinitializes - the internal state variables of the tracker to their default values. + This method requires no parameters and does not return any value. + It reinitializes the internal state variables of the tracker to + their default values. """ self.frame_id = 0 self.tracked_tracks: List[STrack] = [] From ee58e28d25f269c2994ac86f9ec8f9f063046f9b Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Thu, 1 Feb 2024 08:42:36 +0100 Subject: [PATCH 7/7] Small update in `reset` docstring. --- supervision/tracker/byte_tracker/core.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/supervision/tracker/byte_tracker/core.py b/supervision/tracker/byte_tracker/core.py index 1da9a70d4..801c58f7e 100644 --- a/supervision/tracker/byte_tracker/core.py +++ b/supervision/tracker/byte_tracker/core.py @@ -268,14 +268,6 @@ def reset(self): and removed tracks, as well as resetting the frame counter. It's particularly useful when processing multiple videos sequentially, ensuring the tracker starts with a clean state for each new video. - - Example: - tracker = ByteTrack() - tracker.reset() # Call before processing a new video - - This method requires no parameters and does not return any value. - It reinitializes the internal state variables of the tracker to - their default values. """ self.frame_id = 0 self.tracked_tracks: List[STrack] = []