Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions test/test_cvcuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import pytest
import torch
from torchvision import _is_cvcuda_available
from torchvision.transforms.v2 import functional as F

CVCUDA_AVAILABLE = _is_cvcuda_available()
CUDA_AVAILABLE = torch.cuda.is_available()


if CVCUDA_AVAILABLE:
import nvcv


@pytest.mark.skipif(CVCUDA_AVAILABLE is False, reason="test requires CVCUDA")
@pytest.mark.skipif(CUDA_AVAILABLE is False, reason="test requires CUDA")
class TestToNvcvTensor:
"""Tests for to_nvcv_tensor function following patterns from TestToPil"""

def test_1_channel_uint8_tensor_to_nvcv_tensor(self):
img_data = torch.ByteTensor(1, 4, 4).random_(0, 255).cuda()
nvcv_img = F.to_nvcv_tensor(img_data)
# Check that the conversion succeeded and format is correct
assert nvcv_img is not None

def test_1_channel_int16_tensor_to_nvcv_tensor(self):
img_data = torch.ShortTensor(1, 4, 4).random_().cuda()
nvcv_img = F.to_nvcv_tensor(img_data)
assert nvcv_img is not None

def test_1_channel_int32_tensor_to_nvcv_tensor(self):
img_data = torch.IntTensor(1, 4, 4).random_().cuda()
nvcv_img = F.to_nvcv_tensor(img_data)
assert nvcv_img is not None

def test_1_channel_float32_tensor_to_nvcv_tensor(self):
img_data = torch.Tensor(1, 4, 4).uniform_().cuda()
nvcv_img = F.to_nvcv_tensor(img_data)
assert nvcv_img is not None

def test_2_channel_uint8_tensor_to_nvcv_tensor(self):
img_data = torch.ByteTensor(2, 4, 4).random_(0, 255).cuda()
# NVCV doesn't support 2-channel uint8 images
with pytest.raises(TypeError, match="Unsupported dtype.*for 2-channel image"):
F.to_nvcv_tensor(img_data)

def test_2_channel_float32_tensor_to_nvcv_tensor(self):
img_data = torch.Tensor(2, 4, 4).uniform_().cuda()
nvcv_img = F.to_nvcv_tensor(img_data)
assert nvcv_img is not None

def test_3_channel_uint8_tensor_to_nvcv_tensor(self):
img_data = torch.ByteTensor(3, 4, 4).random_(0, 255).cuda()
nvcv_img = F.to_nvcv_tensor(img_data)
assert nvcv_img is not None

def test_3_channel_float32_tensor_to_nvcv_tensor(self):
img_data = torch.Tensor(3, 4, 4).uniform_().cuda()
nvcv_img = F.to_nvcv_tensor(img_data)
assert nvcv_img is not None

def test_4_channel_uint8_tensor_to_nvcv_tensor(self):
img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).cuda()
nvcv_img = F.to_nvcv_tensor(img_data)
assert nvcv_img is not None

def test_4_channel_float32_tensor_to_nvcv_tensor(self):
img_data = torch.Tensor(4, 4, 4).uniform_().cuda()
nvcv_img = F.to_nvcv_tensor(img_data)
assert nvcv_img is not None

def test_2d_uint8_tensor_to_nvcv_tensor(self):
img_data = torch.ByteTensor(4, 4).random_(0, 255).cuda()
nvcv_img = F.to_nvcv_tensor(img_data)
assert nvcv_img is not None

def test_2d_float32_tensor_to_nvcv_tensor(self):
img_data = torch.Tensor(4, 4).uniform_().cuda()
nvcv_img = F.to_nvcv_tensor(img_data)
assert nvcv_img is not None

def test_1_channel_uint8_ndarray_to_nvcv_tensor(self):
img_data = torch.ByteTensor(4, 4, 1).random_(0, 255).numpy()
nvcv_img = F.to_nvcv_tensor(img_data)
assert nvcv_img is not None

def test_3_channel_uint8_ndarray_to_nvcv_tensor(self):
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
nvcv_img = F.to_nvcv_tensor(img_data)
assert nvcv_img is not None

