Skip to content

Commit d63504c

Browse files
Dan-FloresDaniel Flores
andauthored
Create Python API for VideoEncoder (#990)
Co-authored-by: Daniel Flores <[email protected]>
1 parent 44ae3d5 commit d63504c

File tree

5 files changed

+207
-65
lines changed

5 files changed

+207
-65
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,6 @@ torch::Tensor validateFrames(const torch::Tensor& frames) {
531531
frames.sizes()[1] == 3,
532532
"frame must have 3 channels (R, G, B), got ",
533533
frames.sizes()[1]);
534-
// TODO-VideoEncoder: Investigate if non-contiguous frames can be accepted
535534
return frames.contiguous();
536535
}
537536

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from ._audio_encoder import AudioEncoder # noqa
2+
from ._video_encoder import VideoEncoder # noqa
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from pathlib import Path
2+
from typing import Union
3+
4+
import torch
5+
from torch import Tensor
6+
7+
from torchcodec import _core
8+
9+
10+
class VideoEncoder:
11+
"""A video encoder.
12+
13+
Args:
14+
frames (``torch.Tensor``): The frames to encode. This must be a 4D
15+
tensor of shape ``(N, C, H, W)`` where N is the number of frames,
16+
C is 3 channels (RGB), H is height, and W is width.
17+
Values must be uint8 in the range ``[0, 255]``.
18+
frame_rate (int): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate.
19+
"""
20+
21+
def __init__(self, frames: Tensor, *, frame_rate: int):
22+
torch._C._log_api_usage_once("torchcodec.encoders.VideoEncoder")
23+
if not isinstance(frames, Tensor):
24+
raise ValueError(f"Expected frames to be a Tensor, got {type(frames) = }.")
25+
if frames.ndim != 4:
26+
raise ValueError(f"Expected 4D frames, got {frames.shape = }.")
27+
if frames.dtype != torch.uint8:
28+
raise ValueError(f"Expected uint8 frames, got {frames.dtype = }.")
29+
if frame_rate <= 0:
30+
raise ValueError(f"{frame_rate = } must be > 0.")
31+
32+
self._frames = frames
33+
self._frame_rate = frame_rate
34+
35+
def to_file(
36+
self,
37+
dest: Union[str, Path],
38+
) -> None:
39+
"""Encode frames into a file.
40+
41+
Args:
42+
dest (str or ``pathlib.Path``): The path to the output file, e.g.
43+
``video.mp4``. The extension of the file determines the video
44+
container format.
45+
"""
46+
_core.encode_video_to_file(
47+
frames=self._frames,
48+
frame_rate=self._frame_rate,
49+
filename=str(dest),
50+
)
51+
52+
def to_tensor(
53+
self,
54+
format: str,
55+
) -> Tensor:
56+
"""Encode frames into raw bytes, as a 1D uint8 Tensor.
57+
58+
Args:
59+
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
60+
"mkv", "avi", "webm", "flv", or "gif"
61+
62+
Returns:
63+
Tensor: The raw encoded bytes as 4D uint8 Tensor.
64+
"""
65+
return _core.encode_video_to_tensor(
66+
frames=self._frames,
67+
frame_rate=self._frame_rate,
68+
format=format,
69+
)
70+
71+
def to_file_like(
72+
self,
73+
file_like,
74+
format: str,
75+
) -> None:
76+
"""Encode frames into a file-like object.
77+
78+
Args:
79+
file_like: A file-like object that supports ``write()`` and
80+
``seek()`` methods, such as io.BytesIO(), an open file in binary
81+
write mode, etc. Methods must have the following signature:
82+
``write(data: bytes) -> int`` and ``seek(offset: int, whence:
83+
int = 0) -> int``.
84+
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
85+
"mkv", "avi", "webm", "flv", or "gif".
86+
"""
87+
_core.encode_video_to_file_like(
88+
frames=self._frames,
89+
frame_rate=self._frame_rate,
90+
format=format,
91+
file_like=file_like,
92+
)

test/test_encoders.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
from torchcodec.decoders import AudioDecoder
1313

14-
from torchcodec.encoders import AudioEncoder
14+
from torchcodec.encoders import AudioEncoder, VideoEncoder
1515

