Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions supervision/keypoint/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from contextlib import suppress
from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

Expand Down Expand Up @@ -100,6 +101,90 @@ def __eq__(self, other: KeyPoints) -> bool:
]
)

@classmethod
def from_inference(cls, inference_result: Union[dict, Any]) -> KeyPoints:
"""
Create a `sv.KeyPoints` object from the [Roboflow](https://roboflow.com/)
API inference result or the [Inference](https://inference.roboflow.com/)
package results. When a keypoint detection model is used, this method
extracts the keypoint coordinates, class IDs, confidences, and class names.

Args:
inference_result (dict, any): The result from the
Roboflow API or Inference package containing predictions with keypoints.

Returns:
(KeyPoints): A KeyPoints object containing the keypoint coordinates,
class IDs, and confidences of each keypoint.

Example:
```python
import cv2
import supervision as sv
from inference import get_model

image = cv2.imread(<SOURCE_IMAGE_PATH>)
model = get_model(model_id=<POSE_MODEL_ID>, api_key=<ROBOFLOW_API_KEY>)

result = model.infer(image)[0]
key_points = sv.KeyPoints.from_inference(result)
```

```python
import cv2
import supervision as sv
from inference_sdk import InferenceHTTPClient

image = cv2.imread(<SOURCE_IMAGE_PATH>)
client = InferenceHTTPClient(
api_url="https://detect.roboflow.com",
api_key=<ROBOFLOW_API_KEY>
)

result = client.infer(image, model_id=<POSE_MODEL_ID>)
key_points = sv.KeyPoints.from_inference(result)
```
"""
if isinstance(inference_result, list):
raise ValueError(
"from_inference() operates on a single result at a time."
"You can retrieve it like so: inference_result = model.infer(image)[0]"
)

# Unpack the result if received from inference.get_model,
# rather than inference_sdk.InferenceHTTPClient
with suppress(AttributeError):
inference_result = inference_result.dict(exclude_none=True, by_alias=True)

if not inference_result.get("predictions"):
return cls.empty()

xy = []
confidence = []
class_id = []
class_names = []

for prediction in inference_result["predictions"]:
prediction_xy = []
prediction_confidence = []
for keypoint in prediction["keypoints"]:
prediction_xy.append([keypoint["x"], keypoint["y"]])
prediction_confidence.append(keypoint["confidence"])
xy.append(prediction_xy)
confidence.append(prediction_confidence)

class_id.append(prediction["class_id"])
class_names.append(prediction["class"])

data = {CLASS_NAME_DATA_FIELD: np.array(class_names)}

return cls(
xy=np.array(xy, dtype=np.float32),
confidence=np.array(confidence, dtype=np.float32),
class_id=np.array(class_id, dtype=int),
data=data,
)

@classmethod
def from_ultralytics(cls, ultralytics_results) -> KeyPoints:
"""
Expand Down