def test_4_channel_uint8_ndarray_to_nvcv_tensor(self):
img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy()
nvcv_img = F.to_nvcv_tensor(img_data)
assert nvcv_img is not None

def test_explicit_format_rgb8(self):
img_data = torch.ByteTensor(3, 4, 4).random_(0, 255).cuda()
nvcv_img = F.to_nvcv_tensor(img_data, format=nvcv.Format.RGB8)
assert nvcv_img is not None

def test_explicit_format_bgr8(self):
img_data = torch.ByteTensor(3, 4, 4).random_(0, 255).cuda()
nvcv_img = F.to_nvcv_tensor(img_data, format=nvcv.Format.BGR8)
assert nvcv_img is not None

def test_explicit_format_hsv8(self):
img_data = torch.ByteTensor(3, 4, 4).random_(0, 255).cuda()
# HSV8 should work for 3-channel images
nvcv_img = F.to_nvcv_tensor(img_data, format=nvcv.Format.HSV8)
assert nvcv_img is not None

def test_explicit_format_rgba8(self):
img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).cuda()
nvcv_img = F.to_nvcv_tensor(img_data, format=nvcv.Format.RGBA8)
assert nvcv_img is not None

def test_explicit_format_bgra8(self):
img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).cuda()
# BGRA8 should work for 4-channel images
nvcv_img = F.to_nvcv_tensor(img_data, format=nvcv.Format.BGRA8)
assert nvcv_img is not None

def test_invalid_input_type(self):
with pytest.raises(TypeError, match=r"pic should be Tensor or ndarray"):
F.to_nvcv_tensor("invalid_input")

def test_invalid_dimensions(self):
# Test 1D array (too few dimensions)
with pytest.raises(ValueError, match=r"pic should be 2/3/4 dimensional"):
F.to_nvcv_tensor(torch.ByteTensor(4).cuda())

# Test 5D array (too many dimensions)
with pytest.raises(ValueError, match=r"pic should be 2/3/4 dimensional"):
F.to_nvcv_tensor(torch.ByteTensor(1, 1, 3, 4, 4).cuda())

def test_too_many_channels(self):
with pytest.raises(ValueError, match=r"pic should not have > 4 channels"):
F.to_nvcv_tensor(torch.ByteTensor(5, 4, 4).random_(0, 255).cuda())

def test_unsupported_dtype_for_channels(self):
# Float64 is not supported
img_data = torch.DoubleTensor(3, 4, 4).uniform_().cuda()
with pytest.raises(TypeError, match=r"Unsupported dtype"):
F.to_nvcv_tensor(img_data)


def make_nvcv_image(num_channels=3, dtype=torch.uint8):
"""Helper function to create NVCV Tensor for testing"""
if dtype == torch.uint8:
img_data = torch.ByteTensor(num_channels, 4, 4).random_(0, 255).cuda()
else:
img_data = torch.Tensor(num_channels, 4, 4).uniform_().cuda()
return F.to_nvcv_tensor(img_data)


def transform_cls_to_functional(get_transform_cls):
def wrapper(inpt):
transform_cls = get_transform_cls()
return transform_cls()(inpt)

return wrapper


@pytest.mark.skipif(CVCUDA_AVAILABLE is False, reason="test requires CVCUDA")
@pytest.mark.skipif(CUDA_AVAILABLE is False, reason="test requires CUDA")
class TestNVCVToTensor:
@pytest.mark.parametrize("num_channels", [1, 3, 4])
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
@pytest.mark.parametrize(
"fn",
[
F.nvcv_to_tensor,
transform_cls_to_functional(
lambda: __import__("torchvision.transforms.v2", fromlist=["NVCVToTensor"]).NVCVToTensor
),
],
)
def test_functional_and_transform(self, num_channels, dtype, fn):
input = make_nvcv_image(num_channels=num_channels, dtype=dtype)
output = fn(input)

assert isinstance(output, torch.Tensor)
# Convert input to tensor to compare sizes
input_tensor = F.nvcv_to_tensor(input)
assert F.get_size(output) == F.get_size(input_tensor)

