-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Add Non-Maximum Merging (NMM) to Detections #500
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c78ae33
57b12e6
9f22273
6f47046
5f0dcc2
166a8da
b159873
d7e52be
bee3252
97c4071
204669b
2eb0c7c
c3b77d0
8014e88
26bafec
d2d50fb
2d740bd
145b5fe
6c40935
53f345e
0e2eec0
559ef90
f8f3647
9024396
6fbca83
db1b473
0721bc2
530e1d0
2ee9e08
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,8 @@ | |
|
||
from supervision.config import CLASS_NAME_DATA_FIELD, ORIENTED_BOX_COORDINATES | ||
from supervision.detection.utils import ( | ||
box_iou_batch, | ||
box_non_max_merge, | ||
box_non_max_suppression, | ||
calculate_masks_centroids, | ||
extract_ultralytics_masks, | ||
|
@@ -1150,3 +1152,193 @@ def with_nms( | |
) | ||
|
||
return self[indices] | ||
|
||
def with_nmm( | ||
self, threshold: float = 0.5, class_agnostic: bool = False | ||
) -> Detections: | ||
""" | ||
Perform non-maximum merging on the current set of object detections. | ||
|
||
Args: | ||
threshold (float, optional): The intersection-over-union threshold | ||
to use for non-maximum merging. Defaults to 0.5. | ||
class_agnostic (bool, optional): Whether to perform class-agnostic | ||
non-maximum merging. If True, the class_id of each detection | ||
will be ignored. Defaults to False. | ||
|
||
Returns: | ||
Detections: A new Detections object containing the subset of detections | ||
after non-maximum merging. | ||
|
||
Raises: | ||
AssertionError: If `confidence` is None or `class_id` is None and | ||
class_agnostic is False. | ||
""" | ||
if len(self) == 0: | ||
return self | ||
|
||
assert ( | ||
self.confidence is not None | ||
), "Detections confidence must be given for NMM to be executed." | ||
|
||
if class_agnostic: | ||
predictions = np.hstack((self.xyxy, self.confidence.reshape(-1, 1))) | ||
else: | ||
assert self.class_id is not None, ( | ||
"Detections class_id must be given for NMM to be executed. If you" | ||
" intended to perform class agnostic NMM set class_agnostic=True." | ||
) | ||
predictions = np.hstack( | ||
( | ||
self.xyxy, | ||
self.confidence.reshape(-1, 1), | ||
self.class_id.reshape(-1, 1), | ||
) | ||
) | ||
|
||
merge_groups = box_non_max_merge( | ||
predictions=predictions, iou_threshold=threshold | ||
) | ||
|
||
result = [] | ||
for merge_group in merge_groups: | ||
unmerged_detections = [self[i] for i in merge_group] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we don't need that list comprehension, just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My explanation was wrong. We're doing this not to copy the result (that's in another case), but to create a list of single-object detections. I believe this is the most concise way. |
||
merged_detections = merge_inner_detections_objects( | ||
unmerged_detections, threshold | ||
) | ||
result.append(merged_detections) | ||
|
||
return Detections.merge(result) | ||
|
||
|
||
def merge_inner_detection_object_pair( | ||
detections_1: Detections, detections_2: Detections | ||
) -> Detections: | ||
""" | ||
Merges two Detections object into a single Detections object. | ||
Assumes each Detections contains exactly one object. | ||
|
||
A `winning` detection is determined based on the confidence score of the two | ||
input detections. This winning detection is then used to specify which | ||
`class_id`, `tracker_id`, and `data` to include in the merged Detections object. | ||
|
||
The resulting `confidence` of the merged object is calculated by the weighted | ||
contribution of ea detection to the merged object. | ||
The bounding boxes and masks of the two input detections are merged into a | ||
single bounding box and mask, respectively. | ||
|
||
Args: | ||
detections_1 (Detections): | ||
The first Detections object | ||
detections_2 (Detections): | ||
The second Detections object | ||
|
||
Returns: | ||
Detections: A new Detections object, with merged attributes. | ||
|
||
Raises: | ||
ValueError: If the input Detections objects do not have exactly 1 detected | ||
object. | ||
|
||
Example: | ||
```python | ||
import cv2 | ||
import supervision as sv | ||
from inference import get_model | ||
|
||
image = cv2.imread(<SOURCE_IMAGE_PATH>) | ||
model = get_model(model_id="yolov8s-640") | ||
|
||
result = model.infer(image)[0] | ||
detections = sv.Detections.from_inference(result) | ||
|
||
merged_detections = merge_object_detection_pair( | ||
detections[0], detections[1]) | ||
``` | ||
""" | ||
if len(detections_1) != 1 or len(detections_2) != 1: | ||
raise ValueError("Both Detections should have exactly 1 detected object.") | ||
|
||
validate_fields_both_defined_or_none(detections_1, detections_2) | ||
|
||
xyxy_1 = detections_1.xyxy[0] | ||
xyxy_2 = detections_2.xyxy[0] | ||
if detections_1.confidence is None and detections_2.confidence is None: | ||
merged_confidence = None | ||
else: | ||
detection_1_area = (xyxy_1[2] - xyxy_1[0]) * (xyxy_1[3] - xyxy_1[1]) | ||
detections_2_area = (xyxy_2[2] - xyxy_2[0]) * (xyxy_2[3] - xyxy_2[1]) | ||
merged_confidence = ( | ||
detection_1_area * detections_1.confidence[0] | ||
+ detections_2_area * detections_2.confidence[0] | ||
) / (detection_1_area + detections_2_area) | ||
merged_confidence = np.array([merged_confidence]) | ||
|
||
merged_x1, merged_y1 = np.minimum(xyxy_1[:2], xyxy_2[:2]) | ||
merged_x2, merged_y2 = np.maximum(xyxy_1[2:], xyxy_2[2:]) | ||
merged_xyxy = np.array([[merged_x1, merged_y1, merged_x2, merged_y2]]) | ||
|
||
if detections_1.mask is None and detections_2.mask is None: | ||
merged_mask = None | ||
else: | ||
merged_mask = np.logical_or(detections_1.mask, detections_2.mask) | ||
|
||
if detections_1.confidence is None and detections_2.confidence is None: | ||
winning_detection = detections_1 | ||
elif detections_1.confidence[0] >= detections_2.confidence[0]: | ||
winning_detection = detections_1 | ||
else: | ||
winning_detection = detections_2 | ||
|
||
return Detections( | ||
xyxy=merged_xyxy, | ||
mask=merged_mask, | ||
confidence=merged_confidence, | ||
class_id=winning_detection.class_id, | ||
tracker_id=winning_detection.tracker_id, | ||
data=winning_detection.data, | ||
) | ||
|
||
|
||
def merge_inner_detections_objects( | ||
detections: List[Detections], threshold=0.5 | ||
) -> Detections: | ||
""" | ||
Given N detections each of length 1 (exactly one object inside), combine them into a | ||
single detection object of length 1. The contained inner object will be the merged | ||
result of all the input detections. | ||
|
||
For example, this lets you merge N boxes into one big box, N masks into one mask, | ||
etc. | ||
""" | ||
detections_1 = detections[0] | ||
for detections_2 in detections[1:]: | ||
box_iou = box_iou_batch(detections_1.xyxy, detections_2.xyxy)[0] | ||
if box_iou < threshold: | ||
break | ||
detections_1 = merge_inner_detection_object_pair(detections_1, detections_2) | ||
return detections_1 | ||
|
||
|
||
def validate_fields_both_defined_or_none( | ||
detections_1: Detections, detections_2: Detections | ||
) -> None: | ||
""" | ||
Verify that for each optional field in the Detections, both instances either have | ||
the field set to None or both have it set to non-None values. | ||
|
||
`data` field is ignored. | ||
|
||
Raises: | ||
ValueError: If one field is None and the other is not, for any of the fields. | ||
""" | ||
attributes = ["mask", "confidence", "class_id", "tracker_id"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we try to get that list automatically? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried it, it's cumbersome, I'll add the code + tests in a separate PR and we can choose whether to keep it. |
||
for attribute in attributes: | ||
value_1 = getattr(detections_1, attribute) | ||
value_2 = getattr(detections_2, attribute) | ||
|
||
if (value_1 is None) != (value_2 is None): | ||
raise ValueError( | ||
f"Field '{attribute}' should be consistently None or not None in both " | ||
"Detections." | ||
) |
Uh oh!
There was an error while loading. Please reload this page.