Skip to content

Commit ea776b5

Browse files
authored
Merge pull request #1609 from roboflow/feat/metrics-precision-recall
Feat/metrics precision recall
2 parents bda4003 + 3e8a88a commit ea776b5

File tree

10 files changed

+1359
-31
lines changed

10 files changed

+1359
-31
lines changed

docs/metrics/common_values.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
---
2+
comments: true
3+
status: new
4+
---
5+
6+
# Common Values
7+
8+
This page contains supplementary values, types and enums that metrics use.
9+
10+
<div class="md-typeset">
11+
<h2><a href="#supervision.metrics.core.MetricTarget">MetricTarget</a></h2>
12+
</div>
13+
14+
:::supervision.metrics.core.MetricTarget
15+
16+
<div class="md-typeset">
17+
<h2><a href="#supervision.metrics.core.AveragingMethod">AveragingMethod</a></h2>
18+
</div>
19+
20+
:::supervision.metrics.core.AveragingMethod

docs/metrics/precision.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
---
2+
comments: true
3+
status: new
4+
---
5+
6+
# Precision
7+
8+
<div class="md-typeset">
9+
<h2><a href="#supervision.metrics.precision.Precision">Precision</a></h2>
10+
</div>
11+
12+
:::supervision.metrics.precision.Precision
13+
14+
<div class="md-typeset">
15+
<h2><a href="#supervision.metrics.precision.PrecisionResult">PrecisionResult</a></h2>
16+
</div>
17+
18+
:::supervision.metrics.precision.PrecisionResult

docs/metrics/recall.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
---
2+
comments: true
3+
status: new
4+
---
5+
6+
# Recall
7+
8+
<div class="md-typeset">
9+
<h2><a href="#supervision.metrics.recall.Recall">Recall</a></h2>
10+
</div>
11+
12+
:::supervision.metrics.recall.Recall
13+
14+
<div class="md-typeset">
15+
<h2><a href="#supervision.metrics.recall.RecallResult">RecallResult</a></h2>
16+
</div>
17+
18+
:::supervision.metrics.recall.RecallResult

mkdocs.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@ nav:
6666
- Utils: datasets/utils.md
6767
- Metrics:
6868
- mAP: metrics/mean_average_precision.md
69+
- Precision: metrics/precision.md
70+
- Recall: metrics/recall.md
6971
- F1 Score: metrics/f1_score.md
72+
- Common Values: metrics/common_values.md
7073
- Legacy Metrics: detection/metrics.md
7174
- Utils:
7275
- Video: utils/video.md