1616
from .utils import (
1717
assert_tensor_close_on_at_least,
@@ -564,3 +564,115 @@ def write(self, data):
564564
RuntimeError, match="File like object must implement a seek method"
565565
):
566566
encoder.to_file_like(NoSeekMethod(), format="wav")
567+
568+
569+
class TestVideoEncoder:
570+
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
571+
def test_bad_input_parameterized(self, tmp_path, method):
572+
if method == "to_file":
573+
valid_params = dict(dest=str(tmp_path / "output.mp4"))
574+
elif method == "to_tensor":
575+
valid_params = dict(format="mp4")
576+
elif method == "to_file_like":
577+
valid_params = dict(file_like=io.BytesIO(), format="mp4")
578+
else:
579+
raise ValueError(f"Unknown method: {method}")
580+
581+
with pytest.raises(
582+
ValueError, match="Expected uint8 frames, got frames.dtype = torch.float32"
583+
):
584+
encoder = VideoEncoder(
585+
frames=torch.rand(5, 3, 64, 64),
586+
frame_rate=30,
587+
)
588+
getattr(encoder, method)(**valid_params)
589+
590+
with pytest.raises(
591+
ValueError, match=r"Expected 4D frames, got frames.shape = torch.Size"
592+
):
593+
encoder = VideoEncoder(
594+
frames=torch.zeros(10),
595+
frame_rate=30,
596+
)
597+
getattr(encoder, method)(**valid_params)
598+
599+
with pytest.raises(
600+
RuntimeError, match=r"frame must have 3 channels \(R, G, B\), got 2"
601+
):
602+
encoder = VideoEncoder(
603+
frames=torch.zeros((5, 2, 64, 64), dtype=torch.uint8),
604+
frame_rate=30,
605+
)
606+
getattr(encoder, method)(**valid_params)
607+
608+
def test_bad_input(self, tmp_path):
609+
encoder = VideoEncoder(
610+
frames=torch.zeros((5, 3, 64, 64), dtype=torch.uint8),
611+
frame_rate=30,
612+
)
613+
614+
with pytest.raises(
615+
RuntimeError,
616+
match=r"Couldn't allocate AVFormatContext. The destination file is ./file.bad_extension, check the desired extension\?",
617+
):
618+
encoder.to_file("./file.bad_extension")
619+
620+
with pytest.raises(
621+
RuntimeError,
622+
match=r"avio_open failed. The destination file is ./bad/path.mp3, make sure it's a valid path\?",
623+
):
624+
encoder.to_file("./bad/path.mp3")
625+
626+
with pytest.raises(
627+
RuntimeError,
628+
match=r"Couldn't allocate AVFormatContext. Check the desired format\? Got format=bad_format",
629+
):
630+
encoder.to_tensor(format="bad_format")
631+
632+
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
633+
def test_contiguity(self, method, tmp_path):
634+
# Ensure that 2 sets of video frames with the same pixel values are encoded
635+
# in the same way, regardless of their memory layout. Here we encode 2 equal
636+
# frame tensors, one is contiguous while the other is non-contiguous.
637+
638+
num_frames, channels, height, width = 5, 3, 64, 64
639+
contiguous_frames = torch.randint(
640+
0, 256, size=(num_frames, channels, height, width), dtype=torch.uint8
641+
).contiguous()
642+
assert contiguous_frames.is_contiguous()
643+
644+
# Permute NCHW to NHWC, then update the memory layout, then permute back
645+
non_contiguous_frames = (
646+
contiguous_frames.permute(0, 2, 3, 1).contiguous().permute(0, 3, 1, 2)
647+
)
648+
assert non_contiguous_frames.stride() != contiguous_frames.stride()
649+
assert not non_contiguous_frames.is_contiguous()
650+
assert non_contiguous_frames.is_contiguous(memory_format=torch.channels_last)
651+
652+
torch.testing.assert_close(
653+
contiguous_frames, non_contiguous_frames, rtol=0, atol=0
654+
)
655+
656+
def encode_to_tensor(frames):
657+
if method == "to_file":
658+
dest = str(tmp_path / "output.mp4")
659+
VideoEncoder(frames, frame_rate=30).to_file(dest=dest)
660+
with open(dest, "rb") as f:
661+
return torch.frombuffer(f.read(), dtype=torch.uint8)
662+
elif method == "to_tensor":
663+
return VideoEncoder(frames, frame_rate=30).to_tensor(format="mp4")
664+
elif method == "to_file_like":
665+
file_like = io.BytesIO()
666+
VideoEncoder(frames, frame_rate=30).to_file_like(
667+
file_like, format="mp4"
668+
)
669+
return torch.frombuffer(file_like.getvalue(), dtype=torch.uint8)
670+
else:
671+
raise ValueError(f"Unknown method: {method}")
672+
673+
encoded_from_contiguous = encode_to_tensor(contiguous_frames)
674+
encoded_from_non_contiguous = encode_to_tensor(non_contiguous_frames)
675+
676+
torch.testing.assert_close(
677+
encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0
678+
)

