Skip to content

Commit cd8a2be

Browse files
authored
Merge pull request #1178 from roboflow/feat/inference-slicer-segmentation
Feat/inference slicer segmentation
2 parents f41adca + 8901192 commit cd8a2be

File tree

5 files changed

+130
-10
lines changed

5 files changed

+130
-10
lines changed

docs/detection/utils.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ status: new
6565

6666
:::supervision.detection.utils.move_boxes
6767

68+
<div class="md-typeset">
69+
<h2><a href="#supervision.detection.utils.move_masks">move_masks</a></h2>
70+
</div>
71+
72+
:::supervision.detection.utils.move_masks
73+
6874
<div class="md-typeset">
6975
<h2><a href="#supervision.detection.utils.scale_boxes">scale_boxes</a></h2>
7076
</div>

docs/how_to/detect_small_objects.md

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,19 @@ size relative to the image resolution.
6868
import torch
6969
import supervision as sv
7070
from PIL import Image
71-
from transformers import DetrImageProcessor, DetrForObjectDetection
71+
from transformers import DetrImageProcessor, DetrForSegmentation
7272

7373
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
74-
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
74+
model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50")
7575

7676
image = Image.open(<SOURCE_IMAGE_PATH>)
7777
inputs = processor(images=image, return_tensors="pt")
7878

7979
with torch.no_grad():
8080
outputs = model(**inputs)
8181

82-
width, height = image.size
83-
target_size = torch.tensor([[height, width]])
82+
width, height = image_slice.size
83+
target_size = torch.tensor([[width, height]])
8484
results = processor.post_process_object_detection(
8585
outputs=outputs, target_sizes=target_size)[0]
8686
detections = sv.Detections.from_transformers(results)
@@ -239,8 +239,8 @@ objects within each, and aggregating the results.
239239
with torch.no_grad():
240240
outputs = model(**inputs)
241241

