diff --git a/src/torchcodec/_core/AVIOTensorContext.cpp b/src/torchcodec/_core/AVIOTensorContext.cpp index 3f45f5be5..238475761 100644 --- a/src/torchcodec/_core/AVIOTensorContext.cpp +++ b/src/torchcodec/_core/AVIOTensorContext.cpp @@ -18,15 +18,15 @@ constexpr int64_t MAX_TENSOR_SIZE = 320'000'000; // 320 MB int read(void* opaque, uint8_t* buf, int buf_size) { auto tensorContext = static_cast(opaque); TORCH_CHECK( - tensorContext->current <= tensorContext->data.numel(), - "Tried to read outside of the buffer: current=", - tensorContext->current, + tensorContext->current_pos <= tensorContext->data.numel(), + "Tried to read outside of the buffer: current_pos=", + tensorContext->current_pos, ", size=", tensorContext->data.numel()); int64_t numBytesRead = std::min( static_cast(buf_size), - tensorContext->data.numel() - tensorContext->current); + tensorContext->data.numel() - tensorContext->current_pos); TORCH_CHECK( numBytesRead >= 0, @@ -34,8 +34,8 @@ int read(void* opaque, uint8_t* buf, int buf_size) { numBytesRead, ", size=", tensorContext->data.numel(), - ", current=", - tensorContext->current); + ", current_pos=", + tensorContext->current_pos); if (numBytesRead == 0) { return AVERROR_EOF; @@ -43,9 +43,9 @@ int read(void* opaque, uint8_t* buf, int buf_size) { std::memcpy( buf, - tensorContext->data.data_ptr() + tensorContext->current, + tensorContext->data.data_ptr() + tensorContext->current_pos, numBytesRead); - tensorContext->current += numBytesRead; + tensorContext->current_pos += numBytesRead; return numBytesRead; } @@ -54,7 +54,7 @@ int write(void* opaque, const uint8_t* buf, int buf_size) { auto tensorContext = static_cast(opaque); int64_t bufSize = static_cast(buf_size); - if (tensorContext->current + bufSize > tensorContext->data.numel()) { + if (tensorContext->current_pos + bufSize > tensorContext->data.numel()) { TORCH_CHECK( tensorContext->data.numel() * 2 <= MAX_TENSOR_SIZE, "We tried to allocate an output encoded tensor larger than ", @@ -68,13 +68,17 @@ int write(void* opaque, const uint8_t* buf, int buf_size) { } TORCH_CHECK( - tensorContext->current + bufSize <= tensorContext->data.numel(), + tensorContext->current_pos + bufSize <= tensorContext->data.numel(), "Re-allocation of the output tensor didn't work. ", "This should not happen, please report on TorchCodec bug tracker"); uint8_t* outputTensorData = tensorContext->data.data_ptr(); - std::memcpy(outputTensorData + tensorContext->current, buf, bufSize); - tensorContext->current += bufSize; + std::memcpy(outputTensorData + tensorContext->current_pos, buf, bufSize); + tensorContext->current_pos += bufSize; + // Track the maximum position written so getOutputTensor's narrow() does not + // truncate the file if final seek was backwards + tensorContext->max_pos = + std::max(tensorContext->current_pos, tensorContext->max_pos); return buf_size; } @@ -88,7 +92,7 @@ int64_t seek(void* opaque, int64_t offset, int whence) { ret = tensorContext->data.numel(); break; case SEEK_SET: - tensorContext->current = offset; + tensorContext->current_pos = offset; ret = offset; break; default: @@ -101,7 +105,7 @@ int64_t seek(void* opaque, int64_t offset, int whence) { } // namespace AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data) - : tensorContext_{data, 0} { + : tensorContext_{data, 0, 0} { TORCH_CHECK(data.numel() > 0, "data must not be empty"); TORCH_CHECK(data.is_contiguous(), "data must be contiguous"); TORCH_CHECK(data.scalar_type() == torch::kUInt8, "data must be kUInt8"); @@ -110,14 +114,17 @@ AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data) } AVIOToTensorContext::AVIOToTensorContext() - : tensorContext_{torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}), 0} { + : tensorContext_{ + torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}), + 0, + 0} { createAVIOContext( nullptr, &write, &seek, &tensorContext_, /*isForWriting=*/true); } torch::Tensor AVIOToTensorContext::getOutputTensor() { return tensorContext_.data.narrow( - /*dim=*/0, /*start=*/0, /*length=*/tensorContext_.current); + /*dim=*/0, /*start=*/0, /*length=*/tensorContext_.max_pos); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/AVIOTensorContext.h b/src/torchcodec/_core/AVIOTensorContext.h index 15f97da55..bcd97052b 100644 --- a/src/torchcodec/_core/AVIOTensorContext.h +++ b/src/torchcodec/_core/AVIOTensorContext.h @@ -15,7 +15,8 @@ namespace detail { struct TensorContext { torch::Tensor data; - int64_t current; + int64_t current_pos; + int64_t max_pos; }; } // namespace detail diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 14ef1cb94..1d9c2c089 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -4,10 +4,6 @@ #include "src/torchcodec/_core/Encoder.h" #include "torch/types.h" -extern "C" { -#include -} - namespace facebook::torchcodec { namespace { @@ -542,10 +538,17 @@ torch::Tensor validateFrames(const torch::Tensor& frames) { } // namespace VideoEncoder::~VideoEncoder() { + // TODO-VideoEncoder: Unify destructor with ~AudioEncoder() if (avFormatContext_ && avFormatContext_->pb) { - avio_flush(avFormatContext_->pb); - avio_close(avFormatContext_->pb); - avFormatContext_->pb = nullptr; + if (avFormatContext_->pb->error == 0) { + avio_flush(avFormatContext_->pb); + } + if (!avioContextHolder_) { + if (avFormatContext_->pb->error == 0) { + avio_close(avFormatContext_->pb); + } + avFormatContext_->pb = nullptr; + } } } @@ -581,6 +584,36 @@ VideoEncoder::VideoEncoder( initializeEncoder(videoStreamOptions); } +VideoEncoder::VideoEncoder( + const torch::Tensor& frames, + int frameRate, + std::string_view formatName, + std::unique_ptr avioContextHolder, + const VideoStreamOptions& videoStreamOptions) + : frames_(validateFrames(frames)), + inFrameRate_(frameRate), + avioContextHolder_(std::move(avioContextHolder)) { + setFFmpegLogLevel(); + // Map mkv -> matroska when used as format name + formatName = (formatName == "mkv") ? "matroska" : formatName; + AVFormatContext* avFormatContext = nullptr; + int status = avformat_alloc_output_context2( + &avFormatContext, nullptr, formatName.data(), nullptr); + + TORCH_CHECK( + avFormatContext != nullptr, + "Couldn't allocate AVFormatContext. ", + "Check the desired format? Got format=", + formatName, + ". ", + getFFMPEGErrorStringFromErrorCode(status)); + avFormatContext_.reset(avFormatContext); + + avFormatContext_->pb = avioContextHolder_->getAVIOContext(); + + initializeEncoder(videoStreamOptions); +} + void VideoEncoder::initializeEncoder( const VideoStreamOptions& videoStreamOptions) { const AVCodec* avCodec = @@ -751,6 +784,17 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame( return avFrame; } +torch::Tensor VideoEncoder::encodeToTensor() { + TORCH_CHECK( + avioContextHolder_ != nullptr, + "Cannot encode to tensor, avio tensor context doesn't exist."); + encode(); + auto avioToTensorContext = + dynamic_cast(avioContextHolder_.get()); + TORCH_CHECK(avioToTensorContext != nullptr, "Invalid AVIO context holder."); + return avioToTensorContext->getOutputTensor(); +} + void VideoEncoder::encodeFrame( AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame) { diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 62d30a624..7aff0bdbc 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -141,8 +141,17 @@ class VideoEncoder { std::string_view fileName, const VideoStreamOptions& videoStreamOptions); + VideoEncoder( + const torch::Tensor& frames, + int frameRate, + std::string_view formatName, + std::unique_ptr avioContextHolder, + const VideoStreamOptions& videoStreamOptions); + void encode(); + torch::Tensor encodeToTensor(); + private: void initializeEncoder(const VideoStreamOptions& videoStreamOptions); UniqueAVFrame convertTensorToAVFrame( @@ -167,6 +176,8 @@ class VideoEncoder { int outHeight_ = -1; AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE; + std::unique_ptr avioContextHolder_; + bool encodeWasCalled_ = false; }; diff --git a/src/torchcodec/_core/__init__.py b/src/torchcodec/_core/__init__.py index 24e54af0e..eb8dd9697 100644 --- a/src/torchcodec/_core/__init__.py +++ b/src/torchcodec/_core/__init__.py @@ -26,6 +26,7 @@ encode_audio_to_file_like, encode_audio_to_tensor, encode_video_to_file, + encode_video_to_tensor, get_ffmpeg_library_versions, get_frame_at_index, get_frame_at_pts, diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index f29f33395..94a3fba1b 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -32,12 +32,14 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); m.def( "encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()"); - m.def( - "encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()"); m.def( "encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor"); m.def( "_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()"); + m.def( + "encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()"); + m.def( + "encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def( @@ -498,21 +500,6 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio( return makeOpsAudioFramesOutput(result); } -void encode_video_to_file( - const at::Tensor& frames, - int64_t frame_rate, - std::string_view file_name, - std::optional crf = std::nullopt) { - VideoStreamOptions videoStreamOptions; - videoStreamOptions.crf = crf; - VideoEncoder( - frames, - validateInt64ToInt(frame_rate, "frame_rate"), - file_name, - videoStreamOptions) - .encode(); -} - void encode_audio_to_file( const at::Tensor& samples, int64_t sample_rate, @@ -587,6 +574,38 @@ void _encode_audio_to_file_like( encoder.encode(); } +void encode_video_to_file( + const at::Tensor& frames, + int64_t frame_rate, + std::string_view file_name, + std::optional crf = std::nullopt) { + VideoStreamOptions videoStreamOptions; + videoStreamOptions.crf = crf; + VideoEncoder( + frames, + validateInt64ToInt(frame_rate, "frame_rate"), + file_name, + videoStreamOptions) + .encode(); +} + +at::Tensor encode_video_to_tensor( + const at::Tensor& frames, + int64_t frame_rate, + std::string_view format, + std::optional crf = std::nullopt) { + auto avioContextHolder = std::make_unique(); + VideoStreamOptions videoStreamOptions; + videoStreamOptions.crf = crf; + return VideoEncoder( + frames, + validateInt64ToInt(frame_rate, "frame_rate"), + format, + std::move(avioContextHolder), + videoStreamOptions) + .encodeToTensor(); +} + // For testing only. We need to implement this operation as a core library // function because what we're testing is round-tripping pts values as // double-precision floating point numbers from C++ to Python and back to C++. @@ -847,9 +866,10 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("encode_audio_to_file", &encode_audio_to_file); - m.impl("encode_video_to_file", &encode_video_to_file); m.impl("encode_audio_to_tensor", &encode_audio_to_tensor); m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like); + m.impl("encode_video_to_file", &encode_video_to_file); + m.impl("encode_video_to_tensor", &encode_video_to_tensor); m.impl("seek_to_pts", &seek_to_pts); m.impl("add_video_stream", &add_video_stream); m.impl("_add_video_stream", &_add_video_stream); diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 6fc30e5af..03cf8cf6d 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -92,15 +92,18 @@ def load_torchcodec_shared_libraries(): encode_audio_to_file = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.encode_audio_to_file.default ) -encode_video_to_file = torch._dynamo.disallow_in_graph( - torch.ops.torchcodec_ns.encode_video_to_file.default -) encode_audio_to_tensor = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.encode_audio_to_tensor.default ) _encode_audio_to_file_like = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns._encode_audio_to_file_like.default ) +encode_video_to_file = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.encode_video_to_file.default +) +encode_video_to_tensor = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.encode_video_to_tensor.default +) create_from_tensor = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.create_from_tensor.default ) @@ -254,16 +257,6 @@ def encode_audio_to_file_abstract( return -@register_fake("torchcodec_ns::encode_video_to_file") -def encode_video_to_file_abstract( - frames: torch.Tensor, - frame_rate: int, - filename: str, - crf: Optional[int] = None, -) -> None: - return - - @register_fake("torchcodec_ns::encode_audio_to_tensor") def encode_audio_to_tensor_abstract( samples: torch.Tensor, @@ -289,6 +282,26 @@ def _encode_audio_to_file_like_abstract( return +@register_fake("torchcodec_ns::encode_video_to_file") +def encode_video_to_file_abstract( + frames: torch.Tensor, + frame_rate: int, + filename: str, + crf: Optional[int], +) -> None: + return + + +@register_fake("torchcodec_ns::encode_video_to_tensor") +def encode_video_to_tensor_abstract( + frames: torch.Tensor, + frame_rate: int, + format: str, + crf: Optional[int], +) -> torch.Tensor: + return torch.empty([], dtype=torch.long) + + @register_fake("torchcodec_ns::create_from_tensor") def create_from_tensor_abstract( video_tensor: torch.Tensor, seek_mode: Optional[str] diff --git a/test/test_ops.py b/test/test_ops.py index 0c1d90cfc..31afbdd14 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -29,6 +29,7 @@ create_from_tensor, encode_audio_to_file, encode_video_to_file, + encode_video_to_tensor, get_ffmpeg_library_versions, get_frame_at_index, get_frame_at_pts, @@ -41,6 +42,7 @@ get_next_frame, seek_to_pts, ) +from torchcodec.decoders import VideoDecoder from .utils import ( all_supported_devices, @@ -1328,6 +1330,7 @@ def test_bad_input(self, tmp_path): class TestVideoEncoderOps: + # TODO-VideoEncoder: Parametrize test after moving to test_encoders def test_bad_input(self, tmp_path): output_file = str(tmp_path / ".mp4") @@ -1378,17 +1381,25 @@ def test_bad_input(self, tmp_path): filename="./bad/path.mp3", ) - def decode(self, file_path) -> torch.Tensor: - decoder = create_from_file(str(file_path), seek_mode="approximate") - add_video_stream(decoder) - frames, *_ = get_frames_in_range(decoder, start=0, stop=60) - return frames + with pytest.raises( + RuntimeError, + match=r"Couldn't allocate AVFormatContext. Check the desired format\? Got format=bad_format", + ): + encode_video_to_tensor( + frames=torch.randint(high=255, size=(10, 3, 60, 60), dtype=torch.uint8), + frame_rate=10, + format="bad_format", + ) + + def decode(self, source=None) -> torch.Tensor: + return VideoDecoder(source).get_frames_in_range(start=0, stop=60) @pytest.mark.parametrize( "format", ("mov", "mp4", "mkv", pytest.param("webm", marks=pytest.mark.slow)) ) - def test_video_encoder_round_trip(self, tmp_path, format): - # Test that decode(encode(decode(asset))) == decode(asset) + @pytest.mark.parametrize("method", ("to_file", "to_tensor")) + def test_video_encoder_round_trip(self, tmp_path, format, method): + # Test that decode(encode(decode(frames))) == decode(frames) ffmpeg_version = get_ffmpeg_major_version() # In FFmpeg6, the default codec's best pixel format is lossy for all container formats but webm. # As a result, we skip the round trip test. @@ -1400,15 +1411,25 @@ def test_video_encoder_round_trip(self, tmp_path, format): ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) ): pytest.skip("Codec for webm is not available in this FFmpeg installation.") - asset = TEST_SRC_2_720P - source_frames = self.decode(str(asset.path)).data + source_frames = self.decode(TEST_SRC_2_720P.path).data + + params = dict( + frame_rate=30, crf=0 + ) # Frame rate is fixed with num frames decoded + if method == "to_file": + encoded_path = str(tmp_path / f"encoder_output.{format}") + encode_video_to_file( + frames=source_frames, + filename=encoded_path, + **params, + ) + round_trip_frames = self.decode(encoded_path).data + else: # to_tensor + encoded_tensor = encode_video_to_tensor( + source_frames, format=format, **params + ) + round_trip_frames = self.decode(encoded_tensor).data - encoded_path = str(tmp_path / f"encoder_output.{format}") - frame_rate = 30 # Frame rate is fixed with num frames decoded - encode_video_to_file( - frames=source_frames, frame_rate=frame_rate, filename=encoded_path, crf=0 - ) - round_trip_frames = self.decode(encoded_path).data assert source_frames.shape == round_trip_frames.shape assert source_frames.dtype == round_trip_frames.dtype @@ -1424,6 +1445,40 @@ def test_video_encoder_round_trip(self, tmp_path, format): assert psnr(s_frame, rt_frame) > 30 assert_close(s_frame, rt_frame, atol=atol, rtol=0) + @pytest.mark.parametrize( + "format", + ( + "mov", + "mp4", + "avi", + "mkv", + "flv", + "gif", + pytest.param("webm", marks=pytest.mark.slow), + ), + ) + def test_against_to_file(self, tmp_path, format): + # Test that to_file and to_tensor produce the same results + ffmpeg_version = get_ffmpeg_major_version() + if format == "webm" and ( + ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) + ): + pytest.skip("Codec for webm is not available in this FFmpeg installation.") + + source_frames = self.decode(TEST_SRC_2_720P.path).data + params = dict(frame_rate=30, crf=0) + + encoded_file = tmp_path / f"output.{format}" + encode_video_to_file(frames=source_frames, filename=str(encoded_file), **params) + encoded_tensor = encode_video_to_tensor(source_frames, format=format, **params) + + torch.testing.assert_close( + self.decode(encoded_file).data, + self.decode(encoded_tensor).data, + atol=0, + rtol=0, + ) + @pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available") @pytest.mark.parametrize( "format", @@ -1439,18 +1494,12 @@ def test_video_encoder_round_trip(self, tmp_path, format): ) def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): ffmpeg_version = get_ffmpeg_major_version() - if format == "webm": - if ffmpeg_version == 4: - pytest.skip( - "Codec for webm is not available in the FFmpeg4 installation." - ) - if IS_WINDOWS and ffmpeg_version in (6, 7): - pytest.skip( - "Codec for webm is not available in the FFmpeg6/7 installation on Windows." - ) - asset = TEST_SRC_2_720P - source_frames = self.decode(str(asset.path)).data - frame_rate = 30 + if format == "webm" and ( + ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) + ): + pytest.skip("Codec for webm is not available in this FFmpeg installation.") + + source_frames = self.decode(TEST_SRC_2_720P.path).data # Encode with FFmpeg CLI temp_raw_path = str(tmp_path / "temp_input.raw") @@ -1458,8 +1507,8 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes()) ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_output.{format}") + frame_rate = 30 crf = 0 - quality_params = ["-crf", str(crf)] # Some codecs (ex. MPEG4) do not support CRF. # Flags not supported by the selected codec will be ignored. ffmpeg_cmd = [ @@ -1475,7 +1524,8 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): str(frame_rate), "-i", temp_raw_path, - *quality_params, + "-crf", + str(crf), ffmpeg_encoded_path, ] subprocess.run(ffmpeg_cmd, check=True)