test/test_ops.py

Lines changed: 1 addition & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,68 +1152,6 @@ def test_bad_input(self, tmp_path):
11521152

11531153

11541154
class TestVideoEncoderOps:
1155-
# TODO-VideoEncoder: Test encoding against different memory layouts (ex. test_contiguity)
1156-
# TODO-VideoEncoder: Parametrize test after moving to test_encoders
1157-
def test_bad_input(self, tmp_path):
1158-
output_file = str(tmp_path / ".mp4")
1159-
1160-
with pytest.raises(
1161-
RuntimeError, match="frames must have uint8 dtype, got float"
1162-
):
1163-
encode_video_to_file(
1164-
frames=torch.rand((10, 3, 60, 60), dtype=torch.float),
1165-
frame_rate=10,
1166-
filename=output_file,
1167-
)
1168-
1169-
with pytest.raises(
1170-
RuntimeError, match=r"frames must have 4 dimensions \(N, C, H, W\), got 3"
1171-
):
1172-
encode_video_to_file(
1173-
frames=torch.randint(high=1, size=(3, 60, 60), dtype=torch.uint8),
1174-
frame_rate=10,
1175-
filename=output_file,
1176-
)
1177-
1178-
with pytest.raises(
1179-
RuntimeError, match=r"frame must have 3 channels \(R, G, B\), got 2"
1180-
):
1181-
encode_video_to_file(
1182-
frames=torch.randint(high=1, size=(10, 2, 60, 60), dtype=torch.uint8),
1183-
frame_rate=10,
1184-
filename=output_file,
1185-
)
1186-
1187-
with pytest.raises(
1188-
RuntimeError,
1189-
match=r"Couldn't allocate AVFormatContext. The destination file is ./file.bad_extension, check the desired extension\?",
1190-
):
1191-
encode_video_to_file(
1192-
frames=torch.randint(high=255, size=(10, 3, 60, 60), dtype=torch.uint8),
1193-
frame_rate=10,
1194-
filename="./file.bad_extension",
1195-
)
1196-
1197-
with pytest.raises(
1198-
RuntimeError,
1199-
match=r"avio_open failed. The destination file is ./bad/path.mp3, make sure it's a valid path\?",
1200-
):
1201-
encode_video_to_file(
1202-
frames=torch.randint(high=255, size=(10, 3, 60, 60), dtype=torch.uint8),
1203-
frame_rate=10,
1204-
filename="./bad/path.mp3",
1205-
)
1206-
1207-
with pytest.raises(
1208-
RuntimeError,
1209-
match=r"Couldn't allocate AVFormatContext. Check the desired format\? Got format=bad_format",
1210-
):
1211-
encode_video_to_tensor(
1212-
frames=torch.randint(high=255, size=(10, 3, 60, 60), dtype=torch.uint8),
1213-
frame_rate=10,
1214-
format="bad_format",
1215-
)
1216-
12171155
def decode(self, source=None) -> torch.Tensor:
12181156
return VideoDecoder(source).get_frames_in_range(start=0, stop=60)
12191157

@@ -1406,7 +1344,7 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
14061344
)
14071345

14081346
def test_to_file_like_custom_file_object(self):
1409-
"""Test with a custom file-like object that implements write and seek."""
1347+
"""Test to_file_like with a custom file-like object that implements write and seek."""
14101348

14111349
class CustomFileObject:
14121350
def __init__(self):

0 commit comments

Comments
 (0)