Skip to content

Commit 81d0c3a

Browse files
authored
Merge pull request #827 from AdonaiVera/add_method_reset_state_tracker
[ByteTrack] - add a method to reset the state of the tracker
2 parents 46426f8 + ee58e28 commit 81d0c3a

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

supervision/tracker/byte_tracker/basetrack.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ def next_id() -> int:
3939
BaseTrack._count += 1
4040
return BaseTrack._count
4141

42+
@staticmethod
43+
def reset_counter():
44+
BaseTrack._count = 0
45+
BaseTrack.track_id = 0
46+
BaseTrack.start_frame = 0
47+
BaseTrack.frame_id = 0
48+
BaseTrack.time_since_update = 0
49+
4250
def activate(self, *args):
4351
raise NotImplementedError
4452

supervision/tracker/byte_tracker/core.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ def update_with_detections(self, detections: Detections) -> Detections:
210210
```python
211211
import supervision as sv
212212
from ultralytics import YOLO
213-
import numpy as np
214213
215214
model = YOLO(<MODEL_PATH>)
216215
tracker = sv.ByteTrack()
@@ -261,6 +260,21 @@ def callback(frame: np.ndarray, index: int) -> np.ndarray:
261260

262261
return detections
263262

263+
def reset(self):
264+
"""
265+
Resets the internal state of the ByteTrack tracker.
266+
267+
This method clears the tracking data, including tracked, lost,
268+
and removed tracks, as well as resetting the frame counter. It's
269+
particularly useful when processing multiple videos sequentially,
270+
ensuring the tracker starts with a clean state for each new video.
271+
"""
272+
self.frame_id = 0
273+
self.tracked_tracks: List[STrack] = []
274+
self.lost_tracks: List[STrack] = []
275+
self.removed_tracks: List[STrack] = []
276+
BaseTrack.reset_counter()
277+
264278
def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
265279
"""
266280
Updates the tracker with the provided tensors and returns the updated tracks.
@@ -306,6 +320,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
306320
""" Add newly detected tracklets to tracked_stracks"""
307321
unconfirmed = []
308322
tracked_stracks = [] # type: list[STrack]
323+
309324
for track in self.tracked_tracks:
310325
if not track.is_activated:
311326
unconfirmed.append(track)

0 commit comments

Comments
 (0)