diff --git a/test/common_utils.py b/test/common_utils.py index 74ad31fea72..4121ae1ffd4 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -20,7 +20,7 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import to_image, to_pil_image +from torchvision.transforms.v2.functional import to_image, to_nvcv_tensor, to_pil_image from torchvision.utils import _Image_fromarray @@ -400,6 +400,10 @@ def make_image_pil(*args, **kwargs): return to_pil_image(make_image(*args, **kwargs)) +def make_image_nvcv(*args, **kwargs): + return to_nvcv_tensor(make_image(*args, **kwargs)) + + def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"): y = torch.randint(0, canvas_size[0], size=(num_points, 1), dtype=dtype, device=device) x = torch.randint(0, canvas_size[1], size=(num_points, 1), dtype=dtype, device=device) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 0f985ab9604..d650b8c27a6 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -29,6 +29,7 @@ make_bounding_boxes, make_detection_masks, make_image, + make_image_nvcv, make_image_pil, make_image_tensor, make_keypoints, @@ -43,7 +44,7 @@ from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate -from torchvision import tv_tensors +from torchvision import _is_cvcuda_available, tv_tensors from torchvision.ops.boxes import box_iou from torchvision.transforms._functional_tensor import _max_value as get_max_value @@ -54,6 +55,10 @@ from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal +CVCUDA_AVAILABLE = _is_cvcuda_available() +CUDA_AVAILABLE = torch.cuda.is_available() + + # turns all warnings into errors for this module pytestmark = [pytest.mark.filterwarnings("error")] @@ -6733,6 +6738,150 @@ def test_functional_error(self): F.pil_to_tensor(object()) +@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="test requires CUDA") +class TestToNVCVTensor: + """Tests for to_nvcv_tensor function following patterns from TestToPil""" + + def test_1_channel_uint8_tensor_to_nvcv_tensor(self): + img_data = torch.randint(0, 256, (1, 4, 4), dtype=torch.uint8, device="cuda") + nvcv_img = F.to_nvcv_tensor(img_data) + # Check that the conversion succeeded and format is correct + assert nvcv_img is not None + + def test_1_channel_int16_tensor_to_nvcv_tensor(self): + img_data = torch.randint(0, 256, (1, 4, 4), dtype=torch.int16, device="cuda") + nvcv_img = F.to_nvcv_tensor(img_data) + assert nvcv_img is not None + + def test_1_channel_int32_tensor_to_nvcv_tensor(self): + img_data = torch.randint(0, 256, (1, 4, 4), dtype=torch.int32, device="cuda") + nvcv_img = F.to_nvcv_tensor(img_data) + assert nvcv_img is not None + + def test_1_channel_float32_tensor_to_nvcv_tensor(self): + img_data = torch.rand(1, 4, 4, device="cuda") + nvcv_img = F.to_nvcv_tensor(img_data) + assert nvcv_img is not None + + def test_3_channel_uint8_tensor_to_nvcv_tensor(self): + img_data = torch.randint(0, 256, (3, 4, 4), dtype=torch.uint8, device="cuda") + nvcv_img = F.to_nvcv_tensor(img_data) + assert nvcv_img is not None + + def test_3_channel_float32_tensor_to_nvcv_tensor(self): + img_data = torch.rand(3, 4, 4, device="cuda") + nvcv_img = F.to_nvcv_tensor(img_data) + assert nvcv_img is not None + + def test_2d_uint8_tensor_to_nvcv_tensor(self): + img_data = torch.randint(0, 256, (4, 4), dtype=torch.uint8, device="cuda") + nvcv_img = F.to_nvcv_tensor(img_data) + assert nvcv_img is not None + + def test_2d_float32_tensor_to_nvcv_tensor(self): + img_data = torch.rand(4, 4, device="cuda") + nvcv_img = F.to_nvcv_tensor(img_data) + assert nvcv_img is not None + + def test_unsupported_num_channels(self): + # Test 2-channel image (CHW format: 2 channels x 5 height x 5 width) + img_data = torch.rand(2, 5, 5, device="cuda") + with pytest.raises(ValueError, match="Only 1 and 3 channel images are supported"): + F.to_nvcv_tensor(img_data) + + # Test 4-channel image (CHW format: 4 channels x 5 height x 5 width) + img_data = torch.randint(0, 256, (4, 5, 5), dtype=torch.uint8, device="cuda") + with pytest.raises(ValueError, match="Only 1 and 3 channel images are supported"): + F.to_nvcv_tensor(img_data) + + # Test 5-channel image (CHW format: 5 channels x 5 height x 5 width) + img_data = torch.randint(0, 256, (5, 5, 5), dtype=torch.uint8, device="cuda") + with pytest.raises(ValueError, match="Only 1 and 3 channel images are supported"): + F.to_nvcv_tensor(img_data) + + def test_invalid_input_type(self): + with pytest.raises(TypeError, match=r"pic should be `torch.Tensor`"): + F.to_nvcv_tensor("invalid_input") + + def test_invalid_dimensions(self): + # Test 1D array (too few dimensions) + with pytest.raises(ValueError, match=r"pic should be 2/3/4 dimensional"): + F.to_nvcv_tensor(torch.randint(0, 256, (4,), dtype=torch.uint8, device="cuda")) + + # Test 5D array (too many dimensions) + with pytest.raises(ValueError, match=r"pic should be 2/3/4 dimensional"): + F.to_nvcv_tensor(torch.randint(0, 256, (1, 1, 3, 4, 4), dtype=torch.uint8, device="cuda")) + + def test_unsupported_dtype_for_channels(self): + # Float64 is not supported + img_data = torch.rand(3, 4, 4, dtype=torch.float64, device="cuda") + with pytest.raises(TypeError, match=r"Unsupported dtype"): + F.to_nvcv_tensor(img_data) + + @pytest.mark.parametrize("num_channels", [1, 3]) + @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) + def test_round_trip(self, num_channels, dtype): + # Setup: Create a tensor in CHW format (PyTorch standard) + if dtype == torch.uint8: + original_tensor = torch.randint(0, 256, (num_channels, 4, 4), dtype=dtype, device="cuda") + else: + original_tensor = torch.rand(num_channels, 4, 4, dtype=dtype, device="cuda") + + # Execute: Convert to NVCV and back to tensor + # CHW -> (to_nvcv_tensor) -> NVCV NHWC -> (nvcv_to_tensor) -> CHW + nvcv_tensor = F.to_nvcv_tensor(original_tensor) + result_tensor = F.nvcv_to_tensor(nvcv_tensor) + + # Assert: The round-trip conversion preserves the original tensor + # Use allclose for robust comparison that handles floating-point precision + assert torch.allclose(result_tensor, original_tensor, rtol=1e-5, atol=1e-7) + + @pytest.mark.parametrize("num_channels", [1, 3]) + @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) + @pytest.mark.parametrize("batch_size", [1, 2, 4]) + def test_round_trip_batched(self, num_channels, dtype, batch_size): + # Setup: Create a batched tensor in NCHW format + if dtype == torch.uint8: + original_tensor = torch.randint(0, 256, (batch_size, num_channels, 4, 4), dtype=dtype, device="cuda") + else: + original_tensor = torch.rand(batch_size, num_channels, 4, 4, dtype=dtype, device="cuda") + + # Execute: Convert to NVCV and back to tensor + # NCHW -> (to_nvcv_tensor) -> NVCV NHWC -> (nvcv_to_tensor) -> NCHW + nvcv_tensor = F.to_nvcv_tensor(original_tensor) + result_tensor = F.nvcv_to_tensor(nvcv_tensor) + + # Assert: The round-trip conversion preserves the original batched tensor + # Use allclose for robust comparison that handles floating-point precision + assert torch.allclose(result_tensor, original_tensor, rtol=1e-5, atol=1e-7) + # Also verify batch size is preserved + assert result_tensor.shape[0] == batch_size + + +@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="test requires CUDA") +class TestNVCVToTensor: + @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) + @pytest.mark.parametrize( + "fn", + [F.nvcv_to_tensor, transform_cls_to_functional(transforms.NVCVToTensor)], + ) + def test_functional_and_transform(self, color_space, fn): + input = make_image_nvcv(color_space=color_space) + + output = fn(input) + + assert isinstance(output, torch.Tensor) + # Convert input to tensor to compare sizes + input_tensor = F.nvcv_to_tensor(input) + assert F.get_size(output) == F.get_size(input_tensor) + + def test_functional_error(self): + with pytest.raises(TypeError, match="nvcv_img should be `nvcv.Tensor`"): + F.nvcv_to_tensor(object()) + + class TestLambda: @pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0]) @pytest.mark.parametrize("types", [(), (torch.Tensor, np.ndarray)]) diff --git a/torchvision/__init__.py b/torchvision/__init__.py index 5d06156c25f..a0804d6ca8c 100644 --- a/torchvision/__init__.py +++ b/torchvision/__init__.py @@ -99,6 +99,16 @@ def _is_tracing(): return torch._C._get_tracing_state() +def _is_cvcuda_available() -> bool: + try: + import cvcuda # type: ignore[import-not-found] + import nvcv # type: ignore[import-not-found] + + return True + except ImportError: + return False + + def disable_beta_transforms_warning(): # Noop, only exists to avoid breaking existing code. # See https://github.com/pytorch/vision/issues/7896 diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 895bf6e2f71..3797ec92988 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -55,7 +55,7 @@ ToDtype, ) from ._temporal import UniformTemporalSubsample -from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor +from ._type_conversion import NVCVToTensor, PILToTensor, ToImage, ToNVCVTensor, ToPILImage, ToPureTensor from ._utils import check_type, get_bounding_boxes, get_keypoints, has_all, has_any, query_chw, query_size from ._deprecated import ToTensor # usort: skip diff --git a/torchvision/transforms/v2/_type_conversion.py b/torchvision/transforms/v2/_type_conversion.py index 7cac62868b9..ccd50741c61 100644 --- a/torchvision/transforms/v2/_type_conversion.py +++ b/torchvision/transforms/v2/_type_conversion.py @@ -6,8 +6,8 @@ from torchvision import tv_tensors from torchvision.transforms.v2 import functional as F, Transform - from torchvision.transforms.v2._utils import is_pure_tensor +from torchvision.utils import _log_api_usage_once class PILToTensor(Transform): @@ -90,3 +90,69 @@ class ToPureTensor(Transform): def transform(self, inpt: Any, params: dict[str, Any]) -> torch.Tensor: return inpt.as_subclass(torch.Tensor) + + +class ToNVCVTensor: + """Convert a torch.Tensor to nvcv.Tensor + + This transform does not support torchscript. + + Converts a torch.*Tensor of shape C x H x W to a nvcv.Tensor. + Only 1-channel and 3-channel images are supported. + """ + + def __init__(self): + _log_api_usage_once(self) + + def __call__(self, pic): + """ + Args: + pic (torch.Tensor): Image to be converted to nvcv.Tensor. + + Returns: + nvcv.Tensor: Image converted to nvcv.Tensor. + + """ + return F.to_nvcv_tensor(pic) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + +class NVCVToTensor: + """Convert a `nvcv.Tensor` to a `torch.Tensor` of the same type - this does not scale values. + + This transform does not support torchscript. + + Converts a `nvcv.Tensor` to a `torch.Tensor`. Supports both batched and unbatched inputs: + - Unbatched: (H, W, C) or (H, W) → (C, H, W) or (1, H, W) + - Batched: (N, H, W, C) or (N, H, W) → (N, C, H, W) or (N, 1, H, W) + + The conversion happens directly on GPU when the `nvcv.Tensor` is stored on GPU, + avoiding unnecessary data transfers. + + Example: + >>> import nvcv + >>> import torchvision.transforms.v2 as T + >>> # Create an NVCV Image (320x240 RGB) + >>> nvcv_img = nvcv.Image(nvcv.Size2D(320, 240), nvcv.Format.RGB8) + >>> tensor = T.NVCVToTensor()(nvcv_img) + >>> print(tensor.shape) + torch.Size([3, 240, 320]) + """ + + def __init__(self) -> None: + _log_api_usage_once(self) + + def __call__(self, pic): + """ + Args: + pic (nvcv.Image): NVCV Image to be converted to tensor. + + Returns: + Tensor: Converted image in CHW format. + """ + return F.nvcv_to_tensor(pic) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 13fbaa588fe..9cf7a27f87f 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -162,6 +162,6 @@ to_dtype_video, ) from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video -from ._type_conversion import pil_to_tensor, to_image, to_pil_image +from ._type_conversion import nvcv_to_tensor, pil_to_tensor, to_image, to_nvcv_tensor, to_pil_image from ._deprecated import get_image_size, to_tensor # usort: skip diff --git a/torchvision/transforms/v2/functional/_type_conversion.py b/torchvision/transforms/v2/functional/_type_conversion.py index c5a731fe143..9e21ef42a97 100644 --- a/torchvision/transforms/v2/functional/_type_conversion.py +++ b/torchvision/transforms/v2/functional/_type_conversion.py @@ -1,10 +1,14 @@ -from typing import Union +from typing import TYPE_CHECKING, Union import numpy as np import PIL.Image import torch from torchvision import tv_tensors from torchvision.transforms import functional as _F +from torchvision.utils import _log_api_usage_once + +if TYPE_CHECKING: + import nvcv # type: ignore[import-not-found] @torch.jit.unused @@ -25,3 +29,176 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tenso to_pil_image = _F.to_pil_image pil_to_tensor = _F.pil_to_tensor + + +def _infer_nvcv_format(img_tensor: torch.Tensor): + """Infer NVCV format from tensor shape and dtype. + + Args: + img_tensor: Tensor with shape (H, W, C) where C is number of channels. + + Returns: + tuple: (nvcv_format, processed_tensor) where processed_tensor may have reduced dimensions + for single channel images. + + Raises: + TypeError: If dtype is not supported for the given number of channels. + ValueError: If number of channels is not 1 or 3. + """ + import nvcv # type: ignore[import-not-found] + + num_channels = img_tensor.shape[2] + dtype = img_tensor.dtype + + # Handle single channel images + if num_channels == 1: + img_tensor = img_tensor[:, :, 0] + if dtype == torch.uint8: + return nvcv.Format.U8, img_tensor + elif dtype == torch.int16: + return nvcv.Format.S16, img_tensor + elif dtype == torch.int32: + return nvcv.Format.S32, img_tensor + elif dtype == torch.float32: + return nvcv.Format.F32, img_tensor + else: + raise TypeError(f"Unsupported dtype {dtype} for single channel image") + + # Handle 3 channel images (defaults to RGB) + elif num_channels == 3: + if dtype == torch.uint8: + return nvcv.Format.RGB8, img_tensor + elif dtype == torch.float32: + return nvcv.Format.RGBf32, img_tensor + else: + raise TypeError(f"Unsupported dtype {dtype} for 3-channel image") + + raise ValueError(f"Only 1 and 3 channel images are supported. Got {num_channels} channels.") + + +@torch.jit.unused +def to_nvcv_tensor(pic) -> "nvcv.Tensor": + """Convert a torch.Tensor to nvcv.Tensor. This function does not support torchscript. + + See :class:`~torchvision.transforms.v2.ToNVCVTensor` for more details. + + Args: + pic (torch.Tensor): Image to be converted to nvcv.Tensor. + Tensor can be in CHW format (unbatched) or NCHW format (batched). + Only 1-channel and 3-channel images are supported. + + Returns: + nvcv.Tensor: Image converted to nvcv.Tensor with NHWC layout. + """ + import cvcuda # type: ignore[import-not-found] + import nvcv # type: ignore[import-not-found] + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(to_nvcv_tensor) + + # Validate input type + if not isinstance(pic, torch.Tensor): + raise TypeError(f"pic should be `torch.Tensor`. Got {type(pic)}.") + + # Handle different tensor formats and track if input was batched (NCHW) or unbatched (CHW/HW) + if pic.ndim == 4: + # Batched tensor in NCHW format, permute to NHWC + img_tensor = pic.permute(0, 2, 3, 1) + input_was_batched = True + elif pic.ndim == 3: + # Unbatched tensor in CHW format, permute to HWC + img_tensor = pic.permute(1, 2, 0) + input_was_batched = False + else: + # 2D or other formats (unbatched single-channel) + img_tensor = pic + input_was_batched = False + + # Ensure image has channel dimension for unbatched case + if img_tensor.ndim == 2: + img_tensor = img_tensor.unsqueeze(2) # H W -> H W C + + # Validate dimensions + if img_tensor.ndim not in (3, 4): + raise ValueError(f"pic should be 2/3/4 dimensional. Got {img_tensor.ndim} dimensions.") + + # For batched inputs, use the first image to infer format + sample_img = img_tensor[0] if img_tensor.ndim == 4 else img_tensor + _, sample_img = _infer_nvcv_format(sample_img) + + # If format inference modified the tensor (e.g., removed channel dimension for single channel) + # apply the same transformation to all images + if sample_img.ndim == 2 and img_tensor.ndim == 4: + # Batched single channel case: remove channel dimension + img_tensor = img_tensor.squeeze(-1) + elif sample_img.ndim == 2 and img_tensor.ndim == 3: + # Unbatched single channel case: replace with 2D tensor + img_tensor = sample_img + + # Add batch dimension if not present (NVCV expects batched tensors) + if not input_was_batched: + img_tensor = img_tensor.unsqueeze(0) # Add batch dimension at index 0 + + # Determine layout based on final tensor shape + # After all transformations, tensor is either NHW (single-channel) or NHWC (multi-channel) + if img_tensor.ndim == 3: + layout = nvcv.TensorLayout.NHW # Batched single-channel + else: # img_tensor.ndim == 4 + layout = nvcv.TensorLayout.NHWC # Batched multi-channel + + # Convert to NVCV tensor with the appropriate layout + return cvcuda.as_tensor(img_tensor.cuda().contiguous(), layout) + + +@torch.jit.unused +def nvcv_to_tensor(nvcv_img: "nvcv.Tensor") -> torch.Tensor: + """Convert a nvcv.Tensor to a PyTorch tensor. This function does not support torchscript. + + Args: + nvcv_img (nvcv.Tensor): nvcv.Tensor to be converted to PyTorch tensor. + Expected to be in NHWC or NHW layout (for batched images) or HWC or HW layout (for unbatched). + + Returns: + torch.Tensor: Converted image in CHW format (unbatched) or NCHW format (batched). + """ + import nvcv # type: ignore[import-not-found] + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(nvcv_to_tensor) + + # Validate input type + if not isinstance(nvcv_img, nvcv.Tensor): + raise TypeError(f"nvcv_img should be `nvcv.Tensor`. Got {type(nvcv_img)}.") + + # Convert NVCV Tensor to PyTorch tensor via CUDA array interface + # NVCV tensors expose __cuda_array_interface__ which PyTorch can consume directly + cuda_tensor = torch.as_tensor(nvcv_img.cuda(), device="cuda") + + # Handle different dimensionalities + # NVCV stores images in NHWC (batched multi-channel), NHW (batched single-channel), + # HWC (unbatched multi-channel), or HW (unbatched single-channel) format + if cuda_tensor.ndim == 4: + # Batched multi-channel image in NHWC format + # Convert NHWC -> NCHW + img = cuda_tensor.permute(0, 3, 1, 2).contiguous() + elif cuda_tensor.ndim == 3: + # Could be either: + # 1. Unbatched multi-channel (HWC) - last dim is 1 or 3 + # 2. Batched single-channel (NHW) - last dim is width + # We distinguish by checking if last dimension is 1 or 3 (our supported channel counts) + if cuda_tensor.shape[2] in (1, 3): + # Unbatched multi-channel image in HWC format + # Convert HWC -> CHW + img = cuda_tensor.permute(2, 0, 1).contiguous() + else: + # Batched single-channel image in NHW format + # Convert NHW -> NCHW by adding channel dimension + img = cuda_tensor.unsqueeze(1).contiguous() + elif cuda_tensor.ndim == 2: + # Unbatched single-channel image in HW format + # Convert HW -> CHW by adding channel dimension + img = cuda_tensor.unsqueeze(0).contiguous() + else: + raise ValueError(f"Image should be 2/3/4 dimensional. Got {cuda_tensor.ndim} dimensions.") + + return img