supervision/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
MeanAveragePrecision,
99
MeanAveragePrecisionResult,
1010
)
11+
from supervision.metrics.precision import Precision, PrecisionResult
12+
from supervision.metrics.recall import Recall, RecallResult
1113
from supervision.metrics.utils.object_size import (
1214
ObjectSizeCategory,
1315
get_detection_size_category,

supervision/metrics/core.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ class MetricTarget(Enum):
3737
"""
3838
Specifies what type of detection is used to compute the metric.
3939
40-
* BOXES: xyxy bounding boxes
41-
* MASKS: Binary masks
42-
* ORIENTED_BOUNDING_BOXES: Oriented bounding boxes (OBB)
40+
Attributes:
41+
BOXES: xyxy bounding boxes
42+
MASKS: Binary masks
43+
ORIENTED_BOUNDING_BOXES: Oriented bounding boxes (OBB)
4344
"""
4445

4546
BOXES = "boxes"
@@ -54,15 +55,16 @@ class AveragingMethod(Enum):
5455
Suppose, before returning the final result, a metric is computed for each class.
5556
How do you combine those to get the final number?
5657
57-
* MACRO: Calculate the metric for each class and average the results. The simplest
58-
averaging method, but it does not take class imbalance into account.
59-
* MICRO: Calculate the metric globally by counting the total true positives, false
60-
positives, and false negatives. Micro averaging is useful when you want to give
61-
more importance to classes with more samples. It's also more appropriate if you
62-
have an imbalance in the number of instances per class.
63-
* WEIGHTED: Calculate the metric for each class and average the results, weighted by
64-
the number of true instances of each class. Use weighted averaging if you want
65-
to take class imbalance into account.
58+
Attributes:
59+
MACRO: Calculate the metric for each class and average the results. The simplest
60+
averaging method, but it does not take class imbalance into account.
61+
MICRO: Calculate the metric globally by counting the total true positives, false
62+
positives, and false negatives. Micro averaging is useful when you want to
63+
give more importance to classes with more samples. It's also more
64+
appropriate if you have an imbalance in the number of instances per class.
65+
WEIGHTED: Calculate the metric for each class and average the results, weighted
66+
by the number of true instances of each class. Use weighted averaging if
67+
you want to take class imbalance into account.
6668
"""
6769

6870
MACRO = "macro"

supervision/metrics/f1_score.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,45 @@
2323

2424

2525
class F1Score(Metric):
26+
"""
27+
F1 Score is a metric used to evaluate object detection models. It is the harmonic
28+
mean of precision and recall, calculated at different IoU thresholds.
29+
30+
In simple terms, F1 Score is a measure of a model's balance between precision and
31+
recall (accuracy and completeness), calculated as:
32+
33+
`F1 = 2 * (precision * recall) / (precision + recall)`
34+
35+
Example:
36+
```python
37+
import supervision as sv
38+
from supervision.metrics import F1Score
39+
40+
predictions = sv.Detections(...)
41+
targets = sv.Detections(...)
42+
43+
f1_metric = F1Score()
44+
f1_result = f1_metric.update(predictions, targets).compute()
45+
46+
print(f1_result)
47+
print(f1_result.f1_50)
48+
print(f1_result.small_objects.f1_50)
49+
```
50+
"""
51+
2652
def __init__(
2753
self,
2854
metric_target: MetricTarget = MetricTarget.BOXES,
2955
averaging_method: AveragingMethod = AveragingMethod.WEIGHTED,
3056
):
57+
"""
58+
Initialize the F1Score metric.
59+
60+
Args:
61+
metric_target (MetricTarget): The type of detection data to use.
62+
averaging_method (AveragingMethod): The averaging method used to compute the
63+
F1 scores. Determines how the F1 scores are aggregated across classes.
64+
"""
3165
self._metric_target = metric_target
3266
if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
3367
raise NotImplementedError(
@@ -40,6 +74,9 @@ def __init__(
4074
self._targets_list: List[Detections] = []
4175

4276
def reset(self) -> None:
77+
"""
78+
Reset the metric to its initial state, clearing all stored data.
79+
"""
4380
self._predictions_list = []
4481
self._targets_list = []
4582

@@ -48,6 +85,16 @@ def update(
4885
predictions: Union[Detections, List[Detections]],
4986
targets: Union[Detections, List[Detections]],
5087
) -> F1Score:
88+
"""
89+
Add new predictions and targets to the metric, but do not compute the result.
90+
91+
Args:
92+
predictions (Union[Detections, List[Detections]]): The predicted detections.
93+
targets (Union[Detections, List[Detections]]): The target detections.
94+
95+
Returns:
96+
(F1Score): The updated metric instance.
97+
"""
5198
if not isinstance(predictions, list):
5299
predictions = [predictions]
53100
if not isinstance(targets, list):
@@ -65,6 +112,13 @@ def update(
65112
return self
66113

67114
def compute(self) -> F1ScoreResult:
115+
"""
116+
Calculate the F1 score metric based on the stored predictions and ground-truth
117+
data, at different IoU thresholds.
118+
119+
Returns:
120+
(F1ScoreResult): The F1 score metric result.
121+
"""
68122
result = self._compute(self._predictions_list, self._targets_list)
69123

70124
small_predictions, small_targets = self._filter_predictions_and_targets_by_size(
@@ -373,7 +427,6 @@ class F1ScoreResult:
373427
The results of the F1 score metric calculation.
374428
375429
Defaults to `0` if no detections or targets were provided.
376-
Provides a custom `__str__` method for pretty printing.
377430
378431
Attributes:
379432
metric_target (MetricTarget): the type of data used for the metric -

supervision/metrics/mean_average_precision.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,27 @@
2323

2424

2525
class MeanAveragePrecision(Metric):
26+
"""
27+
Mean Average Precision (mAP) is a metric used to evaluate object detection models.
28+
It is the average of the precision-recall curves at different IoU thresholds.
29+
30+
Example:
31+
```python
32+
import supervision as sv
33+
from supervision.metrics import MeanAveragePrecision
34+
35+
predictions = sv.Detections(...)
36+
targets = sv.Detections(...)
37+
38+
map_metric = MeanAveragePrecision()
39+
map_result = map_metric.update(predictions, targets).compute()
40+
41+
print(map_result)
42+
print(map_result.map50_95)
43+
map_result.plot()
44+
```
45+
"""
46+
2647
def __init__(
2748
self,
2849
metric_target: MetricTarget = MetricTarget.BOXES,
@@ -47,6 +68,9 @@ def __init__(
4768
self._targets_list: List[Detections] = []
4869

4970
def reset(self) -> None:
71+
"""
72+
Reset the metric to its initial state, clearing all stored data.
73+
"""
5074
self._predictions_list = []
5175
self._targets_list = []
5276

@@ -95,26 +119,10 @@ def compute(
95119
) -> MeanAveragePrecisionResult:
96120
"""
97121
Calculate Mean Average Precision based on predicted and ground-truth
98-
detections at different thresholds.
122+
detections at different thresholds.
99123
100124
Returns:
101-
(MeanAveragePrecisionResult): New instance of MeanAveragePrecision.
102-
103-
Example:
104-
```python
105-
import supervision as sv
106-
from supervision.metrics import MeanAveragePrecision
107-
108-
predictions = sv.Detections(...)
109-
targets = sv.Detections(...)
110-
111-
map_metric = MeanAveragePrecision()
112-
map_result = map_metric.update(predictions, targets).compute()
113-
114-
print(map_result)
115-
print(map_result.map50_95)
116-
map_result.plot()
117-
```
125+
(MeanAveragePrecisionResult): The Mean Average Precision result.
118126
"""
119127
result = self._compute(self._predictions_list, self._targets_list)
120128

0 commit comments

Comments
 (0)