Skip to content

Commit 1d7bf88

Browse files
committed
Metrics docs overhaul
1 parent e9abfe5 commit 1d7bf88

File tree

6 files changed

+164
-45
lines changed

6 files changed

+164
-45
lines changed

docs/detection/metrics.md renamed to docs/detection/legacy_metrics.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
comments: true
33
---
44

5-
# Metrics
5+
# Legacy Metrics
6+
7+
Starting with `0.23.0`, a new metrics module is being introduced to supervision.
8+
Metrics here are part of the legacy evaluation API and will be deprecated in the future.
69

710
<div class="md-typeset">
811
<h2><a href="#supervision.metrics.detection.ConfusionMatrix">ConfusionMatrix</a></h2>

docs/metrics/intersection_over_union.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@ comments: true
33
status: new
44
---
55

6-
# Detections
6+
# Intersection over Union
7+
8+
<div class="md-typeset">
9+
<h2><a href="#supervision.metrics.intersection_over_union.IntersectionOverUnion">IntersectionOverUnion</a></h2>
10+
</div>
711

812
:::supervision.metrics.intersection_over_union.IntersectionOverUnion
13+
14+
<div class="md-typeset">
15+
<h2><a href="#supervision.metrics.intersection_over_union.IntersectionOverUnionResult">IntersectionOverUnionResult</a></h2>
16+
</div>
17+
18+
:::supervision.metrics.intersection_over_union.IntersectionOverUnionResult
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
---
2+
comments: true
3+
status: new
4+
---
5+
6+
# Mean Average Precision
7+
8+
<div class="md-typeset">
9+
<h2><a href="#supervision.metrics.mean_average_precision.MeanAveragePrecision">MeanAveragePrecision</a></h2>
10+
</div>
11+
12+
:::supervision.metrics.mean_average_precision.MeanAveragePrecision
13+
14+
<div class="md-typeset">
15+
<h2><a href="#supervision.metrics.mean_average_precision.MeanAveragePrecisionResult">MeanAveragePrecisionResult</a></h2>
16+
</div>
17+
18+
:::supervision.metrics.mean_average_precision.MeanAveragePrecisionResult

mkdocs.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ nav:
4747
- Detection and Segmentation:
4848
- Core: detection/core.md
4949
- Annotators: detection/annotators.md
50-
- Metrics: detection/metrics.md
5150
- Double Detection Filter: detection/double_detection_filter.md
5251
- Utils: detection/utils.md
5352
- Keypoint Detection:
@@ -67,6 +66,8 @@ nav:
6766
- Utils: datasets/utils.md
6867
- Metrics:
6968
- IoU: metrics/intersection_over_union.md
69+
- mAP: metrics/mean_average_precision.md
70+
- Legacy Metrics: detection/legacy_metrics.md
7071
- Utils:
7172
- Video: utils/video.md
7273
- Image: utils/image.md

