1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
1
4
from typing import List , Optional , Union
2
5
3
6
import cv2
4
7
import numpy as np
5
8
6
9
from supervision .draw .color import Color , ColorPalette
10
+ from supervision .geometry .core import Position
7
11
8
12
13
+ @dataclass
9
14
class Detections :
10
- def __init__ (
11
- self ,
12
- xyxy : np .ndarray ,
13
- confidence : np .ndarray ,
14
- class_id : np .ndarray ,
15
- tracker_id : Optional [np .ndarray ] = None ,
16
- ):
17
- """
18
- Data class containing information about the detections in a video frame.
15
+ """
16
+ Data class containing information about the detections in a video frame.
19
17
20
- Attributes:
21
- xyxy (np.ndarray): An array of shape (n, 4) containing the bounding boxes coordinates in format [x1, y1, x2, y2]
22
- confidence (np.ndarray): An array of shape (n,) containing the confidence scores of the detections.
23
- class_id (np.ndarray): An array of shape (n,) containing the class ids of the detections.
24
- tracker_id (Optional[np.ndarray]): An array of shape (n,) containing the tracker ids of the detections.
25
- """
26
- self .xyxy : np .ndarray = xyxy
27
- self .confidence : np .ndarray = confidence
28
- self .class_id : np .ndarray = class_id
29
- self .tracker_id : Optional [np .ndarray ] = tracker_id
18
+ Attributes:
19
+ xyxy (np.ndarray): An array of shape `(n, 4)` containing the bounding boxes coordinates in format `[x1, y1, x2, y2]`
20
+ confidence (np.ndarray): An array of shape `(n,)` containing the confidence scores of the detections.
21
+ class_id (np.ndarray): An array of shape `(n,)` containing the class ids of the detections.
22
+ tracker_id (Optional[np.ndarray]): An array of shape `(n,)` containing the tracker ids of the detections.
23
+ """
30
24
25
+ xyxy : np .ndarray
26
+ confidence : np .ndarray
27
+ class_id : np .ndarray
28
+ tracker_id : Optional [np .ndarray ] = None
29
+
30
+ def __post_init__ (self ):
31
31
n = len (self .xyxy )
32
32
validators = [
33
33
(isinstance (self .xyxy , np .ndarray ) and self .xyxy .shape == (n , 4 )),
@@ -55,7 +55,7 @@ def __len__(self):
55
55
56
56
def __iter__ (self ):
57
57
"""
58
- Iterates over the Detections object and yield a tuple of (xyxy, confidence, class_id, tracker_id) for each detection.
58
+ Iterates over the Detections object and yield a tuple of ` (xyxy, confidence, class_id, tracker_id)` for each detection.
59
59
"""
60
60
for i in range (len (self .xyxy )):
61
61
yield (
@@ -66,37 +66,68 @@ def __iter__(self):
66
66
)
67
67
68
68
@classmethod
69
- def from_yolov5 (cls , yolov5_output : np . ndarray ):
69
+ def from_yolov5 (cls , yolov5_detections ):
70
70
"""
71
- Creates a Detections instance from a YOLOv5 output tensor
71
+ Creates a Detections instance from a YOLOv5 output Detections
72
72
73
73
Attributes:
74
- yolov5_output (np.ndarray ): The output tensor from YOLOv5
74
+ yolov5_detections (yolov5.models.common.Detections ): The output Detections instance from YOLOv5
75
75
76
76
Returns:
77
77
78
78
Example:
79
79
```python
80
- >>> from supervision.tools.detections import Detections
80
+ >>> import torch
81
+ >>> from supervision import Detections
81
82
82
- >>> detections = Detections.from_yolov5(yolov5_output)
83
+ >>> model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
84
+ >>> results = model(frame)
85
+ >>> detections = Detections.from_yolov5(results)
83
86
```
84
87
"""
85
- xyxy = yolov5_output [:, :4 ]
86
- confidence = yolov5_output [:, 4 ]
87
- class_id = yolov5_output [:, 5 ].astype (int )
88
- return cls (xyxy , confidence , class_id )
88
+ yolov5_detections_predictions = yolov5_detections .pred [0 ].cpu ().cpu ().numpy ()
89
+ return cls (
90
+ xyxy = yolov5_detections_predictions [:, :4 ],
91
+ confidence = yolov5_detections_predictions [:, 4 ],
92
+ class_id = yolov5_detections_predictions [:, 5 ].astype (int ),
93
+ )
89
94
90
- def filter (self , mask : np .ndarray , inplace : bool = False ) -> Optional [np .ndarray ]:
95
+ @classmethod
96
+ def from_yolov8 (cls , yolov8_results ):
97
+ """
98
+ Creates a Detections instance from a YOLOv8 output Results
99
+
100
+ Attributes:
101
+ yolov8_results (ultralytics.yolo.engine.results.Results): The output Results instance from YOLOv8
102
+
103
+ Returns:
104
+
105
+ Example:
106
+ ```python
107
+ >>> from ultralytics import YOLO
108
+ >>> from supervision import Detections
109
+
110
+ >>> model = YOLO('yolov8s.pt')
111
+ >>> results = model(frame)
112
+ >>> detections = Detections.from_yolov8(results)
113
+ ```
114
+ """
115
+ return cls (
116
+ xyxy = yolov8_results .boxes .xyxy .cpu ().numpy (),
117
+ confidence = yolov8_results .boxes .conf .cpu ().numpy (),
118
+ class_id = yolov8_results .boxes .cls .cpu ().numpy ().astype (int ),
119
+ )
120
+
121
+ def filter (self , mask : np .ndarray , inplace : bool = False ) -> Optional [Detections ]:
91
122
"""
92
123
Filter the detections by applying a mask.
93
124
94
125
Attributes:
95
- mask (np.ndarray): A mask of shape (n,) containing a boolean value for each detection indicating if it should be included in the filtered detections
126
+ mask (np.ndarray): A mask of shape ` (n,)` containing a boolean value for each detection indicating if it should be included in the filtered detections
96
127
inplace (bool): If True, the original data will be modified and self will be returned.
97
128
98
129
Returns:
99
- Optional[np.ndarray]: A new instance of Detections with the filtered detections, if inplace is set to False. None otherwise.
130
+ Optional[np.ndarray]: A new instance of Detections with the filtered detections, if inplace is set to ` False`. ` None` otherwise.
100
131
"""
101
132
if inplace :
102
133
self .xyxy = self .xyxy [mask ]
@@ -116,11 +147,49 @@ def filter(self, mask: np.ndarray, inplace: bool = False) -> Optional[np.ndarray
116
147
else None ,
117
148
)
118
149
150
+ def get_anchor_coordinates (self , anchor : Position ) -> np .ndarray :
151
+ """
152
+ Returns the bounding box coordinates for a specific anchor.
153
+
154
+ Properties:
155
+ anchor (Position): Position of bounding box anchor for which to return the coordinates.
156
+
157
+ Returns:
158
+ np.ndarray: An array of shape `(n, 2)` containing the bounding box anchor coordinates in format `[x, y]`.
159
+ """
160
+ if anchor == Position .CENTER :
161
+ return np .array (
162
+ [
163
+ (self .xyxy [:, 0 ] + self .xyxy [:, 2 ]) / 2 ,
164
+ (self .xyxy [:, 1 ] + self .xyxy [:, 3 ]) / 2 ,
165
+ ]
166
+ ).transpose ()
167
+ elif anchor == Position .BOTTOM_CENTER :
168
+ return np .array (
169
+ [(self .xyxy [:, 0 ] + self .xyxy [:, 2 ]) / 2 , self .xyxy [:, 3 ]]
170
+ ).transpose ()
171
+
172
+ raise ValueError (f"{ anchor } is not supported." )
173
+
174
+ def __getitem__ (self , index : np .ndarray ) -> Detections :
175
+ if isinstance (index , np .ndarray ) and index .dtype == np .bool :
176
+ return Detections (
177
+ xyxy = self .xyxy [index ],
178
+ confidence = self .confidence [index ],
179
+ class_id = self .class_id [index ],
180
+ tracker_id = self .tracker_id [index ]
181
+ if self .tracker_id is not None
182
+ else None ,
183
+ )
184
+ raise TypeError (
185
+ f"Detections.__getitem__ not supported for index of type { type (index )} ."
186
+ )
187
+
119
188
120
189
class BoxAnnotator :
121
190
def __init__ (
122
191
self ,
123
- color : Union [Color , ColorPalette ],
192
+ color : Union [Color , ColorPalette ] = ColorPalette . default () ,
124
193
thickness : int = 2 ,
125
194
text_color : Color = Color .black (),
126
195
text_scale : float = 0.5 ,
@@ -148,35 +217,46 @@ def __init__(
148
217
149
218
def annotate (
150
219
self ,
151
- frame : np .ndarray ,
220
+ scene : np .ndarray ,
152
221
detections : Detections ,
153
222
labels : Optional [List [str ]] = None ,
223
+ skip_label : bool = False ,
154
224
) -> np .ndarray :
155
225
"""
156
226
Draws bounding boxes on the frame using the detections provided.
157
227
158
- Attributes :
159
- frame (np.ndarray): The image on which the bounding boxes will be drawn
228
+ Parameters :
229
+ scene (np.ndarray): The image on which the bounding boxes will be drawn
160
230
detections (Detections): The detections for which the bounding boxes will be drawn
161
231
labels (Optional[List[str]]): An optional list of labels corresponding to each detection. If labels is provided, the confidence score of the detection will be replaced with the label.
162
-
232
+ skip_label (bool): Is set to True, skips bounding box label annotation.
163
233
Returns:
164
234
np.ndarray: The image with the bounding boxes drawn on it
165
235
"""
166
236
font = cv2 .FONT_HERSHEY_SIMPLEX
167
237
for i , (xyxy , confidence , class_id , tracker_id ) in enumerate (detections ):
238
+ x1 , y1 , x2 , y2 = xyxy .astype (int )
168
239
color = (
169
240
self .color .by_idx (class_id )
170
241
if isinstance (self .color , ColorPalette )
171
242
else self .color
172
243
)
244
+ cv2 .rectangle (
245
+ img = scene ,
246
+ pt1 = (x1 , y1 ),
247
+ pt2 = (x2 , y2 ),
248
+ color = color .as_bgr (),
249
+ thickness = self .thickness ,
250
+ )
251
+ if skip_label :
252
+ continue
253
+
173
254
text = (
174
255
f"{ confidence :0.2f} "
175
256
if (labels is None or len (detections ) != len (labels ))
176
257
else labels [i ]
177
258
)
178
259
179
- x1 , y1 , x2 , y2 = xyxy .astype (int )
180
260
text_width , text_height = cv2 .getTextSize (
181
261
text = text ,
182
262
fontFace = font ,
@@ -194,21 +274,14 @@ def annotate(
194
274
text_background_y2 = y1
195
275
196
276
cv2 .rectangle (
197
- img = frame ,
198
- pt1 = (x1 , y1 ),
199
- pt2 = (x2 , y2 ),
200
- color = color .as_bgr (),
201
- thickness = self .thickness ,
202
- )
203
- cv2 .rectangle (
204
- img = frame ,
277
+ img = scene ,
205
278
pt1 = (text_background_x1 , text_background_y1 ),
206
279
pt2 = (text_background_x2 , text_background_y2 ),
207
280
color = color .as_bgr (),
208
281
thickness = cv2 .FILLED ,
209
282
)
210
283
cv2 .putText (
211
- img = frame ,
284
+ img = scene ,
212
285
text = text ,
213
286
org = (text_x , text_y ),
214
287
fontFace = font ,
@@ -217,4 +290,4 @@ def annotate(
217
290
thickness = self .text_thickness ,
218
291
lineType = cv2 .LINE_AA ,
219
292
)
220
- return frame
293
+ return scene
0 commit comments