def test_functional_error(self):
with pytest.raises(TypeError, match="nvcv_img should be NVCV Tensor"):
F.nvcv_to_tensor(object())
9 changes: 9 additions & 0 deletions torchvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ def _is_tracing():
return torch._C._get_tracing_state()


def _is_cvcuda_available() -> bool:
try:
import cvcuda # type: ignore[import-not-found]
import nvcv # type: ignore[import-not-found]
except ImportError:
return False
return True


def disable_beta_transforms_warning():
# Noop, only exists to avoid breaking existing code.
# See https://github.com/pytorch/vision/issues/7896
Expand Down
4 changes: 4 additions & 0 deletions torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,7 @@
from ._utils import check_type, get_bounding_boxes, get_keypoints, has_all, has_any, query_chw, query_size

from ._deprecated import ToTensor # usort: skip
from torchvision import _is_cvcuda_available

if _is_cvcuda_available():
from ._cvcuda import NVCVToTensor, ToNVCVTensor
87 changes: 87 additions & 0 deletions torchvision/transforms/v2/_cvcuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from torchvision.transforms.v2 import functional as F
from torchvision.utils import _log_api_usage_once


class ToNVCVTensor:
"""Convert a tensor or an ndarray to NVCV Tensor

This transform does not support torchscript.

Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to an NVCV Tensor.

Args:
format (`nvcv.Format`_): color format specification from nvcv.Format enum (optional).
If ``format`` is ``None`` (default) the format is inferred from the input data:

- **1 channel images**: Inferred based on dtype
- uint8 → U8, int16 → S16, int32 → S32, float32 → F32
- **2 channel images**: float32 → _2F32 (only float32 is supported for 2-channel images)
- **3 channel images**: Defaults to RGB-based formats
- uint8 → RGB8, float32 → RGBf32
- **4 channel images**: Defaults to RGBA-based formats
- uint8 → RGBA8, float32 → RGBAf32

Explicit format examples: nvcv.Format.RGB8, nvcv.Format.BGR8, nvcv.Format.HSV8,
nvcv.Format.RGBA8, nvcv.Format.BGRA8

.. _nvcv.Format: https://cvcuda.github.io/CV-CUDA/_python_api/nvcv/format.html
"""

def __init__(self, format=None):
_log_api_usage_once(self)
self.format = format

def __call__(self, pic):
"""
Args:
pic (Tensor or numpy.ndarray): Image to be converted to NVCV Tensor.

Returns:
NVCV Tensor: Image converted to NVCV Tensor.

"""
return F.to_nvcv_tensor(pic, self.format)

def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
if self.format is not None:
format_string += f"format={self.format}"
format_string += ")"
return format_string


class NVCVToTensor:
"""Convert an NVCV Image to a tensor of the same type - this does not scale values.

This transform does not support torchscript.

Converts an NVCV Image with H height, W width, and C channels to a PyTorch Tensor
of shape (C x H x W). The conversion happens directly on GPU when the NVCV Image
is stored on GPU, avoiding unnecessary data transfers.

Example:
>>> import nvcv
>>> import torchvision.transforms.v2 as T
>>> # Create an NVCV Image (320x240 RGB)
>>> nvcv_img = nvcv.Image(nvcv.Size2D(320, 240), nvcv.Format.RGB8)
>>> tensor = T.NVCVToTensor()(nvcv_img)
>>> print(tensor.shape)
torch.Size([3, 240, 320])
"""

def __init__(self) -> None:
_log_api_usage_once(self)

def __call__(self, pic):
"""
Args:
pic (nvcv.Image): NVCV Image to be converted to tensor.

Returns:
Tensor: Converted image in CHW format.
"""
return F.nvcv_to_tensor(pic)

def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
5 changes: 5 additions & 0 deletions torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,8 @@
from ._type_conversion import pil_to_tensor, to_image, to_pil_image

from ._deprecated import get_image_size, to_tensor # usort: skip

from torchvision import _is_cvcuda_available

if _is_cvcuda_available():
from ._cvcuda import nvcv_to_tensor, to_nvcv_tensor
Loading
Loading