Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/torchcodec/_core/BetaCudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {

void BetaCudaDeviceInterface::initialize(
const AVStream* avStream,
const UniqueDecodingAVFormatContext& avFormatCtx) {
const UniqueDecodingAVFormatContext& avFormatCtx,
[[maybe_unused]] const SharedAVCodecContext& codecContext) {
TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
timeBase_ = avStream->time_base;
frameRateAvgFromFFmpeg_ = avStream->r_frame_rate;
Expand Down
7 changes: 2 additions & 5 deletions src/torchcodec/_core/BetaCudaDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,15 @@ class BetaCudaDeviceInterface : public DeviceInterface {

void initialize(
const AVStream* avStream,
const UniqueDecodingAVFormatContext& avFormatCtx) override;
const UniqueDecodingAVFormatContext& avFormatCtx,
const SharedAVCodecContext& codecContext) override;

void convertAVFrameToFrameOutput(
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor =
std::nullopt) override;

bool canDecodePacketDirectly() const override {
return true;
}

int sendPacket(ReferenceAVPacket& packet) override;
int sendEOFPacket() override;
int receiveFrame(UniqueAVFrame& avFrame) override;
Expand Down
4 changes: 3 additions & 1 deletion src/torchcodec/_core/CpuDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)

void CpuDeviceInterface::initialize(
const AVStream* avStream,
[[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx) {
[[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx,
const SharedAVCodecContext& codecContext) {
TORCH_CHECK(avStream != nullptr, "avStream is null");
codecContext_ = codecContext;
timeBase_ = avStream->time_base;
}

Expand Down
3 changes: 2 additions & 1 deletion src/torchcodec/_core/CpuDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class CpuDeviceInterface : public DeviceInterface {

virtual void initialize(
const AVStream* avStream,
const UniqueDecodingAVFormatContext& avFormatCtx) override;
const UniqueDecodingAVFormatContext& avFormatCtx,
const SharedAVCodecContext& codecContext) override;

virtual void initializeVideo(
const VideoStreamOptions& videoStreamOptions,
Expand Down
6 changes: 4 additions & 2 deletions src/torchcodec/_core/CudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,17 @@ CudaDeviceInterface::~CudaDeviceInterface() {

void CudaDeviceInterface::initialize(
const AVStream* avStream,
const UniqueDecodingAVFormatContext& avFormatCtx) {
const UniqueDecodingAVFormatContext& avFormatCtx,
const SharedAVCodecContext& codecContext) {
TORCH_CHECK(avStream != nullptr, "avStream is null");
codecContext_ = codecContext;
timeBase_ = avStream->time_base;

// TODO: Ideally, we should keep all interface implementations independent.
cpuInterface_ = createDeviceInterface(torch::kCPU);
TORCH_CHECK(
cpuInterface_ != nullptr, "Failed to create CPU device interface");
cpuInterface_->initialize(avStream, avFormatCtx);
cpuInterface_->initialize(avStream, avFormatCtx, codecContext);
cpuInterface_->initializeVideo(
VideoStreamOptions(),
{},
Expand Down
3 changes: 2 additions & 1 deletion src/torchcodec/_core/CudaDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class CudaDeviceInterface : public DeviceInterface {

void initialize(
const AVStream* avStream,
const UniqueDecodingAVFormatContext& avFormatCtx) override;
const UniqueDecodingAVFormatContext& avFormatCtx,
const SharedAVCodecContext& codecContext) override;

void initializeVideo(
const VideoStreamOptions& videoStreamOptions,
Expand Down
46 changes: 21 additions & 25 deletions src/torchcodec/_core/DeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ class DeviceInterface {
// Initialize the device with parameters generic to all kinds of decoding.
virtual void initialize(
const AVStream* avStream,
const UniqueDecodingAVFormatContext& avFormatCtx) = 0;
const UniqueDecodingAVFormatContext& avFormatCtx,
const SharedAVCodecContext& codecContext) = 0;

// Initialize the device with parameters specific to video decoding. There is
// a default empty implementation.
Expand All @@ -80,52 +81,47 @@ class DeviceInterface {
// Extension points for custom decoding paths
// ------------------------------------------

// Override to return true if this device interface can decode packets
// directly. This means that the following two member functions can both
// be called:
//
// 1. sendPacket()
// 2. receiveFrame()
virtual bool canDecodePacketDirectly() const {
return false;
}

// Moral equivalent of avcodec_send_packet()
// Returns AVSUCCESS on success, AVERROR(EAGAIN) if decoder queue full, or
// other AVERROR on failure
virtual int sendPacket([[maybe_unused]] ReferenceAVPacket& avPacket) {
// Default implementation uses FFmpeg directly
virtual int sendPacket(ReferenceAVPacket& avPacket) {
TORCH_CHECK(
false,
"Send/receive packet decoding not implemented for this device interface");
return AVERROR(ENOSYS);
codecContext_ != nullptr,
"Codec context not available for default packet sending");
return avcodec_send_packet(codecContext_.get(), avPacket.get());
}

// Send an EOF packet to flush the decoder
// Returns AVSUCCESS on success, or other AVERROR on failure
// Default implementation uses FFmpeg directly
virtual int sendEOFPacket() {
TORCH_CHECK(
false, "Send EOF packet not implemented for this device interface");
return AVERROR(ENOSYS);
codecContext_ != nullptr,
"Codec context not available for default EOF packet sending");
return avcodec_send_packet(codecContext_.get(), nullptr);
}

// Moral equivalent of avcodec_receive_frame()
// Returns AVSUCCESS on success, AVERROR(EAGAIN) if no frame ready,
// AVERROR_EOF if end of stream, or other AVERROR on failure
virtual int receiveFrame([[maybe_unused]] UniqueAVFrame& avFrame) {
// Default implementation uses FFmpeg directly
virtual int receiveFrame(UniqueAVFrame& avFrame) {
TORCH_CHECK(
false,
"Send/receive packet decoding not implemented for this device interface");
return AVERROR(ENOSYS);
codecContext_ != nullptr,
"Codec context not available for default frame receiving");
return avcodec_receive_frame(codecContext_.get(), avFrame.get());
}

// Flush remaining frames from decoder
virtual void flush() {
// Default implementation is no-op for standard decoders
// Custom decoders can override this method
TORCH_CHECK(
codecContext_ != nullptr,
"Codec context not available for default flushing");
avcodec_flush_buffers(codecContext_.get());
}

protected:
torch::Device device_;
SharedAVCodecContext codecContext_;
};

using CreateDeviceInterfaceFn =
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/_core/FFMPEGCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ int getNumChannels(const UniqueAVFrame& avFrame) {
#endif
}

int getNumChannels(const UniqueAVCodecContext& avCodecContext) {
int getNumChannels(const SharedAVCodecContext& avCodecContext) {
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
return avCodecContext->ch_layout.nb_channels;
Expand Down
10 changes: 9 additions & 1 deletion src/torchcodec/_core/FFMPEGCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ using UniqueEncodingAVFormatContext = std::unique_ptr<
using UniqueAVCodecContext = std::unique_ptr<
AVCodecContext,
Deleterp<AVCodecContext, void, avcodec_free_context>>;
using SharedAVCodecContext = std::shared_ptr<AVCodecContext>;

// create SharedAVCodecContext with custom deleter
inline SharedAVCodecContext makeSharedAVCodecContext(AVCodecContext* ctx) {
return SharedAVCodecContext(
ctx, Deleterp<AVCodecContext, void, avcodec_free_context>{});
}

using UniqueAVFrame =
std::unique_ptr<AVFrame, Deleterp<AVFrame, void, av_frame_free>>;
using UniqueAVFilterGraph = std::unique_ptr<
Expand Down Expand Up @@ -171,7 +179,7 @@ const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec);
const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec);

int getNumChannels(const UniqueAVFrame& avFrame);
int getNumChannels(const UniqueAVCodecContext& avCodecContext);
int getNumChannels(const SharedAVCodecContext& avCodecContext);

void setDefaultChannelLayout(
UniqueAVCodecContext& avCodecContext,
Expand Down
46 changes: 15 additions & 31 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,6 @@ void SingleStreamDecoder::addStream(
TORCH_CHECK(
deviceInterface_ != nullptr,
"Failed to create device interface. This should never happen, please report.");
deviceInterface_->initialize(streamInfo.stream, formatContext_);

// TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within
// addStream() which is supposed to be generic
Expand All @@ -441,7 +440,7 @@ void SingleStreamDecoder::addStream(

AVCodecContext* codecContext = avcodec_alloc_context3(avCodec);
TORCH_CHECK(codecContext != nullptr);
streamInfo.codecContext.reset(codecContext);
streamInfo.codecContext = makeSharedAVCodecContext(codecContext);

int retVal = avcodec_parameters_to_context(
streamInfo.codecContext.get(), streamInfo.stream->codecpar);
Expand All @@ -453,14 +452,19 @@ void SingleStreamDecoder::addStream(
// Note that we must make sure to register the harware device context
// with the codec context before calling avcodec_open2(). Otherwise, decoding
// will happen on the CPU and not the hardware device.
deviceInterface_->registerHardwareDeviceWithCodec(codecContext);
deviceInterface_->registerHardwareDeviceWithCodec(
streamInfo.codecContext.get());
retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr);
TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal));

codecContext->time_base = streamInfo.stream->time_base;
streamInfo.codecContext->time_base = streamInfo.stream->time_base;

// Initialize the device interface with the codec context
deviceInterface_->initialize(
streamInfo.stream, formatContext_, streamInfo.codecContext);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pointing for myself and for other reviewers that the call to initialize() was moved. Which makes sense, we now pass the codecContext to initialize(), and the codecContext is only initialized after the call to avcodec_open2.

We should just make sure there was no other call to the interface method that would depend on the interface being initialized. I think this is OK, the only call was deviceInterface_->findCodec(), and it's pretty much a stateless check.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think we're good here. My rationale with the device was to initialize it as soon as we had everything we needed to do it.


containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName =
std::string(avcodec_get_name(codecContext->codec_id));
std::string(avcodec_get_name(streamInfo.codecContext->codec_id));

// We will only need packets from the active stream, so we tell FFmpeg to
// discard packets from the other streams. Note that av_read_frame() may still
Expand Down Expand Up @@ -1149,8 +1153,6 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() {
getFFMPEGErrorStringFromErrorCode(status));

decodeStats_.numFlushes++;
avcodec_flush_buffers(streamInfo.codecContext.get());

deviceInterface_->flush();
}

Expand All @@ -1169,24 +1171,16 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
cursorWasJustSet_ = false;
}

StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
UniqueAVFrame avFrame(av_frame_alloc());
AutoAVPacket autoAVPacket;
int status = AVSUCCESS;
bool reachedEOF = false;

// TODONVDEC P2: Instead of calling canDecodePacketDirectly() and rely on
// if/else blocks to dispatch to the interface or to FFmpeg, consider *always*
// dispatching to the interface. The default implementation of the interface's
// receiveFrame and sendPacket could just be calling avcodec_receive_frame and
// avcodec_send_packet. This would make the decoding loop even more generic.
// The default implementation uses avcodec_receive_frame and
// avcodec_send_packet, while specialized interfaces can override for
// hardware-specific optimizations.
while (true) {
if (deviceInterface_->canDecodePacketDirectly()) {
status = deviceInterface_->receiveFrame(avFrame);
} else {
status =
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
}
status = deviceInterface_->receiveFrame(avFrame);

if (status != AVSUCCESS && status != AVERROR(EAGAIN)) {
// Non-retriable error
Expand Down Expand Up @@ -1222,13 +1216,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(

if (status == AVERROR_EOF) {
// End of file reached. We must drain the decoder
if (deviceInterface_->canDecodePacketDirectly()) {
status = deviceInterface_->sendEOFPacket();
} else {
status = avcodec_send_packet(
streamInfo.codecContext.get(),
/*avpkt=*/nullptr);
}
status = deviceInterface_->sendEOFPacket();
TORCH_CHECK(
status >= AVSUCCESS,
"Could not flush decoder: ",
Expand All @@ -1253,11 +1241,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(

// We got a valid packet. Send it to the decoder, and we'll receive it in
// the next iteration.
if (deviceInterface_->canDecodePacketDirectly()) {
status = deviceInterface_->sendPacket(packet);
} else {
status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get());
}
status = deviceInterface_->sendPacket(packet);
TORCH_CHECK(
status >= AVSUCCESS,
"Could not push packet to decoder: ",
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/_core/SingleStreamDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class SingleStreamDecoder {
AVMediaType avMediaType = AVMEDIA_TYPE_UNKNOWN;

AVRational timeBase = {};
UniqueAVCodecContext codecContext;
SharedAVCodecContext codecContext;

// The FrameInfo indices we built when scanFileAndUpdateMetadataAndIndex was
// called.
Expand Down
Loading