Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/detection/tools/slicer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## Slicer

:::supervision.detection.tools.slicer.Slicer
1 change: 1 addition & 0 deletions supervision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from supervision.detection.core import Detections
from supervision.detection.line_counter import LineZone, LineZoneAnnotator
from supervision.detection.tools.polygon_zone import PolygonZone, PolygonZoneAnnotator
from supervision.detection.tools.slicer import Slicer
from supervision.detection.utils import (
box_iou_batch,
filter_polygons_by_area,
Expand Down
7 changes: 7 additions & 0 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ def _validate_mask(mask: Any, n: int) -> None:
raise ValueError("mask must be 3d np.ndarray with (n, H, W) shape")


def validate_inference_callback(callback) -> None:
tmp_img = np.zeros((256, 256, 3), dtype=np.uint8)
res = callback(tmp_img)
if not isinstance(res, Detections):
raise ValueError("Callback function must return sv.Detection type")


def _validate_class_id(class_id: Any, n: int) -> None:
is_valid = class_id is None or (
isinstance(class_id, np.ndarray) and class_id.shape == (n,)
Expand Down
127 changes: 127 additions & 0 deletions supervision/detection/tools/slicer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import Callable, Optional, Tuple

import numpy as np

from supervision.detection.core import Detections, validate_inference_callback
from supervision.detection.utils import move_boxes


class Slicer:
"""
Slicing inference(SAHI) method for small target detection.
"""

def __init__(
self,
callback: Callable[[np.ndarray], Detections],
sliced_width: Optional[int] = 320,
sliced_height: Optional[int] = 320,
overlap_width_ratio: Optional[float] = 0.2,
overlap_height_ratio: Optional[float] = 0.2,
iou_threshold: Optional[float] = 0.5,
):
"""
Args:
callback (Callable): model callback method which returns detections as sv.Detections
sliced_width (int): width of each slice
sliced_height (int): height of each slice
overlap_width_ratio (float): Fractional overlap in width of each
slice (e.g. an overlap of 0.2 for a slice of size 100 yields an
overlap of 20 pixels). Default 0.2.
overlap_height_ratio (float): Fractional overlap in height of each
slice (e.g. an overlap of 0.2 for a slice of size 100 yields an
overlap of 20 pixels). Default 0.2.
iou_threshold (float): non-maximum suppression iou threshold to remove overlapping detections
"""
self.siced_width = sliced_width
self.sliced_height = sliced_height
self.overlap_width_ratio = overlap_width_ratio
self.overlap_height_ratio = overlap_height_ratio
self.iou_threshold = iou_threshold
self.callback = callback
validate_inference_callback(callback=callback)

def __call__(self, image: np.ndarray) -> Detections:
"""

Args:
image (np.ndarray):

Returns:
sv.Detections

Example:
```python
>>> import supervision as sv
>>> from ultralytics import YOLO

>>> dataset = sv.DetectionDataset.from_yolo(...)

>>> model = YOLO(...)
>>> def callback(slice: np.ndarray) -> sv.Detections:
... result = model(slice)[0]
... return sv.Detections.from_ultralytics(result)

>>> slicer = sv.Slicer(
... callback = callback
... )

>>> detections = slicer(image)
```
"""
detections = []
image_height, image_width, _ = image.shape
offsets = self._offset_generation(
image_width=image_width, image_height=image_height
)

for offset in offsets:
slice = image[offset[1] : offset[3], offset[0] : offset[2]]
det = self.callback(slice)
det = self._reposition_detections(detection=det, offset=offset)
detections.append(det)
detection = Detections.merge(detections_list=detections).with_nms(
threshold=self.iou_threshold
)
return detection

def _offset_generation(self, image_width: int, image_height: int) -> np.ndarray:
"""
Args:
image_width (int): width of the input image
image_height (int): height of the input image

Returns:
list of slice locations according to slicer parameters
"""
width_stride = self.siced_width - int(
self.overlap_width_ratio * self.siced_width
)
height_stride = self.sliced_height - int(
self.overlap_height_ratio * self.sliced_height
)
offsets = []
for h in range(0, image_height, height_stride):
for w in range(0, image_width, width_stride):
xmin = w
ymin = h
xmax = min(image_width, w + self.siced_width)
ymax = min(image_height, h + self.sliced_height)
offsets.append([xmin, ymin, xmax, ymax])
offsets = np.asarray(offsets)
return offsets

@staticmethod
def _reposition_detections(detection: Detections, offset: np.array) -> Detections:
"""
Args:
detection (np.ndarray): result of model inference of the slice
slice_location (Tuple[int, int, int, int]): slice location at which inference was performed
Returns:
(sv.Detections) repositioned detections result based on original image
"""
if len(detection) == 0:
return detection
xyxy = move_boxes(boxes=detection.xyxy, offset=offset)
detection.xyxy = xyxy
return detection
13 changes: 13 additions & 0 deletions supervision/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,16 @@ def process_roboflow_result(
masks = np.array(masks, dtype=bool) if len(masks) > 0 else None

return xyxy, confidence, class_id, masks


def move_boxes(boxes: np.ndarray, offset: np.array) -> np.ndarray:
"""
Args:
boxes (np.ndarray): boxes of model inference of the slice
offset (np.array): slice location at which inference was performed i.e. np.array([x1, y1, x2, y2])

Returns:
(np.ndarray) repositioned bounding boxes
"""
offsets = np.array([offset[0], offset[1], offset[0], offset[1]])
return boxes + offsets