Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions src/torchcodec/_core/AVIOTensorContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,34 @@ constexpr int64_t MAX_TENSOR_SIZE = 320'000'000; // 320 MB
int read(void* opaque, uint8_t* buf, int buf_size) {
auto tensorContext = static_cast<detail::TensorContext*>(opaque);
TORCH_CHECK(
tensorContext->current <= tensorContext->data.numel(),
"Tried to read outside of the buffer: current=",
tensorContext->current,
tensorContext->current_pos <= tensorContext->data.numel(),
"Tried to read outside of the buffer: current_pos=",
tensorContext->current_pos,
", size=",
tensorContext->data.numel());

int64_t numBytesRead = std::min(
static_cast<int64_t>(buf_size),
tensorContext->data.numel() - tensorContext->current);
tensorContext->data.numel() - tensorContext->current_pos);

TORCH_CHECK(
numBytesRead >= 0,
"Tried to read negative bytes: numBytesRead=",
numBytesRead,
", size=",
tensorContext->data.numel(),
", current=",
tensorContext->current);
", current_pos=",
tensorContext->current_pos);

if (numBytesRead == 0) {
return AVERROR_EOF;
}

std::memcpy(
buf,
tensorContext->data.data_ptr<uint8_t>() + tensorContext->current,
tensorContext->data.data_ptr<uint8_t>() + tensorContext->current_pos,
numBytesRead);
tensorContext->current += numBytesRead;
tensorContext->current_pos += numBytesRead;
return numBytesRead;
}

Expand All @@ -54,7 +54,7 @@ int write(void* opaque, const uint8_t* buf, int buf_size) {
auto tensorContext = static_cast<detail::TensorContext*>(opaque);

int64_t bufSize = static_cast<int64_t>(buf_size);
if (tensorContext->current + bufSize > tensorContext->data.numel()) {
if (tensorContext->current_pos + bufSize > tensorContext->data.numel()) {
TORCH_CHECK(
tensorContext->data.numel() * 2 <= MAX_TENSOR_SIZE,
"We tried to allocate an output encoded tensor larger than ",
Expand All @@ -68,13 +68,17 @@ int write(void* opaque, const uint8_t* buf, int buf_size) {
}

TORCH_CHECK(
tensorContext->current + bufSize <= tensorContext->data.numel(),
tensorContext->current_pos + bufSize <= tensorContext->data.numel(),
"Re-allocation of the output tensor didn't work. ",
"This should not happen, please report on TorchCodec bug tracker");

uint8_t* outputTensorData = tensorContext->data.data_ptr<uint8_t>();
std::memcpy(outputTensorData + tensorContext->current, buf, bufSize);
tensorContext->current += bufSize;
std::memcpy(outputTensorData + tensorContext->current_pos, buf, bufSize);
tensorContext->current_pos += bufSize;
// Track the maximum position written so getOutputTensor's narrow() does not
// truncate the file if final seek was backwards
tensorContext->max_pos =
std::max(tensorContext->current_pos, tensorContext->max_pos);
return buf_size;
}

Expand All @@ -88,7 +92,7 @@ int64_t seek(void* opaque, int64_t offset, int whence) {
ret = tensorContext->data.numel();
break;
case SEEK_SET:
tensorContext->current = offset;
tensorContext->current_pos = offset;
ret = offset;
break;
default:
Expand All @@ -101,7 +105,7 @@ int64_t seek(void* opaque, int64_t offset, int whence) {
} // namespace

AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data)
: tensorContext_{data, 0} {
: tensorContext_{data, 0, 0} {
TORCH_CHECK(data.numel() > 0, "data must not be empty");
TORCH_CHECK(data.is_contiguous(), "data must be contiguous");
TORCH_CHECK(data.scalar_type() == torch::kUInt8, "data must be kUInt8");
Expand All @@ -110,14 +114,17 @@ AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data)
}

AVIOToTensorContext::AVIOToTensorContext()
: tensorContext_{torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}), 0} {
: tensorContext_{
torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}),
0,
0} {
createAVIOContext(
nullptr, &write, &seek, &tensorContext_, /*isForWriting=*/true);
}

