Skip to content

Commit 1256833

Browse files
committed
Applied feedback from SkalskiP in PR #818
1 parent 5e339f9 commit 1256833

File tree

6 files changed

+159
-91
lines changed

6 files changed

+159
-91
lines changed

docs/detection/tools/csv_sink.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
comments: true
3+
status: new
4+
---
5+
6+
## Save CSV Detection
7+
8+
:::supervision.detection.tools.csv_sink.CSVSink

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ nav:
4141
- Polygon Zone: detection/tools/polygon_zone.md
4242
- Inference Slicer: detection/tools/inference_slicer.md
4343
- Detection Smoother: detection/tools/smoother.md
44+
- Save CSV Detection: detection/tools/csv_sink.md
4445
- Annotators: annotators.md
4546
- Trackers: trackers.md
4647
- Datasets: datasets.md

supervision/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from supervision.detection.tools.inference_slicer import InferenceSlicer
4040
from supervision.detection.tools.polygon_zone import PolygonZone, PolygonZoneAnnotator
4141
from supervision.detection.tools.smoother import DetectionsSmoother
42+
from supervision.detection.tools.csv_sink import CSVSink
4243
from supervision.detection.utils import (
4344
box_iou_batch,
4445
calculate_masks_centroids,
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from __future__ import annotations
2+
3+
import csv
4+
from pathlib import Path
5+
from typing import Any, Dict, List, Optional, Union
6+
7+
import numpy as np
8+
9+
from supervision.detection.core import Detections
10+
11+
BASE_HEADER = [
12+
"x_min",
13+
"y_min",
14+
"x_max",
15+
"y_max",
16+
"class_id",
17+
"confidence",
18+
"tracker_id",
19+
]
20+
21+
class CSVSink:
22+
"""
23+
A utility class for saving detection data to a CSV file. This class is designed to
24+
efficiently serialize detection objects into a CSV format, allowing for the inclusion of
25+
bounding box coordinates and additional attributes like confidence, class ID, and tracker ID.
26+
27+
The class supports the capability to include custom data alongside the detection fields,
28+
providing flexibility for logging various types of information.
29+
30+
Args:
31+
filename (str): The name of the CSV file where the detections will be stored.
32+
Defaults to 'output.csv'.
33+
34+
Example:
35+
```python
36+
import numpy as np
37+
import supervision as sv
38+
from ultralytics import YOLO
39+
import time
40+
41+
model = YOLO("yolov8n.pt")
42+
tracker = sv.ByteTrack()
43+
box_annotator = sv.BoundingBoxAnnotator()
44+
label_annotator = sv.LabelAnnotator()
45+
csv_sink = sv.CSVSink(...)
46+
47+
def callback(frame: np.ndarray, _: int) -> np.ndarray:
48+
start_time = time.time()
49+
results = model(frame)[0]
50+
detections = sv.Detections.from_ultralytics(results)
51+
detections = tracker.update_with_detections(detections)
52+
53+
labels = [
54+
f"#{tracker_id} {results.names[class_id]}"
55+
for class_id, tracker_id
56+
in zip(detections.class_id, detections.tracker_id)
57+
]
58+
time_frame = (time.time() - start_time)
59+
60+
csv_sink.append(detections, custom_data={"processing_time": time_frame})
61+
62+
annotated_frame = box_annotator.annotate(
63+
frame.copy(), detections=detections)
64+
return label_annotator.annotate(
65+
annotated_frame, detections=detections, labels=labels)
66+
67+
csv_sink.open()
68+
sv.process_video(
69+
source_path="people-walking.mp4",
70+
target_path="result.mp4",
71+
callback=callback
72+
)
73+
csv_sink.close()
74+
```
75+
""" # noqa: E501 // docs
76+
77+
def __init__(self, filename: str = "output.csv"):
78+
self.filename = filename
79+
self.file: Optional[open] = None
80+
self.writer: Optional[csv.writer] = None
81+
self.header_written = False
82+
self.fieldnames = [] # To keep track of header names
83+
84+
def __enter__(self) -> CSVSink:
85+
self.open()
86+
return self
87+
88+
def __exit__(
89+
self,
90+
exc_type: Optional[type],
91+
exc_val: Optional[Exception],
92+
exc_tb: Optional[Any],
93+
) -> None:
94+
self.close()
95+
96+
def open(self) -> None:
97+
self.file = open(self.filename, "w", newline="")
98+
self.writer = csv.writer(self.file)
99+
100+
def close(self) -> None:
101+
if self.file:
102+
self.file.close()
103+
104+
@staticmethod
105+
def parse_detection_data(detections: Detections, custom_data: Dict[str, Any] = None) -> List[Dict[str, Any]]:
106+
parsed_rows = []
107+
for i in range(len(detections.xyxy)):
108+
row = {
109+
"x_min": detections.xyxy[i][0],
110+
"y_min": detections.xyxy[i][1],
111+
"x_max": detections.xyxy[i][2],
112+
"y_max": detections.xyxy[i][3],
113+
"class_id": detections.class_id[i],
114+
"confidence": detections.confidence[i],
115+
"tracker_id": detections.tracker_id[i],
116+
}
117+
if hasattr(detections, "data"):
118+
for key, value in detections.data.items():
119+
row[key] = value[i]
120+
if custom_data:
121+
row.update(custom_data)
122+
parsed_rows.append(row)
123+
return parsed_rows
124+
125+
def append(self, detections: Detections, custom_data: Dict[str, Any] = None) -> None:
126+
if not self.writer:
127+
raise Exception(f"Cannot append to CSV: The file '{self.filename}' is not open.")
128+
if not self.header_written:
129+
self.write_header(detections, custom_data)
130+
131+
parsed_rows = CSVSink.parse_detection_data(detections, custom_data)
132+
for row in parsed_rows:
133+
self.writer.writerow([row.get(fieldname, "") for fieldname in self.fieldnames])
134+
135+
def write_header(
136+
self, detections: Detections, custom_data: Dict[str, Any]
137+
) -> None:
138+
dynamic_header = sorted(set(custom_data.keys()) | set(getattr(detections, "data", {}).keys()))
139+
self.fieldnames = BASE_HEADER + dynamic_header
140+
self.writer.writerow(self.fieldnames)
141+
self.header_written = True

supervision/utils/file.py

Lines changed: 3 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,10 @@
11
import json
2-
import csv
32
from pathlib import Path
4-
from typing import List, Optional, Union, Dict, Any
5-
from supervision.detection.core import Detections
3+
from typing import List, Optional, Union
64

75
import numpy as np
86
import yaml
97

10-
class CSVSink:
11-
"""
12-
A utility class for saving detection data to a CSV file. This class is designed to
13-
efficiently serialize detection objects into a CSV format, allowing for the inclusion of
14-
bounding box coordinates and additional attributes like confidence, class ID, and tracker ID.
15-
16-
The class supports the capability to include custom data alongside the detection fields,
17-
providing flexibility for logging various types of information.
18-
19-
Args:
20-
filename (str): The name of the CSV file where the detections will be stored.
21-
Defaults to 'output.csv'.
22-
23-
Usage:
24-
```python
25-
from supervision.utils.detections import Detections
26-
# Initialize CSVSink with a filename
27-
csv_sink = CSVSink('my_detections.csv')
28-
29-
# Assuming detections is an instance of Detections containing detection data
30-
detections = Detections(...)
31-
32-
# Open the CSVSink context, append detection data, and close the file automatically
33-
with csv_sink as sink:
34-
sink.append(detections, custom_data={'frame': 1})
35-
```
36-
"""
37-
def __init__(self, filename: str = 'output.csv'):
38-
self.filename = filename
39-
self.file: Optional[open] = None
40-
self.writer: Optional[csv.writer] = None
41-
self.header_written = False
42-
self.fieldnames = [] # To keep track of header names
43-
44-
def __enter__(self) -> 'CSVSink':
45-
self.open()
46-
return self
47-
48-
def __exit__(self, exc_type: Optional[type], exc_val: Optional[Exception], exc_tb: Optional[Any]) -> None:
49-
self.close()
50-
51-
def open(self) -> None:
52-
self.file = open(self.filename, 'w', newline='')
53-
self.writer = csv.writer(self.file)
54-
55-
def close(self) -> None:
56-
if self.file:
57-
self.file.close()
58-
59-
def append(self, detections: Detections, custom_data: Dict[str, Any] = None) -> None:
60-
if not self.writer:
61-
raise Exception(f"Cannot append to CSV: The file '{self.filename}' is not open. Ensure that the 'open' method is called before appending data.")
62-
if not self.header_written:
63-
self.write_header(detections, custom_data)
64-
for i in range(len(detections.xyxy)):
65-
self.write_detection_row(detections, i, custom_data)
66-
67-
def write_header(self, detections: Detections, custom_data: Dict[str, Any]) -> None:
68-
base_header = ['x_min', 'y_min', 'x_max', 'y_max', 'class_id', 'confidence', 'tracker_id']
69-
dynamic_header = sorted(set(custom_data.keys()) | set(getattr(detections, 'data', {}).keys()))
70-
self.fieldnames = base_header + dynamic_header
71-
self.dynamic_fields = dynamic_header # Store only the dynamic part
72-
self.writer.writerow(self.fieldnames)
73-
self.header_written = True
74-
75-
def write_detection_row(self, detections: Detections, index: int, custom_data: Dict[str, Any]) -> None:
76-
row_base = [
77-
detections.xyxy[index][0], detections.xyxy[index][1],
78-
detections.xyxy[index][2], detections.xyxy[index][3],
79-
detections.class_id[index], detections.confidence[index],
80-
detections.tracker_id[index]
81-
]
82-
dynamic_data = {}
83-
if hasattr(detections, 'data'):
84-
for key, value in detections.data.items():
85-
dynamic_data[key] = value[index]
86-
if custom_data:
87-
dynamic_data.update(custom_data)
88-
89-
row_dynamic = [dynamic_data.get(key) for key in self.fieldnames[7:]]
90-
self.writer.writerow(row_base + row_dynamic)
918

929
class NumpyJsonEncoder(json.JSONEncoder):
9310
def default(self, obj):
@@ -99,6 +16,7 @@ def default(self, obj):
9916
return obj.tolist()
10017
return super(NumpyJsonEncoder, self).default(obj)
10118

19+
10220
def list_files_with_extensions(
10321
directory: Union[str, Path], extensions: Optional[List[str]] = None
10422
) -> List[Path]:
@@ -228,4 +146,4 @@ def save_yaml_file(data: dict, file_path: str) -> None:
228146
"""
229147

230148
with open(file_path, "w") as outfile:
231-
yaml.dump(data, outfile, sort_keys=False, default_flow_style=None)
149+
yaml.dump(data, outfile, sort_keys=False, default_flow_style=None)

test/utils/test_csv.py renamed to test/detection/test_csv.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
import csv
33
import pytest
44
import numpy as np
5-
from supervision.utils.file import CSVSink
65
from supervision.detection.core import Detections
6+
import supervision as sv
77

8-
#pytest test/utils/test_csv.py
98
@pytest.fixture(scope="module")
109
def detection_instances():
1110
# Setup detection instances as per the provided example
@@ -42,15 +41,15 @@ def test_csv_sink(detection_instances):
4241
]
4342

4443
# Using the CSVSink class to write the detection data to a CSV file
45-
with CSVSink(filename=csv_filename) as sink:
44+
with sv.CSVSink(filename=csv_filename) as sink:
4645
sink.append(detections, custom_data)
4746
sink.append(second_detections, second_custom_data)
4847

4948
# Read back the CSV file and verify its contents
5049
with open(csv_filename, mode='r', newline='') as file:
5150
reader = csv.reader(file)
5251
for i, row in enumerate(reader):
53-
assert [str(item) for item in expected_rows[i]] == row, f"Row in CSV file did not match expected output: {row} != {expected_rows[i]}"
52+
assert [str(item) for item in expected_rows[i]] == row, f"Row in CSV didn't match expected output: {row} != {expected_rows[i]}"
5453

5554
# Clean up by removing the test CSV file
5655
os.remove(csv_filename)
@@ -67,7 +66,7 @@ def test_csv_sink_manual(detection_instances):
6766
]
6867

6968
# Using the CSVSink class to write the detection data to a CSV file
70-
sink = CSVSink(filename=csv_filename)
69+
sink = sv.CSVSink(filename=csv_filename)
7170
sink.open()
7271
sink.append(detections, custom_data)
7372
sink.append(second_detections, second_custom_data)
@@ -77,7 +76,7 @@ def test_csv_sink_manual(detection_instances):
7776
with open(csv_filename, mode='r', newline='') as file:
7877
reader = csv.reader(file)
7978
for i, row in enumerate(reader):
80-
assert [str(item) for item in expected_rows[i]] == row, f"Row in CSV file did not match expected output: {row} != {expected_rows[i]}"
79+
assert [str(item) for item in expected_rows[i]] == row, f"Row in CSV didn't match expected output: {row} != {expected_rows[i]}"
8180

8281
# Clean up by removing the test CSV file
8382
os.remove(csv_filename)

0 commit comments

Comments
 (0)