Skip to content

Commit 71200a5

Browse files
authored
Merge pull request #1147 from roboflow/feat/keypoints-from-inference
Add `from_inference` to KeyPoints
2 parents 5f00d8b + 66f6a86 commit 71200a5

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed

supervision/keypoint/core.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from contextlib import suppress
34
from dataclasses import dataclass, field
45
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
56

@@ -100,6 +101,90 @@ def __eq__(self, other: KeyPoints) -> bool:
100101
]
101102
)
102103

104+
@classmethod
105+
def from_inference(cls, inference_result: Union[dict, Any]) -> KeyPoints:
106+
"""
107+
Create a `sv.KeyPoints` object from the [Roboflow](https://roboflow.com/)
108+
API inference result or the [Inference](https://inference.roboflow.com/)
109+
package results. When a keypoint detection model is used, this method
110+
extracts the keypoint coordinates, class IDs, confidences, and class names.
111+
112+
Args:
113+
inference_result (dict, any): The result from the
114+
Roboflow API or Inference package containing predictions with keypoints.
115+
116+
Returns:
117+
(KeyPoints): A KeyPoints object containing the keypoint coordinates,
118+
class IDs, and confidences of each keypoint.
119+
120+
Example:
121+
```python
122+
import cv2
123+
import supervision as sv
124+
from inference import get_model
125+
126+
image = cv2.imread(<SOURCE_IMAGE_PATH>)
127+
model = get_model(model_id=<POSE_MODEL_ID>, api_key=<ROBOFLOW_API_KEY>)
128+
129+
result = model.infer(image)[0]
130+
key_points = sv.KeyPoints.from_inference(result)
131+
```
132+
133+
```python
134+
import cv2
135+
import supervision as sv
136+
from inference_sdk import InferenceHTTPClient
137+
138+
image = cv2.imread(<SOURCE_IMAGE_PATH>)
139+
client = InferenceHTTPClient(
140+
api_url="https://detect.roboflow.com",
141+
api_key=<ROBOFLOW_API_KEY>
142+
)
143+
144+
result = client.infer(image, model_id=<POSE_MODEL_ID>)
145+
key_points = sv.KeyPoints.from_inference(result)
146+
```
147+
"""
148+
if isinstance(inference_result, list):
149+
raise ValueError(
150+
"from_inference() operates on a single result at a time."
151+
"You can retrieve it like so: inference_result = model.infer(image)[0]"
152+
)
153+
154+
# Unpack the result if received from inference.get_model,
155+
# rather than inference_sdk.InferenceHTTPClient
156+
with suppress(AttributeError):
157+
inference_result = inference_result.dict(exclude_none=True, by_alias=True)
158+
159+
if not inference_result.get("predictions"):
160+
return cls.empty()
161+
162+
xy = []
163+
confidence = []
164+
class_id = []
165+
class_names = []
166+
167+
for prediction in inference_result["predictions"]:
168+
prediction_xy = []
169+
prediction_confidence = []
170+
for keypoint in prediction["keypoints"]:
171+
prediction_xy.append([keypoint["x"], keypoint["y"]])
172+
prediction_confidence.append(keypoint["confidence"])
173+
xy.append(prediction_xy)
174+
confidence.append(prediction_confidence)
175+
176+
class_id.append(prediction["class_id"])
177+
class_names.append(prediction["class"])
178+
179+
data = {CLASS_NAME_DATA_FIELD: np.array(class_names)}
180+
181+
return cls(
182+
xy=np.array(xy, dtype=np.float32),
183+
confidence=np.array(confidence, dtype=np.float32),
184+
class_id=np.array(class_id, dtype=int),
185+
data=data,
186+
)
187+
103188
@classmethod
104189
def from_ultralytics(cls, ultralytics_results) -> KeyPoints:
105190
"""

0 commit comments

Comments
 (0)