242-
width, height = image.size
243-
target_size = torch.tensor([[height, width]])
242+
width, height = image_slice.size
243+
target_size = torch.tensor([[width, height]])
244244
results = processor.post_process_object_detection(
245245
outputs=outputs, target_sizes=target_size)[0]
246246
return sv.Detections.from_transformers(results)
@@ -264,3 +264,63 @@ objects within each, and aggregating the results.
264264
```
265265

266266
![detection-with-inference-slicer](https://media.roboflow.com/supervision_detect_small_objects_example_3.png)
267+
268+
## Small Object Segmentation
269+
270+
[`InferenceSlicer`](/latest/detection/tools/inference_slicer/#supervision.detection.tools.inference_slicer.InferenceSlicer) can perform segmentation tasks too.
271+
272+
=== "Inference"
273+
274+
```{ .py hl_lines="6 16 19-20" }
275+
import cv2
276+
import numpy as np
277+
import supervision as sv
278+
from inference import get_model
279+
280+
model = get_model(model_id="yolov8x-seg-640")
281+
image = cv2.imread(<SOURCE_IMAGE_PATH>)
282+
283+
def callback(image_slice: np.ndarray) -> sv.Detections:
284+
results = model.infer(image_slice)[0]
285+
detections = sv.Detections.from_inference(results)
286+
287+
slicer = sv.InferenceSlicer(callback = callback)
288+
detections = slicer(image)
289+
290+
mask_annotator = sv.MaskAnnotator()
291+
label_annotator = sv.LabelAnnotator()
292+
293+
annotated_image = mask_annotator.annotate(
294+
scene=image, detections=detections)
295+
annotated_image = label_annotator.annotate(
296+
scene=annotated_image, detections=detections)
297+
```
298+
299+
=== "Ultralytics"
300+
301+
```{ .py hl_lines="6 16 19-20" }
302+
import cv2
303+
import numpy as np
304+
import supervision as sv
305+
from ultralytics import YOLO
306+
307+
model = YOLO("yolov8x-seg.pt")
308+
image = cv2.imread(<SOURCE_IMAGE_PATH>)
309+
310+
def callback(image_slice: np.ndarray) -> sv.Detections:
311+
result = model(image_slice)[0]
312+
return sv.Detections.from_ultralytics(result)
313+
314+
slicer = sv.InferenceSlicer(callback = callback)
315+
detections = slicer(image)
316+
317+
mask_annotator = sv.MaskAnnotator()
318+
label_annotator = sv.LabelAnnotator()
319+
320+
annotated_image = mask_annotator.annotate(
321+
scene=image, detections=detections)
322+
annotated_image = label_annotator.annotate(
323+
scene=annotated_image, detections=detections)
324+
```
325+
326+
![detection-with-inference-slicer](https://media.roboflow.com/supervision-docs/inference-slicer-segmentation-example.png)

supervision/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
mask_to_polygons,
5454
mask_to_xyxy,
5555
move_boxes,
56+
move_masks,
5657
pad_boxes,
5758
polygon_to_mask,
5859
polygon_to_xyxy,

supervision/detection/tools/inference_slicer.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,36 @@
44
import numpy as np
55

66
from supervision.detection.core import Detections
7-
from supervision.detection.utils import move_boxes
7+
from supervision.detection.utils import move_boxes, move_masks
88
from supervision.utils.image import crop_image
99

1010

11-
def move_detections(detections: Detections, offset: np.array) -> Detections:
11+
def move_detections(
12+
detections: Detections,
13+
offset: np.ndarray,
14+
resolution_wh: Optional[Tuple[int, int]] = None,
15+
) -> Detections:
1216
"""
1317
Args:
1418
detections (sv.Detections): Detections object to be moved.
15-
offset (np.array): An array of shape `(2,)` containing offset values in format
19+
offset (np.ndarray): An array of shape `(2,)` containing offset values in format
1620
is `[dx, dy]`.
21+
resolution_wh (Tuple[int, int]): The width and height of the desired mask
22+
resolution. Required for segmentation detections.
23+
1724
Returns:
1825
(sv.Detections) repositioned Detections object.
1926
"""
2027
detections.xyxy = move_boxes(xyxy=detections.xyxy, offset=offset)
28+
if detections.mask is not None:
29+
if resolution_wh is None:
30+
raise ValueError(
31+
"Resolution width and height are required for moving segmentation "
32+
"detections. This should be the same as (width, height) of image shape."
33+
)
34+
detections.mask = move_masks(
35+
masks=detections.mask, offset=offset, resolution_wh=resolution_wh
36+
)
2137
return detections
2238

2339

@@ -126,7 +142,10 @@ def _run_callback(self, image, offset) -> Detections:
126142
"""
127143
image_slice = crop_image(image=image, xyxy=offset)
128144
detections = self.callback(image_slice)
129-
detections = move_detections(detections=detections, offset=offset[:2])
145+
resolution_wh = (image.shape[1], image.shape[0])
146+
detections = move_detections(
147+
detections=detections, offset=offset[:2], resolution_wh=resolution_wh
148+
)
130149

131150
return detections
132151

supervision/detection/utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,40 @@ def move_boxes(xyxy: np.ndarray, offset: np.ndarray) -> np.ndarray:
621621
return xyxy + np.hstack([offset, offset])
622622

623623

624+
def move_masks(
625+
masks: np.ndarray,
626+
offset: np.ndarray,
627+
resolution_wh: Tuple[int, int] = None,
628+
) -> np.ndarray:
629+
"""
630+
Offset the masks in an array by the specified (x, y) amount.
631+
632+
Args:
633+
masks (np.ndarray): A 3D array of binary masks corresponding to the predictions.
634+
Shape: `(N, H, W)`, where N is the number of predictions, and H, W are the
635+
dimensions of each mask.
636+
offset (np.ndarray): An array of shape `(2,)` containing non-negative int values
637+
`[dx, dy]`.
638+
resolution_wh (Tuple[int, int]): The width and height of the desired mask
639+
resolution.
640+
641+
Returns:
642+
(np.ndarray) repositioned masks, optionally padded to the specified shape.
643+
"""
644+
645+
if offset[0] < 0 or offset[1] < 0:
646+
raise ValueError(f"Offset values must be non-negative integers. Got: {offset}")
647+
648+
mask_array = np.full((masks.shape[0], resolution_wh[1], resolution_wh[0]), False)
649+
mask_array[
650+
:,
651+
offset[1] : masks.shape[1] + offset[1],
652+
offset[0] : masks.shape[2] + offset[0],
653+
] = masks
654+
655+
return mask_array
656+
657+
624658
def scale_boxes(xyxy: np.ndarray, factor: float) -> np.ndarray:
625659
"""
626660
Scale the dimensions of bounding boxes.

0 commit comments

Comments
 (0)