Skip to content

Commit c04edbc

Browse files
authored
Merge pull request #735 from revtheundead/add-center-point-crossed-condition-to-line-counter
Added an option to count detection upon the center point of the bounding box crossing the line counter
2 parents 536e0d8 + 6128576 commit c04edbc

File tree

6 files changed

+249
-53
lines changed

6 files changed

+249
-53
lines changed

examples/speed_estimation/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ supervision package for multiple tasks such as tracking, annotations, etc.
1111

1212
https://github.com/roboflow/supervision/assets/26109316/d50118c1-2ae4-458d-915a-5d860fd36f71
1313

14-
> [!IMPORTANT]
14+
> [!IMPORTANT]
1515
> Adjust the [`SOURCE`](https://github.com/roboflow/supervision/blob/e32b05a636dab2ea1f39299e529c4b22b8baa8da/examples/speed_estimation/ultralytics_example.py#L10)
1616
> and [`TARGET`](https://github.com/roboflow/supervision/blob/e32b05a636dab2ea1f39299e529c4b22b8baa8da/examples/speed_estimation/ultralytics_example.py#L15)
17-
> configuration if you plan to run a speed estimation script on your video file. Those must be adjusted separately for each camera view. You can learn more
17+
> configuration if you plan to run a speed estimation script on your video file. Those must be adjusted separately for each camera view. You can learn more
1818
> from our YouTube [tutorial](https://youtu.be/uWP6UjDeZvY).
1919
2020
## 💻 install

supervision/detection/line_counter.py

Lines changed: 83 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from typing import Dict, Optional, Tuple
1+
from typing import Dict, Iterable, Optional, Tuple
22

33
import cv2
44
import numpy as np
55

66
from supervision.detection.core import Detections
77
from supervision.draw.color import Color
8-
from supervision.geometry.core import Point, Rect, Vector
8+
from supervision.geometry.core import Point, Position, Rect, Vector
99

1010

1111
class LineZone:
@@ -26,16 +26,71 @@ class LineZone:
2626
to outside.
2727
"""
2828

29-
def __init__(self, start: Point, end: Point):
29+
def __init__(
30+
self,
31+
start: Point,
32+
end: Point,
33+
triggering_anchors: Iterable[Position] = (
34+
Position.TOP_LEFT,
35+
Position.TOP_RIGHT,
36+
Position.BOTTOM_LEFT,
37+
Position.BOTTOM_RIGHT,
38+
),
39+
):
3040
"""
3141
Args:
3242
start (Point): The starting point of the line.
3343
end (Point): The ending point of the line.
44+
triggering_anchors (List[sv.Position]): A list of positions
45+
specifying which anchors of the detections bounding box
46+
to consider when deciding on whether the detection
47+
has passed the line counter or not. By default, this
48+
contains the four corners of the detection's bounding box
3449
"""
3550
self.vector = Vector(start=start, end=end)
51+
self.limits = self.calculate_region_of_interest_limits(vector=self.vector)
3652
self.tracker_state: Dict[str, bool] = {}
3753
self.in_count: int = 0
3854
self.out_count: int = 0
55+
self.triggering_anchors = triggering_anchors
56+
57+
@staticmethod
58+
def calculate_region_of_interest_limits(vector: Vector) -> Tuple[Vector, Vector]:
59+
magnitude = vector.magnitude
60+
61+
if magnitude == 0:
62+
raise ValueError("The magnitude of the vector cannot be zero.")
63+
64+
delta_x = vector.end.x - vector.start.x
65+
delta_y = vector.end.y - vector.start.y
66+
67+
unit_vector_x = delta_x / magnitude
68+
unit_vector_y = delta_y / magnitude
69+
70+
perpendicular_vector_x = -unit_vector_y
71+
perpendicular_vector_y = unit_vector_x
72+
73+
start_region_limit = Vector(
74+
start=vector.start,
75+
end=Point(
76+
x=vector.start.x + perpendicular_vector_x,
77+
y=vector.start.y + perpendicular_vector_y,
78+
),
79+
)
80+
end_region_limit = Vector(
81+
start=vector.end,
82+
end=Point(
83+
x=vector.end.x - perpendicular_vector_x,
84+
y=vector.end.y - perpendicular_vector_y,
85+
),
86+
)
87+
return start_region_limit, end_region_limit
88+
89+
@staticmethod
90+
def is_point_in_limits(point: Point, limits: Tuple[Vector, Vector]) -> bool:
91+
cross_product_1 = limits[0].cross_product(point)
92+
cross_product_2 = limits[1].cross_product(point)
93+
return (cross_product_1 > 0) == (cross_product_2 > 0)
3994

4095
def trigger(self, detections: Detections) -> Tuple[np.ndarray, np.ndarray]:
4196
"""
@@ -54,20 +109,35 @@ def trigger(self, detections: Detections) -> Tuple[np.ndarray, np.ndarray]:
54109
crossed_in = np.full(len(detections), False)
55110
crossed_out = np.full(len(detections), False)
56111

57-
for i, (xyxy, tracker_id) in enumerate(
58-
zip(detections.xyxy, detections.tracker_id)
59-
):
112+
if len(detections) == 0:
113+
return crossed_in, crossed_out
114+
115+
all_anchors = np.array(
116+
[
117+
detections.get_anchors_coordinates(anchor)
118+
for anchor in self.triggering_anchors
119+
]
120+
)
121+
122+
for i, tracker_id in enumerate(detections.tracker_id):
60123
if tracker_id is None:
61124
continue
62125

63-
x1, y1, x2, y2 = xyxy
64-
anchors = [
65-
Point(x=x1, y=y1),
66-
Point(x=x1, y=y2),
67-
Point(x=x2, y=y1),
68-
Point(x=x2, y=y2),
126+
box_anchors = [Point(x=x, y=y) for x, y in all_anchors[:, i, :]]
127+
128+
in_limits = all(
129+
[
130+
self.is_point_in_limits(point=anchor, limits=self.limits)
131+
for anchor in box_anchors
132+
]
133+
)
134+
135+
if not in_limits:
136+
continue
137+
138+
triggers = [
139+
self.vector.cross_product(point=anchor) > 0 for anchor in box_anchors
69140
]
70-
triggers = [self.vector.is_in(point=anchor) for anchor in anchors]
71141

72142
if len(set(triggers)) == 2:
73143
continue

supervision/geometry/core.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from dataclasses import dataclass
44
from enum import Enum
5+
from math import sqrt
56
from typing import Tuple
67

78

@@ -43,13 +44,38 @@ class Vector:
4344
start: Point
4445
end: Point
4546

46-
def is_in(self, point: Point) -> bool:
47-
v1 = Vector(self.start, self.end)
48-
v2 = Vector(self.start, point)
49-
cross_product = (v1.end.x - v1.start.x) * (v2.end.y - v2.start.y) - (
50-
v1.end.y - v1.start.y
51-
) * (v2.end.x - v2.start.x)
52-
return cross_product < 0
47+
@property
48+
def magnitude(self) -> float:
49+
"""
50+
Calculate the magnitude (length) of the vector.
51+
52+
Returns:
53+
float: The magnitude of the vector.
54+
"""
55+
dx = self.end.x - self.start.x
56+
dy = self.end.y - self.start.y
57+
return sqrt(dx**2 + dy**2)
58+
59+
def cross_product(self, point: Point) -> float:
60+
"""
61+
Calculate the 2D cross product (also known as the vector product or outer
62+
product) of the vector and a point, treated as vectors in 2D space.
63+
64+
Args:
65+
point (Point): The point to be evaluated, treated as the endpoint of a
66+
vector originating from the 'start' of the main vector.
67+
68+
Returns:
69+
float: The scalar value of the cross product. It is positive if 'point'
70+
lies to the left of the vector (when moving from 'start' to 'end'),
71+
negative if it lies to the right, and 0 if it is collinear with the
72+
vector.
73+
"""
74+
dx_vector = self.end.x - self.start.x
75+
dy_vector = self.end.y - self.start.y
76+
dx_point = point.x - self.start.x
77+
dy_point = point.y - self.start.y
78+
return (dx_vector * dy_point) - (dy_vector * dx_point)
5379

5480

5581
@dataclass

test/detection/test_line_counter.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from contextlib import ExitStack as DoesNotRaise
2+
from typing import Optional, Tuple
3+
4+
import pytest
5+
6+
from supervision import LineZone
7+
from supervision.geometry.core import Point, Vector
8+
9+
10+
@pytest.mark.parametrize(
11+
"vector, expected_result, exception",
12+
[
13+
(
14+
Vector(start=Point(x=0.0, y=0.0), end=Point(x=0.0, y=0.0)),
15+
None,
16+
pytest.raises(ValueError),
17+
),
18+
(
19+
Vector(start=Point(x=1.0, y=1.0), end=Point(x=1.0, y=1.0)),
20+
None,
21+
pytest.raises(ValueError),
22+
),
23+
(
24+
Vector(start=Point(x=0.0, y=0.0), end=Point(x=0.0, y=4.0)),
25+
(
26+
Vector(start=Point(x=0.0, y=0.0), end=Point(x=-1.0, y=0.0)),
27+
Vector(start=Point(x=0.0, y=4.0), end=Point(x=1.0, y=4.0)),
28+
),
29+
DoesNotRaise(),
30+
),
31+
(
32+
Vector(Point(0.0, 0.0), Point(4.0, 0.0)),
33+
(
34+
Vector(start=Point(x=0.0, y=0.0), end=Point(x=0.0, y=1.0)),
35+
Vector(start=Point(x=4.0, y=0.0), end=Point(x=4.0, y=-1.0)),
36+
),
37+
DoesNotRaise(),
38+
),
39+
(
40+
Vector(Point(0.0, 0.0), Point(3.0, 4.0)),
41+
(
42+
Vector(start=Point(x=0, y=0), end=Point(x=-0.8, y=0.6)),
43+
Vector(start=Point(x=3, y=4), end=Point(x=3.8, y=3.4)),
44+
),
45+
DoesNotRaise(),
46+
),
47+
(
48+
Vector(Point(0.0, 0.0), Point(4.0, 3.0)),
49+
(
50+
Vector(start=Point(x=0, y=0), end=Point(x=-0.6, y=0.8)),
51+
Vector(start=Point(x=4, y=3), end=Point(x=4.6, y=2.2)),
52+
),
53+
DoesNotRaise(),
54+
),
55+
(
56+
Vector(Point(0.0, 0.0), Point(3.0, -4.0)),
57+
(
58+
Vector(start=Point(x=0, y=0), end=Point(x=0.8, y=0.6)),
59+
Vector(start=Point(x=3, y=-4), end=Point(x=2.2, y=-4.6)),
60+
),
61+
DoesNotRaise(),
62+
),
63+
],
64+
)
65+
def test_calculate_region_of_interest_limits(
66+
vector: Vector,
67+
expected_result: Optional[Tuple[Vector, Vector]],
68+
exception: Exception,
69+
) -> None:
70+
with exception:
71+
result = LineZone.calculate_region_of_interest_limits(vector=vector)
72+
assert result == expected_result

test/geometry/test_core.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import pytest
2+
3+
from supervision.geometry.core import Point, Vector
4+
5+
6+
@pytest.mark.parametrize(
7+
"vector, point, expected_result",
8+
[
9+
(Vector(start=Point(x=0, y=0), end=Point(x=5, y=5)), Point(x=-1, y=1), 10.0),
10+
(Vector(start=Point(x=0, y=0), end=Point(x=5, y=5)), Point(x=6, y=6), 0.0),
11+
(Vector(start=Point(x=0, y=0), end=Point(x=5, y=5)), Point(x=3, y=6), 15.0),
12+
(Vector(start=Point(x=5, y=5), end=Point(x=0, y=0)), Point(x=-1, y=1), -10.0),
13+
(Vector(start=Point(x=5, y=5), end=Point(x=0, y=0)), Point(x=6, y=6), 0.0),
14+
(Vector(start=Point(x=5, y=5), end=Point(x=0, y=0)), Point(x=3, y=6), -15.0),
15+
(Vector(start=Point(x=0, y=0), end=Point(x=1, y=0)), Point(x=0, y=0), 0.0),
16+
(Vector(start=Point(x=0, y=0), end=Point(x=1, y=0)), Point(x=0, y=-1), -1.0),
17+
(Vector(start=Point(x=0, y=0), end=Point(x=1, y=0)), Point(x=0, y=1), 1.0),
18+
(Vector(start=Point(x=1, y=0), end=Point(x=0, y=0)), Point(x=0, y=0), 0.0),
19+
(Vector(start=Point(x=1, y=0), end=Point(x=0, y=0)), Point(x=0, y=-1), 1.0),
20+
(Vector(start=Point(x=1, y=0), end=Point(x=0, y=0)), Point(x=0, y=1), -1.0),
21+
(Vector(start=Point(x=1, y=1), end=Point(x=1, y=3)), Point(x=0, y=0), 2.0),
22+
(Vector(start=Point(x=1, y=1), end=Point(x=1, y=3)), Point(x=1, y=4), 0.0),
23+
(Vector(start=Point(x=1, y=1), end=Point(x=1, y=3)), Point(x=2, y=4), -2.0),
24+
(Vector(start=Point(x=1, y=3), end=Point(x=1, y=1)), Point(x=0, y=0), -2.0),
25+
(Vector(start=Point(x=1, y=3), end=Point(x=1, y=1)), Point(x=1, y=4), 0.0),
26+
(Vector(start=Point(x=1, y=3), end=Point(x=1, y=1)), Point(x=2, y=4), 2.0),
27+
],
28+
)
29+
def test_vector_cross_product(
30+
vector: Vector, point: Point, expected_result: float
31+
) -> None:
32+
result = vector.cross_product(point=point)
33+
assert result == expected_result
34+
35+
36+
@pytest.mark.parametrize(
37+
"vector, expected_result",
38+
[
39+
(Vector(start=Point(x=0, y=0), end=Point(x=0, y=0)), 0.0),
40+
(Vector(start=Point(x=1, y=0), end=Point(x=0, y=0)), 1.0),
41+
(Vector(start=Point(x=0, y=1), end=Point(x=0, y=0)), 1.0),
42+
(Vector(start=Point(x=0, y=0), end=Point(x=1, y=0)), 1.0),
43+
(Vector(start=Point(x=0, y=0), end=Point(x=0, y=1)), 1.0),
44+
(Vector(start=Point(x=-1, y=0), end=Point(x=0, y=0)), 1.0),
45+
(Vector(start=Point(x=0, y=-1), end=Point(x=0, y=0)), 1.0),
46+
(Vector(start=Point(x=0, y=0), end=Point(x=-1, y=0)), 1.0),
47+
(Vector(start=Point(x=0, y=0), end=Point(x=0, y=-1)), 1.0),
48+
(Vector(start=Point(x=0, y=0), end=Point(x=3, y=4)), 5.0),
49+
(Vector(start=Point(x=0, y=0), end=Point(x=-3, y=4)), 5.0),
50+
(Vector(start=Point(x=0, y=0), end=Point(x=3, y=-4)), 5.0),
51+
(Vector(start=Point(x=0, y=0), end=Point(x=-3, y=-4)), 5.0),
52+
(Vector(start=Point(x=0, y=0), end=Point(x=4, y=3)), 5.0),
53+
(Vector(start=Point(x=3, y=4), end=Point(x=0, y=0)), 5.0),
54+
(Vector(start=Point(x=4, y=3), end=Point(x=0, y=0)), 5.0),
55+
],
56+
)
57+
def test_vector_magnitude(vector: Vector, expected_result: float) -> None:
58+
result = vector.magnitude
59+
assert result == expected_result

test/geometry/test_dataclasses.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

0 commit comments

Comments
 (0)