Skip to content

Commit e4af1df

Browse files
author
Daniel Flores
committed
to_tensor, AVIOTensorContext fix
1 parent 61202b9 commit e4af1df

File tree

8 files changed

+161
-50
lines changed

8 files changed

+161
-50
lines changed

src/torchcodec/_core/AVIOTensorContext.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ int write(void* opaque, const uint8_t* buf, int buf_size) {
7575
uint8_t* outputTensorData = tensorContext->data.data_ptr<uint8_t>();
7676
std::memcpy(outputTensorData + tensorContext->current, buf, bufSize);
7777
tensorContext->current += bufSize;
78+
// Track the maximum position written so getOutputTensor's narrow() does not
79+
// truncate the file if final seek was backwards
80+
if (tensorContext->current > tensorContext->max) {
81+
tensorContext->max = tensorContext->current;
82+
}
7883
return buf_size;
7984
}
8085

@@ -101,7 +106,7 @@ int64_t seek(void* opaque, int64_t offset, int whence) {
101106
} // namespace
102107

103108
AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data)
104-
: tensorContext_{data, 0} {
109+
: tensorContext_{data, 0, 0} {
105110
TORCH_CHECK(data.numel() > 0, "data must not be empty");
106111
TORCH_CHECK(data.is_contiguous(), "data must be contiguous");
107112
TORCH_CHECK(data.scalar_type() == torch::kUInt8, "data must be kUInt8");
@@ -110,14 +115,17 @@ AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data)
110115
}
111116

112117
AVIOToTensorContext::AVIOToTensorContext()
113-
: tensorContext_{torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}), 0} {
118+
: tensorContext_{
119+
torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}),
120+
0,
121+
0} {
114122
createAVIOContext(
115123
nullptr, &write, &seek, &tensorContext_, /*isForWriting=*/true);
116124
}
117125

118126
torch::Tensor AVIOToTensorContext::getOutputTensor() {
119127
return tensorContext_.data.narrow(
120-
/*dim=*/0, /*start=*/0, /*length=*/tensorContext_.current);
128+
/*dim=*/0, /*start=*/0, /*length=*/tensorContext_.max);
121129
}
122130

123131
} // namespace facebook::torchcodec

src/torchcodec/_core/AVIOTensorContext.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ namespace detail {
1616
struct TensorContext {
1717
torch::Tensor data;
1818
int64_t current;
19+
int64_t max;
1920
};
2021

2122
} // namespace detail

src/torchcodec/_core/Encoder.cpp

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
#include "src/torchcodec/_core/Encoder.h"
55
#include "torch/types.h"
66

