Skip to content
48 changes: 45 additions & 3 deletions supervision/classification/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,54 @@ def __post_init__(self) -> None:
_validate_class_ids(self.class_id, n)
_validate_confidence(self.confidence, n)

def __len__(self) -> int:
"""
Returns the number of classifications.
"""
return len(self.class_id)

@classmethod
def from_clip(cls, clip_results) -> Classifications:
"""
Creates a Classifications instance from a
[clip](https://github.com/openai/clip) inference result.

Args:
clip_results (np.ndarray): The inference result from clip model.

Returns:
Classifications: A new Classifications object.

Example:
```python
>>> from PIL import Image
>>> import clip
>>> import supervision as sv

>>> model, preprocess = clip.load('ViT-B/32')

>>> image = cv2.imread(SOURCE_IMAGE_PATH)
>>> image = preprocess(image).unsqueeze(0)

>>> text = clip.tokenize(["a diagram", "a dog", "a cat"])
>>> output, _ = model(image, text)
>>> classifications = sv.Classifications.from_clip(output)
```
"""

confidence = clip_results.softmax(dim=-1).cpu().detach().numpy()[0]

if len(confidence) == 0:
return cls(class_id=np.array([]), confidence=np.array([]))

class_ids = np.arange(len(confidence))
return cls(class_id=class_ids, confidence=confidence)

@classmethod
def from_ultralytics(cls, ultralytics_results) -> Classifications:
"""
Creates a Classifications instance from a
(https://github.com/ultralytics/ultralytics) inference result.
[ultralytics](https://github.com/ultralytics/ultralytics) inference result.

Args:
ultralytics_results (ultralytics.engine.results.Results):
Expand Down Expand Up @@ -72,7 +115,7 @@ def from_ultralytics(cls, ultralytics_results) -> Classifications:
def from_timm(cls, timm_results) -> Classifications:
"""
Creates a Classifications instance from a
timm (https://huggingface.co/docs/hub/timm) inference result.
[timm](https://huggingface.co/docs/hub/timm) inference result.

Args:
timm_results: The inference result from timm model.
Expand Down Expand Up @@ -109,7 +152,6 @@ def from_timm(cls, timm_results) -> Classifications:
return cls(class_id=np.array([]), confidence=np.array([]))

class_id = np.arange(len(confidence))

return cls(class_id=class_id, confidence=confidence)

def get_top_k(self, k: int) -> Tuple[np.ndarray, np.ndarray]:
Expand Down