Skip to content

Commit fb8b8fa

Browse files
committed
changes
1 parent d14deb8 commit fb8b8fa

File tree

13 files changed

+238
-50
lines changed

13 files changed

+238
-50
lines changed

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,6 @@ class CpuDeviceInterface : public DeviceInterface {
1818

1919
virtual ~CpuDeviceInterface() {}
2020

21-
std::optional<const AVCodec*> findCodec(
22-
[[maybe_unused]] const AVCodecID& codecId) override {
23-
return std::nullopt;
24-
}
25-
2621
virtual void initialize(
2722
const AVStream* avStream,
2823
const UniqueDecodingAVFormatContext& avFormatCtx,

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -329,11 +329,40 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
329329
avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);
330330
}
331331

332+
namespace {
333+
// Helper function to check if a codec supports CUDA hardware acceleration
334+
bool codecSupportsCudaHardware(const AVCodec* codec) {
335+
const AVCodecHWConfig* config = nullptr;
336+
for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr; ++j) {
337+
if (config->device_type == AV_HWDEVICE_TYPE_CUDA) {
338+
return true;
339+
}
340+
}
341+
return false;
342+
}
343+
} // namespace
344+
332345
// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9
333346
// we have to do this because of an FFmpeg bug where hardware decoding is not
334347
// appropriately set, so we just go off and find the matching codec for the CUDA
335348
// device
336-
std::optional<const AVCodec*> CudaDeviceInterface::findCodec(
349+
350+
std::optional<const AVCodec*> CudaDeviceInterface::findEncoder(
351+
const AVCodecID& codecId) {
352+
void* i = nullptr;
353+
const AVCodec* codec = nullptr;
354+
while ((codec = av_codec_iterate(&i)) != nullptr) {
355+
if (codec->id != codecId || !av_codec_is_encoder(codec)) {
356+
continue;
357+
}
358+
if (codecSupportsCudaHardware(codec)) {
359+
return codec;
360+
}
361+
}
362+
return std::nullopt;
363+
}
364+
365+
std::optional<const AVCodec*> CudaDeviceInterface::findDecoder(
337366
const AVCodecID& codecId) {
338367
void* i = nullptr;
339368
const AVCodec* codec = nullptr;
@@ -342,12 +371,8 @@ std::optional<const AVCodec*> CudaDeviceInterface::findCodec(
342371
continue;
343372
}
344373

345-
const AVCodecHWConfig* config = nullptr;
346-
for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr;
347-
++j) {
348-
if (config->device_type == AV_HWDEVICE_TYPE_CUDA) {
349-
return codec;
350-
}
374+
if (codecSupportsCudaHardware(codec)) {
375+
return codec;
351376
}
352377
}
353378

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ class CudaDeviceInterface : public DeviceInterface {
1818

1919
virtual ~CudaDeviceInterface();
2020

21-
std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) override;
21+
std::optional<const AVCodec*> findEncoder(const AVCodecID& codecId) override;
22+
std::optional<const AVCodec*> findDecoder(const AVCodecID& codecId) override;
2223

2324
void initialize(
2425
const AVStream* avStream,

src/torchcodec/_core/DeviceInterface.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,12 @@ class DeviceInterface {
4646
return device_;
4747
};
4848

49-
virtual std::optional<const AVCodec*> findCodec(
49+
virtual std::optional<const AVCodec*> findEncoder(
50+
[[maybe_unused]] const AVCodecID& codecId) {
51+
return std::nullopt;
52+
};
53+
54+
virtual std::optional<const AVCodec*> findDecoder(
5055
[[maybe_unused]] const AVCodecID& codecId) {
5156
return std::nullopt;
5257
};

src/torchcodec/_core/Encoder.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,10 +615,25 @@ VideoEncoder::VideoEncoder(
615615

616616
void VideoEncoder::initializeEncoder(
617617
const VideoStreamOptions& videoStreamOptions) {
618+
deviceInterface_ = createDeviceInterface(
619+
videoStreamOptions.device, videoStreamOptions.deviceVariant);
620+
TORCH_CHECK(
621+
deviceInterface_ != nullptr,
622+
"Failed to create device interface. This should never happen, please report.");
623+
618624
const AVCodec* avCodec =
619625
avcodec_find_encoder(avFormatContext_->oformat->video_codec);
620626
TORCH_CHECK(avCodec != nullptr, "Video codec not found");
621627

628+
// Try to find a hardware-accelerated encoder if not using CPU
629+
if (videoStreamOptions.device.type() != torch::kCPU) {
630+
auto hardwareCodec =
631+
deviceInterface_->findEncoder(avFormatContext_->oformat->video_codec);
632+
if (hardwareCodec.has_value()) {
633+
avCodec = hardwareCodec.value();
634+
}
635+
}
636+
622637
AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
623638
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
624639
avCodecContext_.reset(avCodecContext);
@@ -668,6 +683,11 @@ void VideoEncoder::initializeEncoder(
668683
std::to_string(videoStreamOptions.crf.value()).c_str(),
669684
0);
670685
}
686+
687+
// Register the hardware device context with the codec
688+
// context before calling avcodec_open2().
689+
deviceInterface_->registerHardwareDeviceWithCodec(avCodecContext_.get());
690+
671691
int status = avcodec_open2(avCodecContext_.get(), avCodec, &options);
672692
av_dict_free(&options);
673693

src/torchcodec/_core/Encoder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22
#include <torch/types.h>
33
#include "src/torchcodec/_core/AVIOContextHolder.h"
4+
#include "src/torchcodec/_core/DeviceInterface.h"
45
#include "src/torchcodec/_core/FFMPEGCommon.h"
56
#include "src/torchcodec/_core/StreamOptions.h"
67

@@ -177,6 +178,7 @@ class VideoEncoder {
177178
AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE;
178179

179180
std::unique_ptr<AVIOContextHolder> avioContextHolder_;
181+
std::unique_ptr<DeviceInterface> deviceInterface_;
180182

181183
bool encodeWasCalled_ = false;
182184
};

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ AVPacket* ReferenceAVPacket::operator->() {
4040

4141
AVCodecOnlyUseForCallingAVFindBestStream
4242
makeAVCodecOnlyUseForCallingAVFindBestStream(const AVCodec* codec) {
43-
#if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100)
43+
#if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100) // FFmpeg < 5.0.3
4444
return const_cast<AVCodec*>(codec);
4545
#else
4646
return codec;

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ void SingleStreamDecoder::addStream(
435435
// addStream() which is supposed to be generic
436436
if (mediaType == AVMEDIA_TYPE_VIDEO) {
437437
avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream(
438-
deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id)
438+
deviceInterface_->findDecoder(streamInfo.stream->codecpar->codec_id)
439439
.value_or(avCodec));
440440
}
441441

src/torchcodec/_core/custom_ops.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3737
m.def(
3838
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3939
m.def(
40-
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
40+
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str device=\"cpu\", int? crf=None) -> ()");
4141
m.def(
42-
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor");
42+
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str device=\"cpu\", int? crf=None) -> Tensor");
4343
m.def(
44-
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, int? crf=None) -> ()");
44+
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str device=\"cpu\",int? crf=None) -> ()");
4545
m.def(
4646
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4747
m.def(
@@ -603,9 +603,13 @@ void encode_video_to_file(
603603
const at::Tensor& frames,
604604
int64_t frame_rate,
605605
std::string_view file_name,
606+
std::string_view device = "cpu",
606607
std::optional<int64_t> crf = std::nullopt) {
607608
VideoStreamOptions videoStreamOptions;
608609
videoStreamOptions.crf = crf;
610+
611+
videoStreamOptions.device = torch::Device(std::string(device));
612+
videoStreamOptions.deviceVariant = "ffmpeg";
609613
VideoEncoder(
610614
frames,
611615
validateInt64ToInt(frame_rate, "frame_rate"),
@@ -618,10 +622,14 @@ at::Tensor encode_video_to_tensor(
618622
const at::Tensor& frames,
619623
int64_t frame_rate,
620624
std::string_view format,
625+
std::string_view device = "cpu",
621626
std::optional<int64_t> crf = std::nullopt) {
622627
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
623628
VideoStreamOptions videoStreamOptions;
624629
videoStreamOptions.crf = crf;
630+
631+
videoStreamOptions.device = torch::Device(std::string(device));
632+
videoStreamOptions.deviceVariant = "ffmpeg";
625633
return VideoEncoder(
626634
frames,
627635
validateInt64ToInt(frame_rate, "frame_rate"),
@@ -636,6 +644,7 @@ void _encode_video_to_file_like(
636644
int64_t frame_rate,
637645
std::string_view format,
638646
int64_t file_like_context,
647+
std::string_view device = "cpu",
639648
std::optional<int64_t> crf = std::nullopt) {
640649
auto fileLikeContext =
641650
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
@@ -646,6 +655,9 @@ void _encode_video_to_file_like(
646655
VideoStreamOptions videoStreamOptions;
647656
videoStreamOptions.crf = crf;
648657

658+
videoStreamOptions.device = torch::Device(std::string(device));
659+
videoStreamOptions.deviceVariant = "ffmpeg";
660+
649661
VideoEncoder encoder(
650662
frames,
651663
validateInt64ToInt(frame_rate, "frame_rate"),

src/torchcodec/_core/ops.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def encode_video_to_file_like(
212212
frame_rate: int,
213213
format: str,
214214
file_like: Union[io.RawIOBase, io.BufferedIOBase],
215+
device: str = "cpu",
215216
crf: Optional[int] = None,
216217
) -> None:
217218
"""Encode video frames to a file-like object.
@@ -221,6 +222,7 @@ def encode_video_to_file_like(
221222
frame_rate: Frame rate in frames per second
222223
format: Video format (e.g., "mp4", "mov", "mkv")
223224
file_like: File-like object that supports write() and seek() methods
225+
device: Device to use for encoding (default: "cpu")
224226
crf: Optional constant rate factor for encoding quality
225227
"""
226228
assert _pybind_ops is not None
@@ -230,6 +232,7 @@ def encode_video_to_file_like(
230232
frame_rate,
231233
format,
232234
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
235+
device,
233236
crf,
234237
)
235238

@@ -318,7 +321,8 @@ def encode_video_to_file_abstract(
318321
frames: torch.Tensor,
319322
frame_rate: int,
320323
filename: str,
321-
crf: Optional[int],
324+
device: str = "cpu",
325+
crf: Optional[int] = None,
322326
) -> None:
323327
return
324328

@@ -328,7 +332,8 @@ def encode_video_to_tensor_abstract(
328332
frames: torch.Tensor,
329333
frame_rate: int,
330334
format: str,
331-
crf: Optional[int],
335+
device: str = "cpu",
336+
crf: Optional[int] = None,
332337
) -> torch.Tensor:
333338
return torch.empty([], dtype=torch.long)
334339

@@ -339,6 +344,7 @@ def _encode_video_to_file_like_abstract(
339344
frame_rate: int,
340345
format: str,
341346
file_like_context: int,
347+
device: str = "cpu",
342348
crf: Optional[int] = None,
343349
) -> None:
344350
return

0 commit comments

Comments
 (0)