7-
extern "C" {
8-
#include <libavutil/pixdesc.h>
9-
}
10-
117
namespace facebook::torchcodec {
128

139
namespace {
@@ -543,9 +539,15 @@ torch::Tensor validateFrames(const torch::Tensor& frames) {
543539

544540
VideoEncoder::~VideoEncoder() {
545541
if (avFormatContext_ && avFormatContext_->pb) {
546-
avio_flush(avFormatContext_->pb);
547-
avio_close(avFormatContext_->pb);
548-
avFormatContext_->pb = nullptr;
542+
if (avFormatContext_->pb->error == 0) {
543+
avio_flush(avFormatContext_->pb);
544+
}
545+
if (!avioContextHolder_) {
546+
if (avFormatContext_->pb->error == 0) {
547+
avio_close(avFormatContext_->pb);
548+
}
549+
avFormatContext_->pb = nullptr;
550+
}
549551
}
550552
}
551553

@@ -581,6 +583,34 @@ VideoEncoder::VideoEncoder(
581583
initializeEncoder(videoStreamOptions);
582584
}
583585

586+
VideoEncoder::VideoEncoder(
587+
const torch::Tensor& frames,
588+
int frameRate,
589+
std::string_view formatName,
590+
std::unique_ptr<AVIOContextHolder> avioContextHolder,
591+
const VideoStreamOptions& videoStreamOptions)
592+
: frames_(validateFrames(frames)),
593+
inFrameRate_(frameRate),
594+
avioContextHolder_(std::move(avioContextHolder)) {
595+
setFFmpegLogLevel();
596+
AVFormatContext* avFormatContext = nullptr;
597+
int status = avformat_alloc_output_context2(
598+
&avFormatContext, nullptr, formatName.data(), nullptr);
599+
600+
TORCH_CHECK(
601+
avFormatContext != nullptr,
602+
"Couldn't allocate AVFormatContext. ",
603+
"Check the desired format? Got format=",
604+
formatName,
605+
". ",
606+
getFFMPEGErrorStringFromErrorCode(status));
607+
avFormatContext_.reset(avFormatContext);
608+
609+
avFormatContext_->pb = avioContextHolder_->getAVIOContext();
610+
611+
initializeEncoder(videoStreamOptions);
612+
}
613+
584614
void VideoEncoder::initializeEncoder(
585615
const VideoStreamOptions& videoStreamOptions) {
586616
const AVCodec* avCodec =
@@ -751,6 +781,17 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
751781
return avFrame;
752782
}
753783

784+
torch::Tensor VideoEncoder::encodeToTensor() {
785+
TORCH_CHECK(
786+
avioContextHolder_ != nullptr,
787+
"Cannot encode to tensor, avio tensor context doesn't exist.");
788+
encode();
789+
auto avioToTensorContext =
790+
dynamic_cast<AVIOToTensorContext*>(avioContextHolder_.get());
791+
TORCH_CHECK(avioToTensorContext != nullptr, "Invalid AVIO context holder.");
792+
return avioToTensorContext->getOutputTensor();
793+
}
794+
754795
void VideoEncoder::encodeFrame(
755796
AutoAVPacket& autoAVPacket,
756797
const UniqueAVFrame& avFrame) {

src/torchcodec/_core/Encoder.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,17 @@ class VideoEncoder {
141141
std::string_view fileName,
142142
const VideoStreamOptions& videoStreamOptions);
143143

144+
VideoEncoder(
145+
const torch::Tensor& frames,
146+
int frameRate,
147+
std::string_view formatName,
148+
std::unique_ptr<AVIOContextHolder> avioContextHolder,
149+
const VideoStreamOptions& videoStreamOptions);
150+
144151
void encode();
145152

153+
torch::Tensor encodeToTensor();
154+
146155
private:
147156
void initializeEncoder(const VideoStreamOptions& videoStreamOptions);
148157
UniqueAVFrame convertTensorToAVFrame(
@@ -167,6 +176,8 @@ class VideoEncoder {
167176
int outHeight_ = -1;
168177
AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE;
169178

179+
std::unique_ptr<AVIOContextHolder> avioContextHolder_;
180+
170181
bool encodeWasCalled_ = false;
171182
};
172183

src/torchcodec/_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
encode_audio_to_file_like,
2727
encode_audio_to_tensor,
2828
encode_video_to_file,
29+
encode_video_to_tensor,
2930
get_ffmpeg_library_versions,
3031
get_frame_at_index,
3132
get_frame_at_pts,

src/torchcodec/_core/custom_ops.cpp

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,14 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3232
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
3333
m.def(
3434
"encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
35-
m.def(
36-
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
3735
m.def(
3836
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
3937
m.def(
4038
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
39+
m.def(
40+
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
41+
m.def(
42+
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor");
4143
m.def(
4244
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4345
m.def(
@@ -498,21 +500,6 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
498500
return makeOpsAudioFramesOutput(result);
499501
}
500502

501-
void encode_video_to_file(
502-
const at::Tensor& frames,
503-
int64_t frame_rate,
504-
std::string_view file_name,
505-
std::optional<int64_t> crf = std::nullopt) {
506-
VideoStreamOptions videoStreamOptions;
507-
videoStreamOptions.crf = crf;
508-
VideoEncoder(
509-
frames,
510-
validateInt64ToInt(frame_rate, "frame_rate"),
511-
file_name,
512-
videoStreamOptions)
513-
.encode();
514-
}
515-
516503
void encode_audio_to_file(
517504
const at::Tensor& samples,
518505
int64_t sample_rate,
@@ -587,6 +574,38 @@ void _encode_audio_to_file_like(
587574
encoder.encode();
588575
}
589576

577+
void encode_video_to_file(
578+
const at::Tensor& frames,
579+
int64_t frame_rate,
580+
std::string_view file_name,
581+
std::optional<int64_t> crf = std::nullopt) {
582+
VideoStreamOptions videoStreamOptions;
583+
videoStreamOptions.crf = crf;
584+
VideoEncoder(
585+
frames,
586+
validateInt64ToInt(frame_rate, "frame_rate"),
587+
file_name,
588+
videoStreamOptions)
589+
.encode();
590+
}
591+
592+
at::Tensor encode_video_to_tensor(
593+
const at::Tensor& frames,
594+
int64_t frame_rate,
595+
std::string_view format,
596+
std::optional<int64_t> crf = std::nullopt) {
597+
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
598+
VideoStreamOptions videoStreamOptions;
599+
videoStreamOptions.crf = crf;
600+
return VideoEncoder(
601+
frames,
602+
validateInt64ToInt(frame_rate, "frame_rate"),
603+
format,
604+
std::move(avioContextHolder),
605+
videoStreamOptions)
606+
.encodeToTensor();
607+
}
608+
590609
// For testing only. We need to implement this operation as a core library
591610
// function because what we're testing is round-tripping pts values as
592611
// double-precision floating point numbers from C++ to Python and back to C++.
@@ -847,9 +866,10 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
847866

848867
TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
849868
m.impl("encode_audio_to_file", &encode_audio_to_file);
850-
m.impl("encode_video_to_file", &encode_video_to_file);
851869
m.impl("encode_audio_to_tensor", &encode_audio_to_tensor);
852870
m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like);
871+
m.impl("encode_video_to_file", &encode_video_to_file);
872+
m.impl("encode_video_to_tensor", &encode_video_to_tensor);
853873
m.impl("seek_to_pts", &seek_to_pts);
854874
m.impl("add_video_stream", &add_video_stream);
855875
m.impl("_add_video_stream", &_add_video_stream);

src/torchcodec/_core/ops.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,18 @@ def load_torchcodec_shared_libraries():
9292
encode_audio_to_file = torch._dynamo.disallow_in_graph(
9393
torch.ops.torchcodec_ns.encode_audio_to_file.default
9494
)
95-
encode_video_to_file = torch._dynamo.disallow_in_graph(
96-
torch.ops.torchcodec_ns.encode_video_to_file.default
97-
)
9895
encode_audio_to_tensor = torch._dynamo.disallow_in_graph(
9996
torch.ops.torchcodec_ns.encode_audio_to_tensor.default
10097
)
10198
_encode_audio_to_file_like = torch._dynamo.disallow_in_graph(
10299
torch.ops.torchcodec_ns._encode_audio_to_file_like.default
103100
)
101+
encode_video_to_file = torch._dynamo.disallow_in_graph(
102+
torch.ops.torchcodec_ns.encode_video_to_file.default
103+
)
104+
encode_video_to_tensor = torch._dynamo.disallow_in_graph(
105+
torch.ops.torchcodec_ns.encode_video_to_tensor.default
106+
)
104107
create_from_tensor = torch._dynamo.disallow_in_graph(
105108
torch.ops.torchcodec_ns.create_from_tensor.default
106109
)
@@ -254,16 +257,6 @@ def encode_audio_to_file_abstract(
254257
return
255258

256259

257-
@register_fake("torchcodec_ns::encode_video_to_file")
258-
def encode_video_to_file_abstract(
259-
frames: torch.Tensor,
260-
frame_rate: int,
261-
filename: str,
262-
crf: Optional[int] = None,
263-
) -> None:
264-
return
265-
266-
267260
@register_fake("torchcodec_ns::encode_audio_to_tensor")
268261
def encode_audio_to_tensor_abstract(
269262
samples: torch.Tensor,
@@ -289,6 +282,26 @@ def _encode_audio_to_file_like_abstract(
289282
return
290283

291284

285+
@register_fake("torchcodec_ns::encode_video_to_file")
286+
def encode_video_to_file_abstract(
287+
frames: torch.Tensor,
288+
frame_rate: int,
289+
filename: str,
290+
crf: Optional[int],
291+
) -> None:
292+
return
293+
294+
295+
@register_fake("torchcodec_ns::encode_video_to_tensor")
296+
def encode_video_to_tensor_abstract(
297+
frames: torch.Tensor,
298+
frame_rate: int,
299+
format: str,
300+
crf: Optional[int],
301+
) -> torch.Tensor:
302+
return torch.empty([], dtype=torch.long)
303+
304+
292305
@register_fake("torchcodec_ns::create_from_tensor")
293306
def create_from_tensor_abstract(
294307
video_tensor: torch.Tensor, seek_mode: Optional[str]

test/test_ops.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
create_from_tensor,
3030
encode_audio_to_file,
3131
encode_video_to_file,
32+
encode_video_to_tensor,
3233
get_ffmpeg_library_versions,
3334
get_frame_at_index,
3435
get_frame_at_pts,
@@ -1378,14 +1379,18 @@ def test_bad_input(self, tmp_path):
13781379
filename="./bad/path.mp3",
13791380
)
13801381

1381-
def decode(self, file_path) -> torch.Tensor:
1382-
decoder = create_from_file(str(file_path), seek_mode="approximate")
1382+
def decode(self, file_path=None, tensor=None) -> torch.Tensor:
1383+
if file_path is not None:
1384+
decoder = create_from_file(str(file_path), seek_mode="approximate")
1385+
elif tensor is not None:
1386+
decoder = create_from_tensor(tensor, seek_mode="approximate")
13831387
add_video_stream(decoder)
13841388
frames, *_ = get_frames_in_range(decoder, start=0, stop=60)
13851389
return frames
13861390

13871391
@pytest.mark.parametrize("format", ("mov", "mp4", "mkv", "webm"))
1388-
def test_video_encoder_round_trip(self, tmp_path, format):
1392+
@pytest.mark.parametrize("output_method", ("to_file", "to_tensor"))
1393+
def test_video_encoder_round_trip(self, tmp_path, format, output_method):
13891394
# Test that decode(encode(decode(asset))) == decode(asset)
13901395
ffmpeg_version = get_ffmpeg_major_version()
13911396
# In FFmpeg6, the default codec's best pixel format is lossy for all container formats but webm.
@@ -1399,14 +1404,25 @@ def test_video_encoder_round_trip(self, tmp_path, format):
13991404
):
14001405
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
14011406
asset = TEST_SRC_2_720P
1402-
source_frames = self.decode(str(asset.path)).data
1407+
source_frames = self.decode(file_path=str(asset.path)).data
14031408

1404-
encoded_path = str(tmp_path / f"encoder_output.{format}")
14051409
frame_rate = 30 # Frame rate is fixed with num frames decoded
1406-
encode_video_to_file(
1407-
frames=source_frames, frame_rate=frame_rate, filename=encoded_path, crf=0
1408-
)
1409-
round_trip_frames = self.decode(encoded_path).data
1410+
if output_method == "to_file":
1411+
encoded_path = str(tmp_path / f"encoder_output.{format}")
1412+
encode_video_to_file(
1413+
frames=source_frames,
1414+
frame_rate=frame_rate,
1415+
filename=encoded_path,
1416+
crf=0,
1417+
)
1418+
round_trip_frames = self.decode(file_path=encoded_path).data
1419+
else: # to_tensor
1420+
format = "matroska" if format == "mkv" else format
1421+
encoded_tensor = encode_video_to_tensor(
1422+
source_frames, frame_rate, format, crf=0
1423+
)
1424+
round_trip_frames = self.decode(tensor=encoded_tensor).data
1425+
14101426
assert source_frames.shape == round_trip_frames.shape
14111427
assert source_frames.dtype == round_trip_frames.dtype
14121428

0 commit comments

Comments
 (0)