1
1
from __future__ import annotations
2
2
3
3
from dataclasses import dataclass
4
- from typing import List , Optional , Union
4
+ from typing import Iterator , List , Optional , Tuple , Union
5
5
6
6
import cv2
7
7
import numpy as np
8
8
9
+ from supervision .detection .utils import non_max_suppression
9
10
from supervision .draw .color import Color , ColorPalette
10
11
from supervision .geometry .core import Position
11
12
@@ -17,22 +18,26 @@ class Detections:
17
18
18
19
Attributes:
19
20
xyxy (np.ndarray): An array of shape `(n, 4)` containing the bounding boxes coordinates in format `[x1, y1, x2, y2]`
20
- confidence (np.ndarray): An array of shape `(n,)` containing the confidence scores of the detections.
21
+ confidence (Optional[ np.ndarray] ): An array of shape `(n,)` containing the confidence scores of the detections.
21
22
class_id (np.ndarray): An array of shape `(n,)` containing the class ids of the detections.
22
23
tracker_id (Optional[np.ndarray]): An array of shape `(n,)` containing the tracker ids of the detections.
23
24
"""
24
25
25
26
xyxy : np .ndarray
26
- confidence : np .ndarray
27
27
class_id : np .ndarray
28
+ confidence : Optional [np .ndarray ] = None
28
29
tracker_id : Optional [np .ndarray ] = None
29
30
30
31
def __post_init__ (self ):
31
32
n = len (self .xyxy )
32
33
validators = [
33
34
(isinstance (self .xyxy , np .ndarray ) and self .xyxy .shape == (n , 4 )),
34
- (isinstance (self .confidence , np .ndarray ) and self .confidence .shape == (n ,)),
35
35
(isinstance (self .class_id , np .ndarray ) and self .class_id .shape == (n ,)),
36
+ self .confidence is None
37
+ or (
38
+ isinstance (self .confidence , np .ndarray )
39
+ and self .confidence .shape == (n ,)
40
+ ),
36
41
self .tracker_id is None
37
42
or (
38
43
isinstance (self .tracker_id , np .ndarray )
@@ -42,7 +47,7 @@ def __post_init__(self):
42
47
if not all (validators ):
43
48
raise ValueError (
44
49
"xyxy must be 2d np.ndarray with (n, 4) shape, "
45
- "confidence must be 1d np.ndarray with (n,) shape, "
50
+ "confidence must be None or 1d np.ndarray with (n,) shape, "
46
51
"class_id must be 1d np.ndarray with (n,) shape, "
47
52
"tracker_id must be None or 1d np.ndarray with (n,) shape"
48
53
)
@@ -53,14 +58,16 @@ def __len__(self):
53
58
"""
54
59
return len (self .xyxy )
55
60
56
- def __iter__ (self ):
61
+ def __iter__ (
62
+ self ,
63
+ ) -> Iterator [Tuple [np .ndarray , Optional [float ], int , Optional [Union [str , int ]]]]:
57
64
"""
58
65
Iterates over the Detections object and yield a tuple of `(xyxy, confidence, class_id, tracker_id)` for each detection.
59
66
"""
60
67
for i in range (len (self .xyxy )):
61
68
yield (
62
69
self .xyxy [i ],
63
- self .confidence [i ],
70
+ self .confidence [i ] if self . confidence is not None else None ,
64
71
self .class_id [i ],
65
72
self .tracker_id [i ] if self .tracker_id is not None else None ,
66
73
)
@@ -69,11 +76,17 @@ def __eq__(self, other: Detections):
69
76
return all (
70
77
[
71
78
np .array_equal (self .xyxy , other .xyxy ),
72
- np .array_equal (self .confidence , other .confidence ),
79
+ any (
80
+ [
81
+ self .confidence is None and other .confidence is None ,
82
+ np .array_equal (self .confidence , other .confidence ),
83
+ ]
84
+ ),
73
85
np .array_equal (self .class_id , other .class_id ),
74
86
any (
75
87
[
76
88
self .tracker_id is None and other .tracker_id is None ,
89
+ np .array_equal (self .tracker_id , other .tracker_id ),
77
90
]
78
91
),
79
92
]
@@ -122,7 +135,7 @@ def from_yolov8(cls, yolov8_results):
122
135
>>> from supervision import Detections
123
136
124
137
>>> model = YOLO('yolov8s.pt')
125
- >>> results = model(frame)
138
+ >>> results = model(frame)[0]
126
139
>>> detections = Detections.from_yolov8(results)
127
140
```
128
141
"""
@@ -132,6 +145,36 @@ def from_yolov8(cls, yolov8_results):
132
145
class_id = yolov8_results .boxes .cls .cpu ().numpy ().astype (int ),
133
146
)
134
147
148
+ @classmethod
149
+ def from_transformers (cls , transformers_results : dict ):
150
+ return cls (
151
+ xyxy = transformers_results ["boxes" ].cpu ().numpy (),
152
+ confidence = transformers_results ["scores" ].cpu ().numpy (),
153
+ class_id = transformers_results ["labels" ].cpu ().numpy ().astype (int ),
154
+ )
155
+
156
+ @classmethod
157
+ def from_detectron2 (cls , detectron2_results ):
158
+ return cls (
159
+ xyxy = detectron2_results ["instances" ].pred_boxes .tensor .cpu ().numpy (),
160
+ confidence = detectron2_results ["instances" ].scores .cpu ().numpy (),
161
+ class_id = detectron2_results ["instances" ]
162
+ .pred_classes .cpu ()
163
+ .numpy ()
164
+ .astype (int ),
165
+ )
166
+
167
+ @classmethod
168
+ def from_coco_annotations (cls , coco_annotation : dict ):
169
+ xyxy , class_id = [], []
170
+
171
+ for annotation in coco_annotation :
172
+ x_min , y_min , width , height = annotation ["bbox" ]
173
+ xyxy .append ([x_min , y_min , x_min + width , y_min + height ])
174
+ class_id .append (annotation ["category_id" ])
175
+
176
+ return cls (xyxy = np .array (xyxy ), class_id = np .array (class_id ))
177
+
135
178
def filter (self , mask : np .ndarray , inplace : bool = False ) -> Optional [Detections ]:
136
179
"""
137
180
Filter the detections by applying a mask.
@@ -186,7 +229,9 @@ def get_anchor_coordinates(self, anchor: Position) -> np.ndarray:
186
229
raise ValueError (f"{ anchor } is not supported." )
187
230
188
231
def __getitem__ (self , index : np .ndarray ) -> Detections :
189
- if isinstance (index , np .ndarray ) and index .dtype == bool :
232
+ if isinstance (index , np .ndarray ) and (
233
+ index .dtype == bool or index .dtype == int
234
+ ):
190
235
return Detections (
191
236
xyxy = self .xyxy [index ],
192
237
confidence = self .confidence [index ],
@@ -199,6 +244,17 @@ def __getitem__(self, index: np.ndarray) -> Detections:
199
244
f"Detections.__getitem__ not supported for index of type { type (index )} ."
200
245
)
201
246
247
+ @property
248
+ def area (self ) -> np .ndarray :
249
+ return (self .xyxy [:, 3 ] - self .xyxy [:, 1 ]) * (self .xyxy [:, 2 ] - self .xyxy [:, 0 ])
250
+
251
+ def with_nms (self , threshold : float = 0.5 ) -> Detections :
252
+ assert (
253
+ self .confidence is not None
254
+ ), f"Detections confidence must be given for NMS to be executed."
255
+ indices = non_max_suppression (self .xyxy , self .confidence , threshold = threshold )
256
+ return self [indices ]
257
+
202
258
203
259
class BoxAnnotator :
204
260
def __init__ (
@@ -266,7 +322,7 @@ def annotate(
266
322
continue
267
323
268
324
text = (
269
- f"{ confidence :0.2f } "
325
+ f"{ class_id } "
270
326
if (labels is None or len (detections ) != len (labels ))
271
327
else labels [i ]
272
328
)
0 commit comments