supervision/metrics/intersection_over_union.py

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,13 @@ def update(
6565
) -> IntersectionOverUnion:
6666
"""
6767
Add data to the metric, without computing the result.
68-
Should call all update methods of the shared data store.
6968
7069
Args:
7170
data_1 (Union[Detection, List[Detections]]): The first set of data.
7271
data_2 (Union[Detection, List[Detections]]): The second set of data.
7372
7473
Returns:
75-
Metric: The metric object itself. You can get the metric result
76-
by calling the `compute` method.
74+
(IntersectionOverUnion): The updated metric instance.
7775
"""
7876
if self._is_store_shared:
7977
# Should be updated by the parent metric
@@ -103,10 +101,32 @@ def compute(self) -> IntersectionOverUnionResult:
103101
Uses the data set with the `update` method.
104102
105103
Returns:
106-
Dict[int, npt.NDArray[np.float32]]: A dictionary with class IDs as keys.
107-
If no class ID is provided, the key is the value CLASS_ID_NONE. The values
108-
are (N, M) arrays where N is the number of predictions and M is the number
109-
of targets.
104+
IntersectionOverUnionResult: IoU results.
105+
106+
Example:
107+
```python
108+
import supervision as sv
109+
from supervision.metrics import IntersectionOverUnion
110+
111+
detections_1 = sv.Detections(...)
112+
detections_2 = sv.Detections(...)
113+
114+
iou_metric = IntersectionOverUnion(class_agnostic=False)
115+
iou_result = map_metric.update(detections_1, detections_2).compute()
116+
print(iou_result)
117+
118+
class_id = 2
119+
ious = iou_result[class_id]
120+
121+
class_id = -1 # no class
122+
ious = iou_result[class_id]
123+
124+
for class_id, ious in iou_result:
125+
...
126+
127+
iou_result.plot()
128+
```
129+
110130
"""
111131
ious_by_class = {}
112132
for class_id in self._store.get_classes():
@@ -135,25 +155,75 @@ def compute(self) -> IntersectionOverUnionResult:
135155
@dataclass
136156
class IntersectionOverUnionResult:
137157
ious_by_class: Dict[int, npt.NDArray[np.float32]]
158+
"""The IoU matrices for each class."""
159+
138160
metric_target: MetricTarget
161+
"""
162+
Defines the type of data used for the metric - boxes, masks or
163+
oriented bounding boxes.
164+
"""
139165

140166
@property
141167
def class_ids(self) -> List[int]:
142168
return list(self.ious_by_class.keys())
143169

144170
def __getitem__(self, class_id: int) -> npt.NDArray[np.float32]:
171+
"""
172+
Get the IoU matrix for a specific class.
173+
174+
Args:
175+
class_id (int): The class ID. Set `-1` to access "no class" data.
176+
If class-agnostic IoU was used, all class IDs will be `-1`.
177+
178+
Returns:
179+
(npt.NDArray[np.float32]): The IoU matrix for the class.
180+
181+
Example:
182+
```python
183+
class_id = 2
184+
ious = iou_result[class_id]
185+
```
186+
"""
145187
return self.ious_by_class[class_id]
146188

147189
def __iter__(self):
190+
"""
191+
Iterate over the IoU matrices for each class.
192+
193+
Returns:
194+
(Iterator[Tuple[int, npt.NDArray[np.float32]]]): An iterator
195+
with class IDs as keys and IoU matrices as values.
196+
197+
Example:
198+
```python
199+
for class_id, ious in iou_result:
200+
...
201+
```
202+
"""
148203
return iter(self.ious_by_class.items())
149204

150205
def __str__(self) -> str:
206+
"""
207+
Format the IoU results as a pretty string.
208+
209+
Example:
210+
```python
211+
print(iou_result)
212+
```
213+
"""
151214
out_str = f"{self.__class__.__name__}:\n"
152215
for class_id, iou in self.ious_by_class.items():
153216
out_str += f"IoUs for class {class_id}:\n{str(iou)}\n"
154217
return out_str
155218

156219
def to_pandas(self) -> Dict[int, "pd.DataFrame"]:
220+
"""
221+
Convert the results to multiple pandas DataFrames.
222+
223+
Returns:
224+
(Dict[int, pd.DataFrame]): A dictionary with class IDs as keys and pandas
225+
DataFrames as values.
226+
"""
157227
ensure_pandas_installed()
158228
import pandas as pd
159229

supervision/metrics/mean_average_precision.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,16 @@ def update(
5252
predictions: Union[Detections, List[Detections]],
5353
targets: Union[Detections, List[Detections]],
5454
) -> MeanAveragePrecision:
55+
"""
56+
Add new predictions and targets to the metric, but do not compute the result.
57+
58+
Args:
59+
predictions (Union[Detections, List[Detections]]): The predicted detections.
60+
targets (Union[Detections, List[Detections]]): The ground-truth detections.
61+
62+
Returns:
63+
(MeanAveragePrecision): The updated metric instance.
64+
"""
5565
if not isinstance(predictions, list):
5666
predictions = [predictions]
5767
if not isinstance(targets, list):
@@ -88,45 +98,22 @@ def compute(
8898
number of ground-truth objects. Each row is expected to be in
8999
`(x_min, y_min, x_max, y_max, class)` format.
90100
Returns:
91-
MeanAveragePrecision: New instance of MeanAveragePrecision.
101+
(MeanAveragePrecision): New instance of MeanAveragePrecision.
92102
93103
Example:
94104
```python
95105
import supervision as sv
96-
import numpy as np
97-
98-
targets = (
99-
[
100-
np.array(
101-
[
102-
[0.0, 0.0, 3.0, 3.0, 1],
103-
[2.0, 2.0, 5.0, 5.0, 1],
104-
[6.0, 1.0, 8.0, 3.0, 2],
105-
]
106-
),
107-
np.array([[1.0, 1.0, 2.0, 2.0, 2]]),
108-
]
109-
)
106+
from supervision.metrics import MeanAveragePrecision
110107
111-
predictions = [
112-
np.array(
113-
[
114-
[0.0, 0.0, 3.0, 3.0, 1, 0.9],
115-
[0.1, 0.1, 3.0, 3.0, 0, 0.9],
116-
[6.0, 1.0, 8.0, 3.0, 1, 0.8],
117-
[1.0, 6.0, 2.0, 7.0, 1, 0.8],
118-
]
119-
),
120-
np.array([[1.0, 1.0, 2.0, 2.0, 2, 0.8]])
121-
]
108+
predictions = sv.Detections(...)
109+
targets = sv.Detections(...)
122110
123-
mean_average_precison = sv.MeanAveragePrecision.from_tensors(
124-
predictions=predictions,
125-
targets=targets,
126-
)
111+
map_metric = MeanAveragePrecision()
112+
map_result = map_metric.update(predictions, targets).compute()
127113
128-
print(mean_average_precison.map50_95)
129-
# 0.6649
114+
print(map_result)
115+
print(map_result.map50_95)
116+
map_result.plot()
130117
```
131118
"""
132119
(
@@ -243,6 +230,7 @@ def _compute(
243230
map50=map50,
244231
map75=map75,
245232
per_class_ap50_95=average_precisions,
233+
metric_target=self._metric_target,
246234
)
247235

248236
@staticmethod
@@ -256,7 +244,7 @@ def compute_average_precision(recall: np.ndarray, precision: np.ndarray) -> floa
256244
precision (np.ndarray): The precision curve.
257245
258246
Returns:
259-
float: Average precision.
247+
(float): Average precision.
260248
"""
261249
extended_recall = np.concatenate(([0.0], recall, [1.0]))
262250
extended_precision = np.concatenate(([1.0], precision, [0.0]))
@@ -320,7 +308,7 @@ def _average_precisions_per_class(
320308
eps (float, optional): Small value to prevent division by zero.
321309
322310
Returns:
323-
np.ndarray: Average precision for different IoU levels.
311+
(np.ndarray): Average precision for different IoU levels.
324312
"""
325313
eps = 1e-16
326314

@@ -361,15 +349,44 @@ def _average_precisions_per_class(
361349
@dataclass
362350
class MeanAveragePrecisionResult:
363351
iou_thresholds: np.ndarray
352+
"""Array of IoU thresholds used in the calculations"""
364353
map50_95: float
354+
"""Mean Average Precision over IoU thresholds from 0.5 to 0.95"""
355+
365356
map50: float
357+
"""Mean Average Precision at IoU threshold of 0.5"""
358+
366359
map75: float
360+
"""Mean Average Precision at IoU threshold of 0.75"""
361+
367362
per_class_ap50_95: np.ndarray
363+
"""Average precision for each class at different IoU thresholds"""
364+
365+
metric_target: MetricTarget
366+
"""
367+
Defines the type of data used for the metric - boxes, masks or
368+
oriented bounding boxes.
369+
"""
370+
368371
small_objects: Optional[MeanAveragePrecisionResult] = None
372+
"""Mean Average Precision results for small objects"""
373+
369374
medium_objects: Optional[MeanAveragePrecisionResult] = None
375+
"""Mean Average Precision results for medium objects"""
376+
370377
large_objects: Optional[MeanAveragePrecisionResult] = None
378+
"""Mean Average Precision results for large objects"""
371379

372380
def __str__(self) -> str:
381+
"""
382+
Format the mAP results as a pretty string.
383+
384+
Example:
385+
```python
386+
print(map_result)
387+
```
388+
"""
389+
373390
out_str = (
374391
f"{self.__class__.__name__}:\n"
375392
f"iou_thresholds: {self.iou_thresholds}\n"
@@ -402,7 +419,7 @@ def to_pandas(self) -> "pd.DataFrame":
402419
Convert the result to a pandas DataFrame.
403420
404421
Returns:
405-
pd.DataFrame: The result as a DataFrame.
422+
(pd.DataFrame): The result as a DataFrame.
406423
"""
407424
ensure_pandas_installed()
408425
import pandas as pd

0 commit comments

Comments
 (0)