diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index 78fa8d635..0bdd91a23 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -231,7 +231,8 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() { void BetaCudaDeviceInterface::initialize( const AVStream* avStream, - const UniqueDecodingAVFormatContext& avFormatCtx) { + const UniqueDecodingAVFormatContext& avFormatCtx, + [[maybe_unused]] const SharedAVCodecContext& codecContext) { TORCH_CHECK(avStream != nullptr, "AVStream cannot be null"); timeBase_ = avStream->time_base; frameRateAvgFromFFmpeg_ = avStream->r_frame_rate; diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h index 0bf9951d6..fb01415d4 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.h +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h @@ -40,7 +40,8 @@ class BetaCudaDeviceInterface : public DeviceInterface { void initialize( const AVStream* avStream, - const UniqueDecodingAVFormatContext& avFormatCtx) override; + const UniqueDecodingAVFormatContext& avFormatCtx, + const SharedAVCodecContext& codecContext) override; void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, @@ -48,10 +49,6 @@ class BetaCudaDeviceInterface : public DeviceInterface { std::optional preAllocatedOutputTensor = std::nullopt) override; - bool canDecodePacketDirectly() const override { - return true; - } - int sendPacket(ReferenceAVPacket& packet) override; int sendEOFPacket() override; int receiveFrame(UniqueAVFrame& avFrame) override; diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index e6b96e3e4..0e9b46434 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -48,8 +48,10 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device) void CpuDeviceInterface::initialize( const AVStream* avStream, - [[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx) { + [[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx, + const SharedAVCodecContext& codecContext) { TORCH_CHECK(avStream != nullptr, "avStream is null"); + codecContext_ = codecContext; timeBase_ = avStream->time_base; } diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 399b0c6be..9f44c4e8c 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -25,7 +25,8 @@ class CpuDeviceInterface : public DeviceInterface { virtual void initialize( const AVStream* avStream, - const UniqueDecodingAVFormatContext& avFormatCtx) override; + const UniqueDecodingAVFormatContext& avFormatCtx, + const SharedAVCodecContext& codecContext) override; virtual void initializeVideo( const VideoStreamOptions& videoStreamOptions, diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index aea2b2d9a..ba8e495b8 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -117,15 +117,17 @@ CudaDeviceInterface::~CudaDeviceInterface() { void CudaDeviceInterface::initialize( const AVStream* avStream, - const UniqueDecodingAVFormatContext& avFormatCtx) { + const UniqueDecodingAVFormatContext& avFormatCtx, + const SharedAVCodecContext& codecContext) { TORCH_CHECK(avStream != nullptr, "avStream is null"); + codecContext_ = codecContext; timeBase_ = avStream->time_base; // TODO: Ideally, we should keep all interface implementations independent. cpuInterface_ = createDeviceInterface(torch::kCPU); TORCH_CHECK( cpuInterface_ != nullptr, "Failed to create CPU device interface"); - cpuInterface_->initialize(avStream, avFormatCtx); + cpuInterface_->initialize(avStream, avFormatCtx, codecContext); cpuInterface_->initializeVideo( VideoStreamOptions(), {}, diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index 1a8f184ec..d240066f4 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -22,7 +22,8 @@ class CudaDeviceInterface : public DeviceInterface { void initialize( const AVStream* avStream, - const UniqueDecodingAVFormatContext& avFormatCtx) override; + const UniqueDecodingAVFormatContext& avFormatCtx, + const SharedAVCodecContext& codecContext) override; void initializeVideo( const VideoStreamOptions& videoStreamOptions, diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index cac29e838..25a36a40f 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -54,7 +54,8 @@ class DeviceInterface { // Initialize the device with parameters generic to all kinds of decoding. virtual void initialize( const AVStream* avStream, - const UniqueDecodingAVFormatContext& avFormatCtx) = 0; + const UniqueDecodingAVFormatContext& avFormatCtx, + const SharedAVCodecContext& codecContext) = 0; // Initialize the device with parameters specific to video decoding. There is // a default empty implementation. @@ -80,52 +81,47 @@ class DeviceInterface { // Extension points for custom decoding paths // ------------------------------------------ - // Override to return true if this device interface can decode packets - // directly. This means that the following two member functions can both - // be called: - // - // 1. sendPacket() - // 2. receiveFrame() - virtual bool canDecodePacketDirectly() const { - return false; - } - - // Moral equivalent of avcodec_send_packet() // Returns AVSUCCESS on success, AVERROR(EAGAIN) if decoder queue full, or // other AVERROR on failure - virtual int sendPacket([[maybe_unused]] ReferenceAVPacket& avPacket) { + // Default implementation uses FFmpeg directly + virtual int sendPacket(ReferenceAVPacket& avPacket) { TORCH_CHECK( - false, - "Send/receive packet decoding not implemented for this device interface"); - return AVERROR(ENOSYS); + codecContext_ != nullptr, + "Codec context not available for default packet sending"); + return avcodec_send_packet(codecContext_.get(), avPacket.get()); } // Send an EOF packet to flush the decoder // Returns AVSUCCESS on success, or other AVERROR on failure + // Default implementation uses FFmpeg directly virtual int sendEOFPacket() { TORCH_CHECK( - false, "Send EOF packet not implemented for this device interface"); - return AVERROR(ENOSYS); + codecContext_ != nullptr, + "Codec context not available for default EOF packet sending"); + return avcodec_send_packet(codecContext_.get(), nullptr); } - // Moral equivalent of avcodec_receive_frame() // Returns AVSUCCESS on success, AVERROR(EAGAIN) if no frame ready, // AVERROR_EOF if end of stream, or other AVERROR on failure - virtual int receiveFrame([[maybe_unused]] UniqueAVFrame& avFrame) { + // Default implementation uses FFmpeg directly + virtual int receiveFrame(UniqueAVFrame& avFrame) { TORCH_CHECK( - false, - "Send/receive packet decoding not implemented for this device interface"); - return AVERROR(ENOSYS); + codecContext_ != nullptr, + "Codec context not available for default frame receiving"); + return avcodec_receive_frame(codecContext_.get(), avFrame.get()); } // Flush remaining frames from decoder virtual void flush() { - // Default implementation is no-op for standard decoders - // Custom decoders can override this method + TORCH_CHECK( + codecContext_ != nullptr, + "Codec context not available for default flushing"); + avcodec_flush_buffers(codecContext_.get()); } protected: torch::Device device_; + SharedAVCodecContext codecContext_; }; using CreateDeviceInterfaceFn = diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index 0570f06cf..97ff082e1 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -149,7 +149,7 @@ int getNumChannels(const UniqueAVFrame& avFrame) { #endif } -int getNumChannels(const UniqueAVCodecContext& avCodecContext) { +int getNumChannels(const SharedAVCodecContext& avCodecContext) { #if LIBAVFILTER_VERSION_MAJOR > 8 || \ (LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44) return avCodecContext->ch_layout.nb_channels; diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 19cddcc37..337616ddc 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -71,6 +71,14 @@ using UniqueEncodingAVFormatContext = std::unique_ptr< using UniqueAVCodecContext = std::unique_ptr< AVCodecContext, Deleterp>; +using SharedAVCodecContext = std::shared_ptr; + +// create SharedAVCodecContext with custom deleter +inline SharedAVCodecContext makeSharedAVCodecContext(AVCodecContext* ctx) { + return SharedAVCodecContext( + ctx, Deleterp{}); +} + using UniqueAVFrame = std::unique_ptr>; using UniqueAVFilterGraph = std::unique_ptr< @@ -171,7 +179,7 @@ const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec); const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec); int getNumChannels(const UniqueAVFrame& avFrame); -int getNumChannels(const UniqueAVCodecContext& avCodecContext); +int getNumChannels(const SharedAVCodecContext& avCodecContext); void setDefaultChannelLayout( UniqueAVCodecContext& avCodecContext, diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index d06c47922..ba7382c67 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -429,7 +429,6 @@ void SingleStreamDecoder::addStream( TORCH_CHECK( deviceInterface_ != nullptr, "Failed to create device interface. This should never happen, please report."); - deviceInterface_->initialize(streamInfo.stream, formatContext_); // TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within // addStream() which is supposed to be generic @@ -441,7 +440,7 @@ void SingleStreamDecoder::addStream( AVCodecContext* codecContext = avcodec_alloc_context3(avCodec); TORCH_CHECK(codecContext != nullptr); - streamInfo.codecContext.reset(codecContext); + streamInfo.codecContext = makeSharedAVCodecContext(codecContext); int retVal = avcodec_parameters_to_context( streamInfo.codecContext.get(), streamInfo.stream->codecpar); @@ -453,14 +452,19 @@ void SingleStreamDecoder::addStream( // Note that we must make sure to register the harware device context // with the codec context before calling avcodec_open2(). Otherwise, decoding // will happen on the CPU and not the hardware device. - deviceInterface_->registerHardwareDeviceWithCodec(codecContext); + deviceInterface_->registerHardwareDeviceWithCodec( + streamInfo.codecContext.get()); retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr); TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal)); - codecContext->time_base = streamInfo.stream->time_base; + streamInfo.codecContext->time_base = streamInfo.stream->time_base; + + // Initialize the device interface with the codec context + deviceInterface_->initialize( + streamInfo.stream, formatContext_, streamInfo.codecContext); containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName = - std::string(avcodec_get_name(codecContext->codec_id)); + std::string(avcodec_get_name(streamInfo.codecContext->codec_id)); // We will only need packets from the active stream, so we tell FFmpeg to // discard packets from the other streams. Note that av_read_frame() may still @@ -1149,8 +1153,6 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() { getFFMPEGErrorStringFromErrorCode(status)); decodeStats_.numFlushes++; - avcodec_flush_buffers(streamInfo.codecContext.get()); - deviceInterface_->flush(); } @@ -1169,24 +1171,16 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( cursorWasJustSet_ = false; } - StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; UniqueAVFrame avFrame(av_frame_alloc()); AutoAVPacket autoAVPacket; int status = AVSUCCESS; bool reachedEOF = false; - // TODONVDEC P2: Instead of calling canDecodePacketDirectly() and rely on - // if/else blocks to dispatch to the interface or to FFmpeg, consider *always* - // dispatching to the interface. The default implementation of the interface's - // receiveFrame and sendPacket could just be calling avcodec_receive_frame and - // avcodec_send_packet. This would make the decoding loop even more generic. + // The default implementation uses avcodec_receive_frame and + // avcodec_send_packet, while specialized interfaces can override for + // hardware-specific optimizations. while (true) { - if (deviceInterface_->canDecodePacketDirectly()) { - status = deviceInterface_->receiveFrame(avFrame); - } else { - status = - avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get()); - } + status = deviceInterface_->receiveFrame(avFrame); if (status != AVSUCCESS && status != AVERROR(EAGAIN)) { // Non-retriable error @@ -1222,13 +1216,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( if (status == AVERROR_EOF) { // End of file reached. We must drain the decoder - if (deviceInterface_->canDecodePacketDirectly()) { - status = deviceInterface_->sendEOFPacket(); - } else { - status = avcodec_send_packet( - streamInfo.codecContext.get(), - /*avpkt=*/nullptr); - } + status = deviceInterface_->sendEOFPacket(); TORCH_CHECK( status >= AVSUCCESS, "Could not flush decoder: ", @@ -1253,11 +1241,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( // We got a valid packet. Send it to the decoder, and we'll receive it in // the next iteration. - if (deviceInterface_->canDecodePacketDirectly()) { - status = deviceInterface_->sendPacket(packet); - } else { - status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get()); - } + status = deviceInterface_->sendPacket(packet); TORCH_CHECK( status >= AVSUCCESS, "Could not push packet to decoder: ", diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 48821ff09..10f820550 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -221,7 +221,7 @@ class SingleStreamDecoder { AVMediaType avMediaType = AVMEDIA_TYPE_UNKNOWN; AVRational timeBase = {}; - UniqueAVCodecContext codecContext; + SharedAVCodecContext codecContext; // The FrameInfo indices we built when scanFileAndUpdateMetadataAndIndex was // called.