torch::Tensor AVIOToTensorContext::getOutputTensor() {
return tensorContext_.data.narrow(
/*dim=*/0, /*start=*/0, /*length=*/tensorContext_.current);
/*dim=*/0, /*start=*/0, /*length=*/tensorContext_.max_pos);
}

} // namespace facebook::torchcodec
3 changes: 2 additions & 1 deletion src/torchcodec/_core/AVIOTensorContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ namespace detail {

struct TensorContext {
torch::Tensor data;
int64_t current;
int64_t current_pos;
int64_t max_pos;
};

} // namespace detail
Expand Down
58 changes: 51 additions & 7 deletions src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
#include "src/torchcodec/_core/Encoder.h"
#include "torch/types.h"

extern "C" {
#include <libavutil/pixdesc.h>
}

namespace facebook::torchcodec {

namespace {
Expand Down Expand Up @@ -542,10 +538,17 @@ torch::Tensor validateFrames(const torch::Tensor& frames) {
} // namespace

VideoEncoder::~VideoEncoder() {
// TODO-VideoEncoder: Unify destructor with ~AudioEncoder()
if (avFormatContext_ && avFormatContext_->pb) {
avio_flush(avFormatContext_->pb);
avio_close(avFormatContext_->pb);
avFormatContext_->pb = nullptr;
if (avFormatContext_->pb->error == 0) {
avio_flush(avFormatContext_->pb);
}
if (!avioContextHolder_) {
if (avFormatContext_->pb->error == 0) {
avio_close(avFormatContext_->pb);
}
avFormatContext_->pb = nullptr;
}
}
}

Expand Down Expand Up @@ -581,6 +584,36 @@ VideoEncoder::VideoEncoder(
initializeEncoder(videoStreamOptions);
}

VideoEncoder::VideoEncoder(
const torch::Tensor& frames,
int frameRate,
std::string_view formatName,
std::unique_ptr<AVIOContextHolder> avioContextHolder,
const VideoStreamOptions& videoStreamOptions)
: frames_(validateFrames(frames)),
inFrameRate_(frameRate),
avioContextHolder_(std::move(avioContextHolder)) {
setFFmpegLogLevel();
// Map mkv -> matroska when used as format name
formatName = (formatName == "mkv") ? "matroska" : formatName;
AVFormatContext* avFormatContext = nullptr;
int status = avformat_alloc_output_context2(
&avFormatContext, nullptr, formatName.data(), nullptr);

TORCH_CHECK(
avFormatContext != nullptr,
"Couldn't allocate AVFormatContext. ",
"Check the desired format? Got format=",
formatName,
". ",
getFFMPEGErrorStringFromErrorCode(status));
avFormatContext_.reset(avFormatContext);

avFormatContext_->pb = avioContextHolder_->getAVIOContext();

initializeEncoder(videoStreamOptions);
}

void VideoEncoder::initializeEncoder(
const VideoStreamOptions& videoStreamOptions) {
const AVCodec* avCodec =
Expand Down Expand Up @@ -751,6 +784,17 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
return avFrame;
}

torch::Tensor VideoEncoder::encodeToTensor() {
TORCH_CHECK(
avioContextHolder_ != nullptr,
"Cannot encode to tensor, avio tensor context doesn't exist.");
encode();
auto avioToTensorContext =
dynamic_cast<AVIOToTensorContext*>(avioContextHolder_.get());
TORCH_CHECK(avioToTensorContext != nullptr, "Invalid AVIO context holder.");
return avioToTensorContext->getOutputTensor();
}

void VideoEncoder::encodeFrame(
AutoAVPacket& autoAVPacket,
const UniqueAVFrame& avFrame) {
Expand Down
11 changes: 11 additions & 0 deletions src/torchcodec/_core/Encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,17 @@ class VideoEncoder {
std::string_view fileName,
const VideoStreamOptions& videoStreamOptions);

VideoEncoder(
const torch::Tensor& frames,
int frameRate,
std::string_view formatName,
std::unique_ptr<AVIOContextHolder> avioContextHolder,
const VideoStreamOptions& videoStreamOptions);

void encode();

torch::Tensor encodeToTensor();

private:
void initializeEncoder(const VideoStreamOptions& videoStreamOptions);
UniqueAVFrame convertTensorToAVFrame(
Expand All @@ -167,6 +176,8 @@ class VideoEncoder {
int outHeight_ = -1;
AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE;

std::unique_ptr<AVIOContextHolder> avioContextHolder_;

bool encodeWasCalled_ = false;
};

Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
encode_audio_to_file_like,
encode_audio_to_tensor,
encode_video_to_file,
encode_video_to_tensor,
get_ffmpeg_library_versions,
get_frame_at_index,
get_frame_at_pts,
Expand Down
56 changes: 38 additions & 18 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
m.def(
"encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
m.def(
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
m.def(
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
m.def(
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
m.def(
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
m.def(
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor");
m.def(
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
m.def(
Expand Down Expand Up @@ -498,21 +500,6 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
return makeOpsAudioFramesOutput(result);
}

void encode_video_to_file(
const at::Tensor& frames,
int64_t frame_rate,
std::string_view file_name,
std::optional<int64_t> crf = std::nullopt) {
VideoStreamOptions videoStreamOptions;
videoStreamOptions.crf = crf;
VideoEncoder(
frames,
validateInt64ToInt(frame_rate, "frame_rate"),
file_name,
videoStreamOptions)
.encode();
}

void encode_audio_to_file(
const at::Tensor& samples,
int64_t sample_rate,
Expand Down Expand Up @@ -587,6 +574,38 @@ void _encode_audio_to_file_like(
encoder.encode();
}

void encode_video_to_file(
const at::Tensor& frames,
int64_t frame_rate,
std::string_view file_name,
std::optional<int64_t> crf = std::nullopt) {
VideoStreamOptions videoStreamOptions;
videoStreamOptions.crf = crf;
VideoEncoder(
frames,
validateInt64ToInt(frame_rate, "frame_rate"),
file_name,
videoStreamOptions)
.encode();
}

at::Tensor encode_video_to_tensor(
const at::Tensor& frames,
int64_t frame_rate,
std::string_view format,
std::optional<int64_t> crf = std::nullopt) {
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
VideoStreamOptions videoStreamOptions;
videoStreamOptions.crf = crf;
return VideoEncoder(
frames,
validateInt64ToInt(frame_rate, "frame_rate"),
format,
std::move(avioContextHolder),
videoStreamOptions)
.encodeToTensor();
}

// For testing only. We need to implement this operation as a core library
// function because what we're testing is round-tripping pts values as
// double-precision floating point numbers from C++ to Python and back to C++.
Expand Down Expand Up @@ -847,9 +866,10 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {

TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
m.impl("encode_audio_to_file", &encode_audio_to_file);
m.impl("encode_video_to_file", &encode_video_to_file);
m.impl("encode_audio_to_tensor", &encode_audio_to_tensor);
m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like);
m.impl("encode_video_to_file", &encode_video_to_file);
m.impl("encode_video_to_tensor", &encode_video_to_tensor);
m.impl("seek_to_pts", &seek_to_pts);
m.impl("add_video_stream", &add_video_stream);
m.impl("_add_video_stream", &_add_video_stream);
Expand Down
39 changes: 26 additions & 13 deletions src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,18 @@ def load_torchcodec_shared_libraries():
encode_audio_to_file = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns.encode_audio_to_file.default
)
encode_video_to_file = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns.encode_video_to_file.default
)
encode_audio_to_tensor = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns.encode_audio_to_tensor.default
)
_encode_audio_to_file_like = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns._encode_audio_to_file_like.default
)
encode_video_to_file = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns.encode_video_to_file.default
)
encode_video_to_tensor = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns.encode_video_to_tensor.default
)
create_from_tensor = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns.create_from_tensor.default
)
Expand Down Expand Up @@ -254,16 +257,6 @@ def encode_audio_to_file_abstract(
return


@register_fake("torchcodec_ns::encode_video_to_file")
def encode_video_to_file_abstract(
frames: torch.Tensor,
frame_rate: int,
filename: str,
crf: Optional[int] = None,
) -> None:
return


@register_fake("torchcodec_ns::encode_audio_to_tensor")
def encode_audio_to_tensor_abstract(
samples: torch.Tensor,
Expand All @@ -289,6 +282,26 @@ def _encode_audio_to_file_like_abstract(
return


@register_fake("torchcodec_ns::encode_video_to_file")
def encode_video_to_file_abstract(
frames: torch.Tensor,
frame_rate: int,
filename: str,
crf: Optional[int],
) -> None:
return


@register_fake("torchcodec_ns::encode_video_to_tensor")
def encode_video_to_tensor_abstract(
frames: torch.Tensor,
frame_rate: int,
format: str,
crf: Optional[int],
) -> torch.Tensor:
return torch.empty([], dtype=torch.long)


@register_fake("torchcodec_ns::create_from_tensor")
def create_from_tensor_abstract(
video_tensor: torch.Tensor, seek_mode: Optional[str]
Expand Down
Loading
Loading