diff --git a/supervision/tracker/byte_tracker/basetrack.py b/supervision/tracker/byte_tracker/basetrack.py index c8fadefa4..806f75384 100644 --- a/supervision/tracker/byte_tracker/basetrack.py +++ b/supervision/tracker/byte_tracker/basetrack.py @@ -39,6 +39,14 @@ def next_id() -> int: BaseTrack._count += 1 return BaseTrack._count + @staticmethod + def reset_counter(): + BaseTrack._count = 0 + BaseTrack.track_id = 0 + BaseTrack.start_frame = 0 + BaseTrack.frame_id = 0 + BaseTrack.time_since_update = 0 + def activate(self, *args): raise NotImplementedError diff --git a/supervision/tracker/byte_tracker/core.py b/supervision/tracker/byte_tracker/core.py index ffe1b1b5f..801c58f7e 100644 --- a/supervision/tracker/byte_tracker/core.py +++ b/supervision/tracker/byte_tracker/core.py @@ -210,7 +210,6 @@ def update_with_detections(self, detections: Detections) -> Detections: ```python import supervision as sv from ultralytics import YOLO - import numpy as np model = YOLO() tracker = sv.ByteTrack() @@ -261,6 +260,21 @@ def callback(frame: np.ndarray, index: int) -> np.ndarray: return detections + def reset(self): + """ + Resets the internal state of the ByteTrack tracker. + + This method clears the tracking data, including tracked, lost, + and removed tracks, as well as resetting the frame counter. It's + particularly useful when processing multiple videos sequentially, + ensuring the tracker starts with a clean state for each new video. + """ + self.frame_id = 0 + self.tracked_tracks: List[STrack] = [] + self.lost_tracks: List[STrack] = [] + self.removed_tracks: List[STrack] = [] + BaseTrack.reset_counter() + def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: """ Updates the tracker with the provided tensors and returns the updated tracks. @@ -306,6 +320,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: """ Add newly detected tracklets to tracked_stracks""" unconfirmed = [] tracked_stracks = [] # type: list[STrack] + for track in self.tracked_tracks: if not track.is_activated: unconfirmed.append(track)