Skip to content

Commit 58b4921

Browse files
committed
Use cuda filters to support 10-bit videos
For: #776 Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent 141c5b1 commit 58b4921

File tree

6 files changed

+99
-19
lines changed

6 files changed

+99
-19
lines changed

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,68 @@ void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) {
199199
return;
200200
}
201201

202+
std::unique_ptr<FiltersContext> CudaDeviceInterface::initializeFiltersContext(
203+
const VideoStreamOptions& videoStreamOptions,
204+
const UniqueAVFrame& avFrame,
205+
const AVRational& timeBase) {
206+
enum AVPixelFormat frameFormat =
207+
static_cast<enum AVPixelFormat>(avFrame->format);
208+
209+
if (avFrame->format != AV_PIX_FMT_CUDA) {
210+
auto cpuDevice = torch::Device(torch::kCPU);
211+
auto cpuInterface = createDeviceInterface(cpuDevice);
212+
return cpuInterface->initializeFiltersContext(
213+
videoStreamOptions, avFrame, timeBase);
214+
}
215+
216+
auto frameDims =
217+
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
218+
int height = frameDims.height;
219+
int width = frameDims.width;
220+
221+
auto hwFramesCtx =
222+
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
223+
AVPixelFormat actualFormat = hwFramesCtx->sw_format;
224+
225+
if (actualFormat == AV_PIX_FMT_NV12) {
226+
return nullptr;
227+
}
228+
229+
std::unique_ptr<FiltersContext> filtersContext =
230+
std::make_unique<FiltersContext>();
231+
232+
filtersContext->inputWidth = avFrame->width;
233+
filtersContext->inputHeight = avFrame->height;
234+
filtersContext->inputFormat = frameFormat;
235+
filtersContext->inputAspectRatio = avFrame->sample_aspect_ratio;
236+
filtersContext->timeBase = timeBase;
237+
filtersContext->hwFramesCtx.reset(av_buffer_ref(avFrame->hw_frames_ctx));
238+
239+
std::stringstream filters;
240+
241+
unsigned version_int = avfilter_version();
242+
if (version_int < AV_VERSION_INT(8, 0, 103)) {
243+
// Color conversion support ('format=' option) was added to scale_cuda from
244+
// n5.0. With the earlier version of ffmpeg we have no choice but use CPU
245+
// filters. See:
246+
// https://github.com/FFmpeg/FFmpeg/commit/62dc5df941f5e196164c151691e4274195523e95
247+
filtersContext->outputFormat = AV_PIX_FMT_RGB24;
248+
249+
filters << "hwdownload,format=" << av_pix_fmt_desc_get(actualFormat)->name;
250+
filters << ",scale=" << width << ":" << height;
251+
filters << ":sws_flags=bilinear";
252+
} else {
253+
// Actual output color format will be set via filter options
254+
filtersContext->outputFormat = AV_PIX_FMT_CUDA;
255+
256+
filters << "scale_cuda=" << width << ":" << height;
257+
filters << ":format=nv12:interp_algo=bilinear";
258+
}
259+
260+
filtersContext->filters = filters.str();
261+
return filtersContext;
262+
}
263+
202264
void CudaDeviceInterface::convertAVFrameToFrameOutput(
203265
const VideoStreamOptions& videoStreamOptions,
204266
[[maybe_unused]] const AVRational& timeBase,

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ class CudaDeviceInterface : public DeviceInterface {
2121

2222
void initializeContext(AVCodecContext* codecContext) override;
2323

24+
std::unique_ptr<FiltersContext> initializeFiltersContext(
25+
const VideoStreamOptions& videoStreamOptions,
26+
const UniqueAVFrame& avFrame,
27+
const AVRational& timeBase) override;
28+
2429
void convertAVFrameToFrameOutput(
2530
const VideoStreamOptions& videoStreamOptions,
2631
const AVRational& timeBase,

src/torchcodec/_core/DeviceInterface.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <stdexcept>
1313
#include <string>
1414
#include "FFMPEGCommon.h"
15+
#include "src/torchcodec/_core/FilterGraph.h"
1516
#include "src/torchcodec/_core/Frame.h"
1617
#include "src/torchcodec/_core/StreamOptions.h"
1718

@@ -41,6 +42,18 @@ class DeviceInterface {
4142
// support CUDA and others only support CPU.
4243
virtual void initializeContext(AVCodecContext* codecContext) = 0;
4344

45+
// Returns FilterContext if device interface can't handle conversion of the
46+
// frame on its own within a call to convertAVFrameToFrameOutput().
47+
// FilterContext contains input and output initialization parameters
48+
// describing required conversion. Output can further be passed to
49+
// convertAVFrameToFrameOutput() to generate output tensor.
50+
virtual std::unique_ptr<FiltersContext> initializeFiltersContext(
51+
[[maybe_unused]] const VideoStreamOptions& videoStreamOptions,
52+
[[maybe_unused]] const UniqueAVFrame& avFrame,
53+
[[maybe_unused]] const AVRational& timeBase) {
54+
return nullptr;
55+
};
56+
4457
virtual void convertAVFrameToFrameOutput(
4558
const VideoStreamOptions& videoStreamOptions,
4659
const AVRational& timeBase,

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,6 +1254,17 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput(
12541254
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
12551255
convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput);
12561256
} else if (deviceInterface_) {
1257+
std::unique_ptr<FiltersContext> filtersContext =
1258+
deviceInterface_->initializeFiltersContext(
1259+
streamInfo.videoStreamOptions, avFrame, streamInfo.timeBase);
1260+
if (filtersContext) {
1261+
if (!filterGraph_ || prevFiltersContext_ != filtersContext) {
1262+
filterGraph_ = std::make_unique<FilterGraph>(
1263+
*filtersContext, streamInfo.videoStreamOptions);
1264+
prevFiltersContext_ = std::move(filtersContext);
1265+
}
1266+
avFrame = filterGraph_->convert(avFrame);
1267+
}
12571268
deviceInterface_->convertAVFrameToFrameOutput(
12581269
streamInfo.videoStreamOptions,
12591270
streamInfo.timeBase,

src/torchcodec/_core/SingleStreamDecoder.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,10 @@ class SingleStreamDecoder {
351351
SeekMode seekMode_;
352352
ContainerMetadata containerMetadata_;
353353
UniqueDecodingAVFormatContext formatContext_;
354+
// Previous frame filter context. Used to know whether a new FilterGraph
355+
// should be created to process a next frame.
356+
std::unique_ptr<FiltersContext> prevFiltersContext_;
357+
std::unique_ptr<FilterGraph> filterGraph_;
354358
std::unique_ptr<DeviceInterface> deviceInterface_;
355359
std::map<int, StreamInfo> streamInfos_;
356360
const int NO_ACTIVE_STREAM = -2;

test/test_decoders.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,22 +1225,6 @@ def test_full_and_studio_range_bt709_video(self, asset):
12251225
elif cuda_version_used_for_building_torch() == (12, 8):
12261226
assert psnr(gpu_frame, cpu_frame) > 20
12271227

1228-
@needs_cuda
1229-
def test_10bit_videos_cuda(self):
1230-
# Assert that we raise proper error on different kinds of 10bit videos.
1231-
1232-
# TODO we should investigate how to support 10bit videos on GPU.
1233-
# See https://github.com/pytorch/torchcodec/issues/776
1234-
1235-
asset = H265_10BITS
1236-
1237-
decoder = VideoDecoder(asset.path, device="cuda")
1238-
with pytest.raises(
1239-
RuntimeError,
1240-
match="The AVFrame is p010le, but we expected AV_PIX_FMT_NV12.",
1241-
):
1242-
decoder.get_frame_at(0)
1243-
12441228
@needs_cuda
12451229
def test_10bit_gpu_fallsback_to_cpu(self):
12461230
# Test for 10-bit videos that aren't supported by NVDEC: we decode and
@@ -1272,12 +1256,13 @@ def test_10bit_gpu_fallsback_to_cpu(self):
12721256
frames_cpu = decoder_cpu.get_frames_at(frame_indices).data
12731257
assert_frames_equal(frames_gpu.cpu(), frames_cpu)
12741258

1259+
@pytest.mark.parametrize("device", all_supported_devices())
12751260
@pytest.mark.parametrize("asset", (H264_10BITS, H265_10BITS))
1276-
def test_10bit_videos_cpu(self, asset):
1277-
# This just validates that we can decode 10-bit videos on CPU.
1261+
def test_10bit_videos(self, device, asset):
1262+
# This just validates that we can decode 10-bit videos.
12781263
# TODO validate against the ref that the decoded frames are correct
12791264

1280-
decoder = VideoDecoder(asset.path)
1265+
decoder = VideoDecoder(asset.path, device=device)
12811266
decoder.get_frame_at(10)
12821267

12831268
def setup_frame_mappings(tmp_path, file, stream_index):

0 commit comments

Comments
 (0)