|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +from contextlib import suppress |
3 | 4 | from dataclasses import dataclass, field
|
4 | 5 | from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
5 | 6 |
|
@@ -100,6 +101,90 @@ def __eq__(self, other: KeyPoints) -> bool:
|
100 | 101 | ]
|
101 | 102 | )
|
102 | 103 |
|
| 104 | + @classmethod |
| 105 | + def from_inference(cls, inference_result: Union[dict, Any]) -> KeyPoints: |
| 106 | + """ |
| 107 | + Create a `sv.KeyPoints` object from the [Roboflow](https://roboflow.com/) |
| 108 | + API inference result or the [Inference](https://inference.roboflow.com/) |
| 109 | + package results. When a keypoint detection model is used, this method |
| 110 | + extracts the keypoint coordinates, class IDs, confidences, and class names. |
| 111 | +
|
| 112 | + Args: |
| 113 | + inference_result (dict, any): The result from the |
| 114 | + Roboflow API or Inference package containing predictions with keypoints. |
| 115 | +
|
| 116 | + Returns: |
| 117 | + (KeyPoints): A KeyPoints object containing the keypoint coordinates, |
| 118 | + class IDs, and confidences of each keypoint. |
| 119 | +
|
| 120 | + Example: |
| 121 | + ```python |
| 122 | + import cv2 |
| 123 | + import supervision as sv |
| 124 | + from inference import get_model |
| 125 | +
|
| 126 | + image = cv2.imread(<SOURCE_IMAGE_PATH>) |
| 127 | + model = get_model(model_id=<POSE_MODEL_ID>, api_key=<ROBOFLOW_API_KEY>) |
| 128 | +
|
| 129 | + result = model.infer(image)[0] |
| 130 | + key_points = sv.KeyPoints.from_inference(result) |
| 131 | + ``` |
| 132 | +
|
| 133 | + ```python |
| 134 | + import cv2 |
| 135 | + import supervision as sv |
| 136 | + from inference_sdk import InferenceHTTPClient |
| 137 | +
|
| 138 | + image = cv2.imread(<SOURCE_IMAGE_PATH>) |
| 139 | + client = InferenceHTTPClient( |
| 140 | + api_url="https://detect.roboflow.com", |
| 141 | + api_key=<ROBOFLOW_API_KEY> |
| 142 | + ) |
| 143 | +
|
| 144 | + result = client.infer(image, model_id=<POSE_MODEL_ID>) |
| 145 | + key_points = sv.KeyPoints.from_inference(result) |
| 146 | + ``` |
| 147 | + """ |
| 148 | + if isinstance(inference_result, list): |
| 149 | + raise ValueError( |
| 150 | + "from_inference() operates on a single result at a time." |
| 151 | + "You can retrieve it like so: inference_result = model.infer(image)[0]" |
| 152 | + ) |
| 153 | + |
| 154 | + # Unpack the result if received from inference.get_model, |
| 155 | + # rather than inference_sdk.InferenceHTTPClient |
| 156 | + with suppress(AttributeError): |
| 157 | + inference_result = inference_result.dict(exclude_none=True, by_alias=True) |
| 158 | + |
| 159 | + if not inference_result.get("predictions"): |
| 160 | + return cls.empty() |
| 161 | + |
| 162 | + xy = [] |
| 163 | + confidence = [] |
| 164 | + class_id = [] |
| 165 | + class_names = [] |
| 166 | + |
| 167 | + for prediction in inference_result["predictions"]: |
| 168 | + prediction_xy = [] |
| 169 | + prediction_confidence = [] |
| 170 | + for keypoint in prediction["keypoints"]: |
| 171 | + prediction_xy.append([keypoint["x"], keypoint["y"]]) |
| 172 | + prediction_confidence.append(keypoint["confidence"]) |
| 173 | + xy.append(prediction_xy) |
| 174 | + confidence.append(prediction_confidence) |
| 175 | + |
| 176 | + class_id.append(prediction["class_id"]) |
| 177 | + class_names.append(prediction["class"]) |
| 178 | + |
| 179 | + data = {CLASS_NAME_DATA_FIELD: np.array(class_names)} |
| 180 | + |
| 181 | + return cls( |
| 182 | + xy=np.array(xy, dtype=np.float32), |
| 183 | + confidence=np.array(confidence, dtype=np.float32), |
| 184 | + class_id=np.array(class_id, dtype=int), |
| 185 | + data=data, |
| 186 | + ) |
| 187 | + |
103 | 188 | @classmethod
|
104 | 189 | def from_ultralytics(cls, ultralytics_results) -> KeyPoints:
|
105 | 190 | """
|
|
0 commit comments