Skip to content

Commit 87a4927

Browse files
authored
Merge pull request #818 from AdonaiVera/allowing_serialise_detections_csv
New function [CSVSink] - allowing to serialise Detections to a CSV file
2 parents 0ccb0b8 + 6bebd4e commit 87a4927

File tree

6 files changed

+611
-1
lines changed

6 files changed

+611
-1
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
---
2+
comments: true
3+
status: new
4+
---
5+
6+
# Save Detections
7+
8+
<div class="md-typeset">
9+
<h2>CSV Sink</h2>
10+
</div>
11+
12+
:::supervision.detection.tools.csv_sink.CSVSink

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ nav:
4949
- Polygon Zone: detection/tools/polygon_zone.md
5050
- Inference Slicer: detection/tools/inference_slicer.md
5151
- Detection Smoother: detection/tools/smoother.md
52+
- Save Detections: detection/tools/save_detections.md
5253
- Annotators: annotators.md
5354
- Trackers: trackers.md
5455
- Datasets: datasets.md

supervision/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from supervision.detection.annotate import BoxAnnotator
3737
from supervision.detection.core import Detections
3838
from supervision.detection.line_counter import LineZone, LineZoneAnnotator
39+
from supervision.detection.tools.csv_sink import CSVSink
3940
from supervision.detection.tools.inference_slicer import InferenceSlicer
4041
from supervision.detection.tools.polygon_zone import PolygonZone, PolygonZoneAnnotator
4142
from supervision.detection.tools.smoother import DetectionsSmoother

supervision/annotators/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class OrientedBoxAnnotator(BaseAnnotator):
9999

100100
def __init__(
101101
self,
102-
color: Union[Color, ColorPalette] = ColorPalette.default(),
102+
color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
103103
thickness: int = 2,
104104
color_lookup: ColorLookup = ColorLookup.CLASS,
105105
):
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
from __future__ import annotations
2+
3+
import csv
4+
import os
5+
from typing import Any, Dict, List, Optional
6+
7+
from supervision.detection.core import Detections
8+
9+
BASE_HEADER = [
10+
"x_min",
11+
"y_min",
12+
"x_max",
13+
"y_max",
14+
"class_id",
15+
"confidence",
16+
"tracker_id",
17+
]
18+
19+
20+
class CSVSink:
21+
"""
22+
A utility class for saving detection data to a CSV file. This class is designed to
23+
efficiently serialize detection objects into a CSV format, allowing for the
24+
inclusion of bounding box coordinates and additional attributes like `confidence`,
25+
`class_id`, and `tracker_id`.
26+
27+
!!! tip
28+
29+
CSVSink allow to pass custom data alongside the detection fields, providing
30+
flexibility for logging various types of information.
31+
32+
Args:
33+
file_name (str): The name of the CSV file where the detections will be stored.
34+
Defaults to 'output.csv'.
35+
36+
Example:
37+
```python
38+
import supervision as sv
39+
from ultralytics import YOLO
40+
41+
model = YOLO(<SOURCE_MODEL_PATH>)
42+
csv_sink = sv.CSVSink(<RESULT_CSV_FILE_PATH>)
43+
frames_generator = sv.get_video_frames_generator(<SOURCE_VIDEO_PATH>)
44+
45+
with csv_sink:
46+
for frame in frames_generator:
47+
result = model(frame)[0]
48+
detections = sv.Detections.from_ultralytics(result)
49+
sink.append(detections, custom_data={'<CUSTOM_LABEL>':'<CUSTOM_DATA>'})
50+
```
51+
""" # noqa: E501 // docs
52+
53+
def __init__(self, file_name: str = "output.csv") -> None:
54+
"""
55+
Initialize the CSVSink instance.
56+
57+
Args:
58+
file_name (str): The name of the CSV file.
59+
60+
Returns:
61+
None
62+
"""
63+
self.file_name = file_name
64+
self.file: Optional[open] = None
65+
self.writer: Optional[csv.writer] = None
66+
self.header_written = False
67+
self.field_names = []
68+
69+
def __enter__(self) -> CSVSink:
70+
self.open()
71+
return self
72+
73+
def __exit__(
74+
self,
75+
exc_type: Optional[type],
76+
exc_val: Optional[Exception],
77+
exc_tb: Optional[Any],
78+
) -> None:
79+
self.close()
80+
81+
def open(self) -> None:
82+
"""
83+
Open the CSV file for writing.
84+
85+
Returns:
86+
None
87+
"""
88+
parent_directory = os.path.dirname(self.file_name)
89+
if parent_directory and not os.path.exists(parent_directory):
90+
os.makedirs(parent_directory)
91+
92+
self.file = open(self.file_name, "w", newline="")
93+
self.writer = csv.writer(self.file)
94+
95+
def close(self) -> None:
96+
"""
97+
Close the CSV file.
98+
99+
Returns:
100+
None
101+
"""
102+
if self.file:
103+
self.file.close()
104+
105+
@staticmethod
106+
def parse_detection_data(
107+
detections: Detections, custom_data: Dict[str, Any] = None
108+
) -> List[Dict[str, Any]]:
109+
parsed_rows = []
110+
for i in range(len(detections.xyxy)):
111+
row = {
112+
"x_min": detections.xyxy[i][0],
113+
"y_min": detections.xyxy[i][1],
114+
"x_max": detections.xyxy[i][2],
115+
"y_max": detections.xyxy[i][3],
116+
"class_id": ""
117+
if detections.class_id is None
118+
else str(detections.class_id[i]),
119+
"confidence": ""
120+
if detections.confidence is None
121+
else str(detections.confidence[i]),
122+
"tracker_id": ""
123+
if detections.tracker_id is None
124+
else str(detections.tracker_id[i]),
125+
}
126+
127+
if hasattr(detections, "data"):
128+
for key, value in detections.data.items():
129+
if value.ndim == 0:
130+
row[key] = value
131+
else:
132+
row[key] = value[i]
133+
134+
if custom_data:
135+
row.update(custom_data)
136+
parsed_rows.append(row)
137+
return parsed_rows
138+
139+
def append(
140+
self, detections: Detections, custom_data: Dict[str, Any] = None
141+
) -> None:
142+
"""
143+
Append detection data to the CSV file.
144+
145+
Args:
146+
detections (Detections): The detection data.
147+
custom_data (Dict[str, Any]): Custom data to include.
148+
149+
Returns:
150+
None
151+
"""
152+
if not self.writer:
153+
raise Exception(
154+
f"Cannot append to CSV: The file '{self.file_name}' is not open."
155+
)
156+
field_names = CSVSink.parse_field_names(detections, custom_data)
157+
if not self.header_written:
158+
self.field_names = field_names
159+
self.writer.writerow(field_names)
160+
self.header_written = True
161+
162+
if field_names != self.field_names:
163+
print(
164+
f"Field names do not match the header. "
165+
f"Expected: {self.field_names}, given: {field_names}"
166+
)
167+
168+
parsed_rows = CSVSink.parse_detection_data(detections, custom_data)
169+
for row in parsed_rows:
170+
self.writer.writerow(
171+
[row.get(field_name, "") for field_name in self.field_names]
172+
)
173+
174+
@staticmethod
175+
def parse_field_names(
176+
detections: Detections, custom_data: Dict[str, Any]
177+
) -> List[str]:
178+
dynamic_header = sorted(
179+
set(custom_data.keys()) | set(getattr(detections, "data", {}).keys())
180+
)
181+
return BASE_HEADER + dynamic_header

0 commit comments

Comments
 (0)