Skip to content

Commit 2e76e3d

Browse files
authored
Merge pull request #15 from roboflow/feature/polygon-zone
feature/polygon-zone
2 parents ceabb83 + a0eb6c3 commit 2e76e3d

File tree

21 files changed

+420
-114
lines changed

21 files changed

+420
-114
lines changed

docs/detection_core.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## Detections
2+
3+
:::supervision.detection.core.Detections

docs/detection_utils.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## generate_2d_mask
2+
3+
:::supervision.detection.utils.generate_2d_mask

docs/draw.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
Utilities for drawing on images.
2-
3-
## Draw Line
1+
## draw_line
42

53
:::supervision.draw.utils.draw_line
64

7-
## Draw Rectangle
5+
## draw_rectangle
86

97
:::supervision.draw.utils.draw_rectangle

docs/notebook.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
Utilities to help you build computer vision projects in notebook environments.
1+
## show_frame_in_notebook
22

33
:::supervision.notebook.utils.show_frame_in_notebook

docs/tools.md

Lines changed: 0 additions & 9 deletions
This file was deleted.

mkdocs.yml

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,15 @@ extra:
2222
property: G-P7ZG0Y19G5
2323

2424
nav:
25-
- Home 🏠: index.md
26-
- Video 📷: video.md
27-
- Notebook Helpers 📓: notebook.md
28-
- Draw 🎨: draw.md
29-
- Geometry 📐: geometry.md
30-
- Tools 🛠: tools.md
25+
- Home: index.md
26+
- API reference:
27+
- Video: video.md
28+
- Detection:
29+
- Core: detection_core.md
30+
- Utils: detection_utils.md
31+
- Draw: draw.md
32+
- Geometry: geometry.md
33+
- Notebook: notebook.md
3134

3235
theme:
3336
name: 'material'

setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ def get_version():
2424
long_description_content_type='text/markdown',
2525
url='https://github.com/roboflow/supervision',
2626
install_requires=[
27-
'numpy',
28-
'opencv-python'
27+
'numpy',
28+
'opencv-python',
29+
'matplotlib'
2930
],
3031
packages=find_packages(exclude=("tests",)),
3132
extras_require={

supervision/__init__.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,16 @@
1-
__version__ = "0.1.0"
1+
__version__ = "0.2.0"
2+
3+
from supervision.detection.core import BoxAnnotator, Detections
4+
from supervision.detection.polygon_zone import PolygonZone, PolygonZoneAnnotator
5+
from supervision.detection.utils import generate_2d_mask
6+
from supervision.draw.color import Color, ColorPalette
7+
from supervision.draw.utils import draw_filled_rectangle, draw_polygon, draw_text
8+
from supervision.geometry.core import Point, Position, Rect
9+
from supervision.geometry.utils import get_polygon_center
10+
from supervision.notebook.utils import show_frame_in_notebook
11+
from supervision.video import (
12+
VideoInfo,
13+
VideoSink,
14+
get_video_frames_generator,
15+
process_video,
16+
)
File renamed without changes.

supervision/tools/detections.py renamed to supervision/detection/core.py

Lines changed: 121 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,33 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
14
from typing import List, Optional, Union
25

36
import cv2
47
import numpy as np
58

69
from supervision.draw.color import Color, ColorPalette
10+
from supervision.geometry.core import Position
711

812

13+
@dataclass
914
class Detections:
10-
def __init__(
11-
self,
12-
xyxy: np.ndarray,
13-
confidence: np.ndarray,
14-
class_id: np.ndarray,
15-
tracker_id: Optional[np.ndarray] = None,
16-
):
17-
"""
18-
Data class containing information about the detections in a video frame.
15+
"""
16+
Data class containing information about the detections in a video frame.
1917
20-
Attributes:
21-
xyxy (np.ndarray): An array of shape (n, 4) containing the bounding boxes coordinates in format [x1, y1, x2, y2]
22-
confidence (np.ndarray): An array of shape (n,) containing the confidence scores of the detections.
23-
class_id (np.ndarray): An array of shape (n,) containing the class ids of the detections.
24-
tracker_id (Optional[np.ndarray]): An array of shape (n,) containing the tracker ids of the detections.
25-
"""
26-
self.xyxy: np.ndarray = xyxy
27-
self.confidence: np.ndarray = confidence
28-
self.class_id: np.ndarray = class_id
29-
self.tracker_id: Optional[np.ndarray] = tracker_id
18+
Attributes:
19+
xyxy (np.ndarray): An array of shape `(n, 4)` containing the bounding boxes coordinates in format `[x1, y1, x2, y2]`
20+
confidence (np.ndarray): An array of shape `(n,)` containing the confidence scores of the detections.
21+
class_id (np.ndarray): An array of shape `(n,)` containing the class ids of the detections.
22+
tracker_id (Optional[np.ndarray]): An array of shape `(n,)` containing the tracker ids of the detections.
23+
"""
3024

25+
xyxy: np.ndarray
26+
confidence: np.ndarray
27+
class_id: np.ndarray
28+
tracker_id: Optional[np.ndarray] = None
29+
30+
def __post_init__(self):
3131
n = len(self.xyxy)
3232
validators = [
3333
(isinstance(self.xyxy, np.ndarray) and self.xyxy.shape == (n, 4)),
@@ -55,7 +55,7 @@ def __len__(self):
5555

5656
def __iter__(self):
5757
"""
58-
Iterates over the Detections object and yield a tuple of (xyxy, confidence, class_id, tracker_id) for each detection.
58+
Iterates over the Detections object and yield a tuple of `(xyxy, confidence, class_id, tracker_id)` for each detection.
5959
"""
6060
for i in range(len(self.xyxy)):
6161
yield (
@@ -66,37 +66,68 @@ def __iter__(self):
6666
)
6767

6868
@classmethod
69-
def from_yolov5(cls, yolov5_output: np.ndarray):
69+
def from_yolov5(cls, yolov5_detections):
7070
"""
71-
Creates a Detections instance from a YOLOv5 output tensor
71+
Creates a Detections instance from a YOLOv5 output Detections
7272
7373
Attributes:
74-
yolov5_output (np.ndarray): The output tensor from YOLOv5
74+
yolov5_detections (yolov5.models.common.Detections): The output Detections instance from YOLOv5
7575
7676
Returns:
7777
7878
Example:
7979
```python
80-
>>> from supervision.tools.detections import Detections
80+
>>> import torch
81+
>>> from supervision import Detections
8182
82-
>>> detections = Detections.from_yolov5(yolov5_output)
83+
>>> model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
84+
>>> results = model(frame)
85+
>>> detections = Detections.from_yolov5(results)
8386
```
8487
"""
85-
xyxy = yolov5_output[:, :4]
86-
confidence = yolov5_output[:, 4]
87-
class_id = yolov5_output[:, 5].astype(int)
88-
return cls(xyxy, confidence, class_id)
88+
yolov5_detections_predictions = yolov5_detections.pred[0].cpu().cpu().numpy()
89+
return cls(
90+
xyxy=yolov5_detections_predictions[:, :4],
91+
confidence=yolov5_detections_predictions[:, 4],
92+
class_id=yolov5_detections_predictions[:, 5].astype(int),
93+
)
8994

90-
def filter(self, mask: np.ndarray, inplace: bool = False) -> Optional[np.ndarray]:
95+
@classmethod
96+
def from_yolov8(cls, yolov8_results):
97+
"""
98+
Creates a Detections instance from a YOLOv8 output Results
99+
100+
Attributes:
101+
yolov8_results (ultralytics.yolo.engine.results.Results): The output Results instance from YOLOv8
102+
103+
Returns:
104+
105+
Example:
106+
```python
107+
>>> from ultralytics import YOLO
108+
>>> from supervision import Detections
109+
110+
>>> model = YOLO('yolov8s.pt')
111+
>>> results = model(frame)
112+
>>> detections = Detections.from_yolov8(results)
113+
```
114+
"""
115+
return cls(
116+
xyxy=yolov8_results.boxes.xyxy.cpu().numpy(),
117+
confidence=yolov8_results.boxes.conf.cpu().numpy(),
118+
class_id=yolov8_results.boxes.cls.cpu().numpy().astype(int),
119+
)
120+
121+
def filter(self, mask: np.ndarray, inplace: bool = False) -> Optional[Detections]:
91122
"""
92123
Filter the detections by applying a mask.
93124
94125
Attributes:
95-
mask (np.ndarray): A mask of shape (n,) containing a boolean value for each detection indicating if it should be included in the filtered detections
126+
mask (np.ndarray): A mask of shape `(n,)` containing a boolean value for each detection indicating if it should be included in the filtered detections
96127
inplace (bool): If True, the original data will be modified and self will be returned.
97128
98129
Returns:
99-
Optional[np.ndarray]: A new instance of Detections with the filtered detections, if inplace is set to False. None otherwise.
130+
Optional[np.ndarray]: A new instance of Detections with the filtered detections, if inplace is set to `False`. `None` otherwise.
100131
"""
101132
if inplace:
102133
self.xyxy = self.xyxy[mask]
@@ -116,11 +147,49 @@ def filter(self, mask: np.ndarray, inplace: bool = False) -> Optional[np.ndarray
116147
else None,
117148
)
118149

150+
def get_anchor_coordinates(self, anchor: Position) -> np.ndarray:
151+
"""
152+
Returns the bounding box coordinates for a specific anchor.
153+
154+
Properties:
155+
anchor (Position): Position of bounding box anchor for which to return the coordinates.
156+
157+
Returns:
158+
np.ndarray: An array of shape `(n, 2)` containing the bounding box anchor coordinates in format `[x, y]`.
159+
"""
160+
if anchor == Position.CENTER:
161+
return np.array(
162+
[
163+
(self.xyxy[:, 0] + self.xyxy[:, 2]) / 2,
164+
(self.xyxy[:, 1] + self.xyxy[:, 3]) / 2,
165+
]
166+
).transpose()
167+
elif anchor == Position.BOTTOM_CENTER:
168+
return np.array(
169+
[(self.xyxy[:, 0] + self.xyxy[:, 2]) / 2, self.xyxy[:, 3]]
170+
).transpose()
171+
172+
raise ValueError(f"{anchor} is not supported.")
173+
174+
def __getitem__(self, index: np.ndarray) -> Detections:
175+
if isinstance(index, np.ndarray) and index.dtype == np.bool:
176+
return Detections(
177+
xyxy=self.xyxy[index],
178+
confidence=self.confidence[index],
179+
class_id=self.class_id[index],
180+
tracker_id=self.tracker_id[index]
181+
if self.tracker_id is not None
182+
else None,
183+
)
184+
raise TypeError(
185+
f"Detections.__getitem__ not supported for index of type {type(index)}."
186+
)
187+
119188

120189
class BoxAnnotator:
121190
def __init__(
122191
self,
123-
color: Union[Color, ColorPalette],
192+
color: Union[Color, ColorPalette] = ColorPalette.default(),
124193
thickness: int = 2,
125194
text_color: Color = Color.black(),
126195
text_scale: float = 0.5,
@@ -148,35 +217,46 @@ def __init__(
148217

149218
def annotate(
150219
self,
151-
frame: np.ndarray,
220+
scene: np.ndarray,
152221
detections: Detections,
153222
labels: Optional[List[str]] = None,
223+
skip_label: bool = False,
154224
) -> np.ndarray:
155225
"""
156226
Draws bounding boxes on the frame using the detections provided.
157227
158-
Attributes:
159-
frame (np.ndarray): The image on which the bounding boxes will be drawn
228+
Parameters:
229+
scene (np.ndarray): The image on which the bounding boxes will be drawn
160230
detections (Detections): The detections for which the bounding boxes will be drawn
161231
labels (Optional[List[str]]): An optional list of labels corresponding to each detection. If labels is provided, the confidence score of the detection will be replaced with the label.
162-
232+
skip_label (bool): Is set to True, skips bounding box label annotation.
163233
Returns:
164234
np.ndarray: The image with the bounding boxes drawn on it
165235
"""
166236
font = cv2.FONT_HERSHEY_SIMPLEX
167237
for i, (xyxy, confidence, class_id, tracker_id) in enumerate(detections):
238+
x1, y1, x2, y2 = xyxy.astype(int)
168239
color = (
169240
self.color.by_idx(class_id)
170241
if isinstance(self.color, ColorPalette)
171242
else self.color
172243
)
244+
cv2.rectangle(
245+
img=scene,
246+
pt1=(x1, y1),
247+
pt2=(x2, y2),
248+
color=color.as_bgr(),
249+
thickness=self.thickness,
250+
)
251+
if skip_label:
252+
continue
253+
173254
text = (
174255
f"{confidence:0.2f}"
175256
if (labels is None or len(detections) != len(labels))
176257
else labels[i]
177258
)
178259

179-
x1, y1, x2, y2 = xyxy.astype(int)
180260
text_width, text_height = cv2.getTextSize(
181261
text=text,
182262
fontFace=font,
@@ -194,21 +274,14 @@ def annotate(
194274
text_background_y2 = y1
195275

196276
cv2.rectangle(
197-
img=frame,
198-
pt1=(x1, y1),
199-
pt2=(x2, y2),
200-
color=color.as_bgr(),
201-
thickness=self.thickness,
202-
)
203-
cv2.rectangle(
204-
img=frame,
277+
img=scene,
205278
pt1=(text_background_x1, text_background_y1),
206279
pt2=(text_background_x2, text_background_y2),
207280
color=color.as_bgr(),
208281
thickness=cv2.FILLED,
209282
)
210283
cv2.putText(
211-
img=frame,
284+
img=scene,
212285
text=text,
213286
org=(text_x, text_y),
214287
fontFace=font,
@@ -217,4 +290,4 @@ def annotate(
217290
thickness=self.text_thickness,
218291
lineType=cv2.LINE_AA,
219292
)
220-
return frame
293+
return scene

0 commit comments

Comments
 (0)