Skip to content

Commit 3827dfe

Browse files
mollyxuMolly Xu
andauthored
Refactor receiveFrame and sendPacket logic to dispatch directly to interface (#954)
Co-authored-by: Molly Xu <[email protected]>
1 parent 0816c8b commit 3827dfe

11 files changed

+62
-70
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,8 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
230230

231231
void BetaCudaDeviceInterface::initialize(
232232
const AVStream* avStream,
233-
const UniqueDecodingAVFormatContext& avFormatCtx) {
233+
const UniqueDecodingAVFormatContext& avFormatCtx,
234+
[[maybe_unused]] const SharedAVCodecContext& codecContext) {
234235
TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
235236
timeBase_ = avStream->time_base;
236237
frameRateAvgFromFFmpeg_ = avStream->r_frame_rate;

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,15 @@ class BetaCudaDeviceInterface : public DeviceInterface {
4040

4141
void initialize(
4242
const AVStream* avStream,
43-
const UniqueDecodingAVFormatContext& avFormatCtx) override;
43+
const UniqueDecodingAVFormatContext& avFormatCtx,
44+
const SharedAVCodecContext& codecContext) override;
4445

4546
void convertAVFrameToFrameOutput(
4647
UniqueAVFrame& avFrame,
4748
FrameOutput& frameOutput,
4849
std::optional<torch::Tensor> preAllocatedOutputTensor =
4950
std::nullopt) override;
5051

51-
bool canDecodePacketDirectly() const override {
52-
return true;
53-
}
54-
5552
int sendPacket(ReferenceAVPacket& packet) override;
5653
int sendEOFPacket() override;
5754
int receiveFrame(UniqueAVFrame& avFrame) override;

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
4848

4949
void CpuDeviceInterface::initialize(
5050
const AVStream* avStream,
51-
[[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx) {
51+
[[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx,
52+
const SharedAVCodecContext& codecContext) {
5253
TORCH_CHECK(avStream != nullptr, "avStream is null");
54+
codecContext_ = codecContext;
5355
timeBase_ = avStream->time_base;
5456
}
5557

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ class CpuDeviceInterface : public DeviceInterface {
2525

2626
virtual void initialize(
2727
const AVStream* avStream,
28-
const UniqueDecodingAVFormatContext& avFormatCtx) override;
28+
const UniqueDecodingAVFormatContext& avFormatCtx,
29+
const SharedAVCodecContext& codecContext) override;
2930

3031
virtual void initializeVideo(
3132
const VideoStreamOptions& videoStreamOptions,

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,17 @@ CudaDeviceInterface::~CudaDeviceInterface() {
114114

115115
void CudaDeviceInterface::initialize(
116116
const AVStream* avStream,
117-
const UniqueDecodingAVFormatContext& avFormatCtx) {
117+
const UniqueDecodingAVFormatContext& avFormatCtx,
118+
const SharedAVCodecContext& codecContext) {
118119
TORCH_CHECK(avStream != nullptr, "avStream is null");
120+
codecContext_ = codecContext;
119121
timeBase_ = avStream->time_base;
120122

121123
// TODO: Ideally, we should keep all interface implementations independent.
122124
cpuInterface_ = createDeviceInterface(torch::kCPU);
123125
TORCH_CHECK(
124126
cpuInterface_ != nullptr, "Failed to create CPU device interface");
125-
cpuInterface_->initialize(avStream, avFormatCtx);
127+
cpuInterface_->initialize(avStream, avFormatCtx, codecContext);
126128
cpuInterface_->initializeVideo(
127129
VideoStreamOptions(),
128130
{},

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ class CudaDeviceInterface : public DeviceInterface {
2222

2323
void initialize(
2424
const AVStream* avStream,
25-
const UniqueDecodingAVFormatContext& avFormatCtx) override;
25+
const UniqueDecodingAVFormatContext& avFormatCtx,
26+
const SharedAVCodecContext& codecContext) override;
2627

2728
void initializeVideo(
2829
const VideoStreamOptions& videoStreamOptions,

src/torchcodec/_core/DeviceInterface.h

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ class DeviceInterface {
5454
// Initialize the device with parameters generic to all kinds of decoding.
5555
virtual void initialize(
5656
const AVStream* avStream,
57-
const UniqueDecodingAVFormatContext& avFormatCtx) = 0;
57+
const UniqueDecodingAVFormatContext& avFormatCtx,
58+
const SharedAVCodecContext& codecContext) = 0;
5859

5960
// Initialize the device with parameters specific to video decoding. There is
6061
// a default empty implementation.
@@ -80,52 +81,47 @@ class DeviceInterface {
8081
// Extension points for custom decoding paths
8182
// ------------------------------------------
8283

83-
// Override to return true if this device interface can decode packets
84-
// directly. This means that the following two member functions can both
85-
// be called:
86-
//
87-
// 1. sendPacket()
88-
// 2. receiveFrame()
89-
virtual bool canDecodePacketDirectly() const {
90-
return false;
91-
}
92-
93-
// Moral equivalent of avcodec_send_packet()
9484
// Returns AVSUCCESS on success, AVERROR(EAGAIN) if decoder queue full, or
9585
// other AVERROR on failure
96-
virtual int sendPacket([[maybe_unused]] ReferenceAVPacket& avPacket) {
86+
// Default implementation uses FFmpeg directly
87+
virtual int sendPacket(ReferenceAVPacket& avPacket) {
9788
TORCH_CHECK(
98-
false,
99-
"Send/receive packet decoding not implemented for this device interface");
100-
return AVERROR(ENOSYS);
89+
codecContext_ != nullptr,
90+
"Codec context not available for default packet sending");
91+
return avcodec_send_packet(codecContext_.get(), avPacket.get());
10192
}
10293

10394
// Send an EOF packet to flush the decoder
10495
// Returns AVSUCCESS on success, or other AVERROR on failure
96+
// Default implementation uses FFmpeg directly
10597
virtual int sendEOFPacket() {
10698
TORCH_CHECK(
107-
false, "Send EOF packet not implemented for this device interface");
108-
return AVERROR(ENOSYS);
99+
codecContext_ != nullptr,
100+
"Codec context not available for default EOF packet sending");
101+
return avcodec_send_packet(codecContext_.get(), nullptr);
109102
}
110103

111-
// Moral equivalent of avcodec_receive_frame()
112104
// Returns AVSUCCESS on success, AVERROR(EAGAIN) if no frame ready,
113105
// AVERROR_EOF if end of stream, or other AVERROR on failure
114-
virtual int receiveFrame([[maybe_unused]] UniqueAVFrame& avFrame) {
106+
// Default implementation uses FFmpeg directly
107+
virtual int receiveFrame(UniqueAVFrame& avFrame) {
115108
TORCH_CHECK(
116-
false,
117-
"Send/receive packet decoding not implemented for this device interface");
118-
return AVERROR(ENOSYS);
109+
codecContext_ != nullptr,
110+
"Codec context not available for default frame receiving");
111+
return avcodec_receive_frame(codecContext_.get(), avFrame.get());
119112
}
120113

121114
// Flush remaining frames from decoder
122115
virtual void flush() {
123-
// Default implementation is no-op for standard decoders
124-
// Custom decoders can override this method
116+
TORCH_CHECK(
117+
codecContext_ != nullptr,
118+
"Codec context not available for default flushing");
119+
avcodec_flush_buffers(codecContext_.get());
125120
}
126121

127122
protected:
128123
torch::Device device_;
124+
SharedAVCodecContext codecContext_;
129125
};
130126

131127
using CreateDeviceInterfaceFn =

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ int getNumChannels(const UniqueAVFrame& avFrame) {
149149
#endif
150150
}
151151

152-
int getNumChannels(const UniqueAVCodecContext& avCodecContext) {
152+
int getNumChannels(const SharedAVCodecContext& avCodecContext) {
153153
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
154154
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
155155
return avCodecContext->ch_layout.nb_channels;

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ using UniqueEncodingAVFormatContext = std::unique_ptr<
7171
using UniqueAVCodecContext = std::unique_ptr<
7272
AVCodecContext,
7373
Deleterp<AVCodecContext, void, avcodec_free_context>>;
74+
using SharedAVCodecContext = std::shared_ptr<AVCodecContext>;
75+
76+
// create SharedAVCodecContext with custom deleter
77+
inline SharedAVCodecContext makeSharedAVCodecContext(AVCodecContext* ctx) {
78+
return SharedAVCodecContext(
79+
ctx, Deleterp<AVCodecContext, void, avcodec_free_context>{});
80+
}
81+
7482
using UniqueAVFrame =
7583
std::unique_ptr<AVFrame, Deleterp<AVFrame, void, av_frame_free>>;
7684
using UniqueAVFilterGraph = std::unique_ptr<
@@ -171,7 +179,7 @@ const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec);
171179
const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec);
172180

173181
int getNumChannels(const UniqueAVFrame& avFrame);
174-
int getNumChannels(const UniqueAVCodecContext& avCodecContext);
182+
int getNumChannels(const SharedAVCodecContext& avCodecContext);
175183

176184
void setDefaultChannelLayout(
177185
UniqueAVCodecContext& avCodecContext,

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,6 @@ void SingleStreamDecoder::addStream(
429429
TORCH_CHECK(
430430
deviceInterface_ != nullptr,
431431
"Failed to create device interface. This should never happen, please report.");
432-
deviceInterface_->initialize(streamInfo.stream, formatContext_);
433432

434433
// TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within
435434
// addStream() which is supposed to be generic
@@ -441,7 +440,7 @@ void SingleStreamDecoder::addStream(
441440

442441
AVCodecContext* codecContext = avcodec_alloc_context3(avCodec);
443442
TORCH_CHECK(codecContext != nullptr);
444-
streamInfo.codecContext.reset(codecContext);
443+
streamInfo.codecContext = makeSharedAVCodecContext(codecContext);
445444

446445
int retVal = avcodec_parameters_to_context(
447446
streamInfo.codecContext.get(), streamInfo.stream->codecpar);
@@ -453,14 +452,19 @@ void SingleStreamDecoder::addStream(
453452
// Note that we must make sure to register the harware device context
454453
// with the codec context before calling avcodec_open2(). Otherwise, decoding
455454
// will happen on the CPU and not the hardware device.
456-
deviceInterface_->registerHardwareDeviceWithCodec(codecContext);
455+
deviceInterface_->registerHardwareDeviceWithCodec(
456+
streamInfo.codecContext.get());
457457
retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr);
458458
TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal));
459459

460-
codecContext->time_base = streamInfo.stream->time_base;
460+
streamInfo.codecContext->time_base = streamInfo.stream->time_base;
461+
462+
// Initialize the device interface with the codec context
463+
deviceInterface_->initialize(
464+
streamInfo.stream, formatContext_, streamInfo.codecContext);
461465

462466
containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName =
463-
std::string(avcodec_get_name(codecContext->codec_id));
467+
std::string(avcodec_get_name(streamInfo.codecContext->codec_id));
464468

465469
// We will only need packets from the active stream, so we tell FFmpeg to
466470
// discard packets from the other streams. Note that av_read_frame() may still
@@ -1149,8 +1153,6 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() {
11491153
getFFMPEGErrorStringFromErrorCode(status));
11501154

11511155
decodeStats_.numFlushes++;
1152-
avcodec_flush_buffers(streamInfo.codecContext.get());
1153-
11541156
deviceInterface_->flush();
11551157
}
11561158

@@ -1169,24 +1171,16 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
11691171
cursorWasJustSet_ = false;
11701172
}
11711173

1172-
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
11731174
UniqueAVFrame avFrame(av_frame_alloc());
11741175
AutoAVPacket autoAVPacket;
11751176
int status = AVSUCCESS;
11761177
bool reachedEOF = false;
11771178

1178-
// TODONVDEC P2: Instead of calling canDecodePacketDirectly() and rely on
1179-
// if/else blocks to dispatch to the interface or to FFmpeg, consider *always*
1180-
// dispatching to the interface. The default implementation of the interface's
1181-
// receiveFrame and sendPacket could just be calling avcodec_receive_frame and
1182-
// avcodec_send_packet. This would make the decoding loop even more generic.
1179+
// The default implementation uses avcodec_receive_frame and
1180+
// avcodec_send_packet, while specialized interfaces can override for
1181+
// hardware-specific optimizations.
11831182
while (true) {
1184-
if (deviceInterface_->canDecodePacketDirectly()) {
1185-
status = deviceInterface_->receiveFrame(avFrame);
1186-
} else {
1187-
status =
1188-
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
1189-
}
1183+
status = deviceInterface_->receiveFrame(avFrame);
11901184

11911185
if (status != AVSUCCESS && status != AVERROR(EAGAIN)) {
11921186
// Non-retriable error
@@ -1222,13 +1216,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
12221216

12231217
if (status == AVERROR_EOF) {
12241218
// End of file reached. We must drain the decoder
1225-
if (deviceInterface_->canDecodePacketDirectly()) {
1226-
status = deviceInterface_->sendEOFPacket();
1227-
} else {
1228-
status = avcodec_send_packet(
1229-
streamInfo.codecContext.get(),
1230-
/*avpkt=*/nullptr);
1231-
}
1219+
status = deviceInterface_->sendEOFPacket();
12321220
TORCH_CHECK(
12331221
status >= AVSUCCESS,
12341222
"Could not flush decoder: ",
@@ -1253,11 +1241,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
12531241

12541242
// We got a valid packet. Send it to the decoder, and we'll receive it in
12551243
// the next iteration.
1256-
if (deviceInterface_->canDecodePacketDirectly()) {
1257-
status = deviceInterface_->sendPacket(packet);
1258-
} else {
1259-
status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get());
1260-
}
1244+
status = deviceInterface_->sendPacket(packet);
12611245
TORCH_CHECK(
12621246
status >= AVSUCCESS,
12631247
"Could not push packet to decoder: ",

0 commit comments

Comments
 (0)