Skip to content

Commit 169484d

Browse files
mollyxuMolly Xu
andauthored
Generalize DeviceInterface to include audio decoding (#1010)
Co-authored-by: Molly Xu <[email protected]>
1 parent f991110 commit 169484d

File tree

7 files changed

+187
-148
lines changed

7 files changed

+187
-148
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ class BetaCudaDeviceInterface : public DeviceInterface {
4646
void convertAVFrameToFrameOutput(
4747
UniqueAVFrame& avFrame,
4848
FrameOutput& frameOutput,
49-
std::optional<torch::Tensor> preAllocatedOutputTensor =
50-
std::nullopt) override;
49+
std::optional<torch::Tensor> preAllocatedOutputTensor) override;
5150

5251
int sendPacket(ReferenceAVPacket& packet) override;
5352
int sendEOFPacket() override;

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 140 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ void CpuDeviceInterface::initializeVideo(
3535
const VideoStreamOptions& videoStreamOptions,
3636
const std::vector<std::unique_ptr<Transform>>& transforms,
3737
const std::optional<FrameDims>& resizedOutputDims) {
38+
avMediaType_ = AVMEDIA_TYPE_VIDEO;
3839
videoStreamOptions_ = videoStreamOptions;
3940
resizedOutputDims_ = resizedOutputDims;
4041

@@ -86,6 +87,13 @@ void CpuDeviceInterface::initializeVideo(
8687
initialized_ = true;
8788
}
8889

90+
void CpuDeviceInterface::initializeAudio(
91+
const AudioStreamOptions& audioStreamOptions) {
92+
avMediaType_ = AVMEDIA_TYPE_AUDIO;
93+
audioStreamOptions_ = audioStreamOptions;
94+
initialized_ = true;
95+
}
96+
8997
ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary(
9098
const FrameDims& outputDims) const {
9199
// swscale requires widths to be multiples of 32:
@@ -114,6 +122,20 @@ ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary(
114122
}
115123
}
116124

125+
void CpuDeviceInterface::convertAVFrameToFrameOutput(
126+
UniqueAVFrame& avFrame,
127+
FrameOutput& frameOutput,
128+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
129+
TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized.");
130+
131+
if (avMediaType_ == AVMEDIA_TYPE_AUDIO) {
132+
convertAudioAVFrameToFrameOutput(avFrame, frameOutput);
133+
} else {
134+
convertVideoAVFrameToFrameOutput(
135+
avFrame, frameOutput, preAllocatedOutputTensor);
136+
}
137+
}
138+
117139
// Note [preAllocatedOutputTensor with swscale and filtergraph]:
118140
// Callers may pass a pre-allocated tensor, where the output.data tensor will
119141
// be stored. This parameter is honored in any case, but it only leads to a
@@ -123,12 +145,10 @@ ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary(
123145
// TODO: Figure out whether that's possible!
124146
// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
125147
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
126-
void CpuDeviceInterface::convertAVFrameToFrameOutput(
148+
void CpuDeviceInterface::convertVideoAVFrameToFrameOutput(
127149
UniqueAVFrame& avFrame,
128150
FrameOutput& frameOutput,
129151
std::optional<torch::Tensor> preAllocatedOutputTensor) {
130-
TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized.");
131-
132152
// Note that we ignore the dimensions from the metadata; we don't even bother
133153
// storing them. The resized dimensions take priority. If we don't have any,
134154
// then we use the dimensions from the actual decoded frame. We use the actual
@@ -278,6 +298,123 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
278298
return rgbAVFrameToTensor(filterGraph_->convert(avFrame));
279299
}
280300

301+
void CpuDeviceInterface::convertAudioAVFrameToFrameOutput(
302+
UniqueAVFrame& srcAVFrame,
303+
FrameOutput& frameOutput) {
304+
AVSampleFormat srcSampleFormat =
305+
static_cast<AVSampleFormat>(srcAVFrame->format);
306+
AVSampleFormat outSampleFormat = AV_SAMPLE_FMT_FLTP;
307+
308+
int srcSampleRate = srcAVFrame->sample_rate;
309+
int outSampleRate = audioStreamOptions_.sampleRate.value_or(srcSampleRate);
310+
311+
int srcNumChannels = getNumChannels(codecContext_);
312+
TORCH_CHECK(
313+
srcNumChannels == getNumChannels(srcAVFrame),
314+
"The frame has ",
315+
getNumChannels(srcAVFrame),
316+
" channels, expected ",
317+
srcNumChannels,
318+
". If you are hitting this, it may be because you are using "
319+
"a buggy FFmpeg version. FFmpeg4 is known to fail here in some "
320+
"valid scenarios. Try to upgrade FFmpeg?");
321+
int outNumChannels = audioStreamOptions_.numChannels.value_or(srcNumChannels);
322+
323+
bool mustConvert =
324+
(srcSampleFormat != outSampleFormat || srcSampleRate != outSampleRate ||
325+
srcNumChannels != outNumChannels);
326+
327+
UniqueAVFrame convertedAVFrame;
328+
if (mustConvert) {
329+
if (!swrContext_) {
330+
swrContext_.reset(createSwrContext(
331+
srcSampleFormat,
332+
outSampleFormat,
333+
srcSampleRate,
334+
outSampleRate,
335+
srcAVFrame,
336+
outNumChannels));
337+
}
338+
339+
convertedAVFrame = convertAudioAVFrameSamples(
340+
swrContext_,
341+
srcAVFrame,
342+
outSampleFormat,
343+
outSampleRate,
344+
outNumChannels);
345+
}
346+
const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;
347+
348+
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
349+
TORCH_CHECK(
350+
format == outSampleFormat,
351+
"Something went wrong, the frame didn't get converted to the desired format. ",
352+
"Desired format = ",
353+
av_get_sample_fmt_name(outSampleFormat),
354+
"source format = ",
355+
av_get_sample_fmt_name(format));
356+
357+
int numChannels = getNumChannels(avFrame);
358+
TORCH_CHECK(
359+
numChannels == outNumChannels,
360+
"Something went wrong, the frame didn't get converted to the desired ",
361+
"number of channels = ",
362+
outNumChannels,
363+
". Got ",
364+
numChannels,
365+
" instead.");
366+
367+
auto numSamples = avFrame->nb_samples;
368+
369+
frameOutput.data = torch::empty({numChannels, numSamples}, torch::kFloat32);
370+
371+
if (numSamples > 0) {
372+
uint8_t* outputChannelData =
373+
static_cast<uint8_t*>(frameOutput.data.data_ptr());
374+
auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format);
375+
for (auto channel = 0; channel < numChannels;
376+
++channel, outputChannelData += numBytesPerChannel) {
377+
std::memcpy(
378+
outputChannelData,
379+
avFrame->extended_data[channel],
380+
numBytesPerChannel);
381+
}
382+
}
383+
}
384+
385+
std::optional<torch::Tensor> CpuDeviceInterface::maybeFlushAudioBuffers() {
386+
// When sample rate conversion is involved, swresample buffers some of the
387+
// samples in-between calls to swr_convert (see the libswresample docs).
388+
// That's because the last few samples in a given frame require future
389+
// samples from the next frame to be properly converted. This function
390+
// flushes out the samples that are stored in swresample's buffers.
391+
if (!swrContext_) {
392+
return std::nullopt;
393+
}
394+
auto numRemainingSamples = // this is an upper bound
395+
swr_get_out_samples(swrContext_.get(), 0);
396+
397+
if (numRemainingSamples == 0) {
398+
return std::nullopt;
399+
}
400+
401+
int numChannels =
402+
audioStreamOptions_.numChannels.value_or(getNumChannels(codecContext_));
403+
torch::Tensor lastSamples =
404+
torch::empty({numChannels, numRemainingSamples}, torch::kFloat32);
405+
406+
std::vector<uint8_t*> outputBuffers(numChannels);
407+
for (auto i = 0; i < numChannels; i++) {
408+
outputBuffers[i] = static_cast<uint8_t*>(lastSamples[i].data_ptr());
409+
}
410+
411+
auto actualNumRemainingSamples = swr_convert(
412+
swrContext_.get(), outputBuffers.data(), numRemainingSamples, nullptr, 0);
413+
414+
return lastSamples.narrow(
415+
/*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples);
416+
}
417+
281418
std::string CpuDeviceInterface::getDetails() {
282419
return std::string("CPU Device Interface.");
283420
}

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,28 @@ class CpuDeviceInterface : public DeviceInterface {
3333
const std::vector<std::unique_ptr<Transform>>& transforms,
3434
const std::optional<FrameDims>& resizedOutputDims) override;
3535

36+
virtual void initializeAudio(
37+
const AudioStreamOptions& audioStreamOptions) override;
38+
39+
virtual std::optional<torch::Tensor> maybeFlushAudioBuffers() override;
40+
3641
void convertAVFrameToFrameOutput(
3742
UniqueAVFrame& avFrame,
3843
FrameOutput& frameOutput,
39-
std::optional<torch::Tensor> preAllocatedOutputTensor =
40-
std::nullopt) override;
44+
std::optional<torch::Tensor> preAllocatedOutputTensor) override;
4145

4246
std::string getDetails() override;
4347

4448
private:
49+
void convertAudioAVFrameToFrameOutput(
50+
UniqueAVFrame& srcAVFrame,
51+
FrameOutput& frameOutput);
52+
53+
void convertVideoAVFrameToFrameOutput(
54+
UniqueAVFrame& avFrame,
55+
FrameOutput& frameOutput,
56+
std::optional<torch::Tensor> preAllocatedOutputTensor);
57+
4558
int convertAVFrameToTensorUsingSwScale(
4659
const UniqueAVFrame& avFrame,
4760
torch::Tensor& outputTensor,
@@ -108,6 +121,10 @@ class CpuDeviceInterface : public DeviceInterface {
108121
bool userRequestedSwScale_;
109122

110123
bool initialized_ = false;
124+
125+
// Audio-specific members
126+
AudioStreamOptions audioStreamOptions_;
127+
UniqueSwrContext swrContext_;
111128
};
112129

113130
} // namespace facebook::torchcodec

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ class CudaDeviceInterface : public DeviceInterface {
3737
void convertAVFrameToFrameOutput(
3838
UniqueAVFrame& avFrame,
3939
FrameOutput& frameOutput,
40-
std::optional<torch::Tensor> preAllocatedOutputTensor =
41-
std::nullopt) override;
40+
std::optional<torch::Tensor> preAllocatedOutputTensor) override;
4241

4342
std::string getDetails() override;
4443

src/torchcodec/_core/DeviceInterface.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,21 @@ class DeviceInterface {
6565
transforms,
6666
[[maybe_unused]] const std::optional<FrameDims>& resizedOutputDims) {}
6767

68+
// Initialize the device with parameters specific to audio decoding. There is
69+
// a default empty implementation.
70+
virtual void initializeAudio(
71+
[[maybe_unused]] const AudioStreamOptions& audioStreamOptions) {}
72+
73+
// Flush any remaining samples from the audio resampler buffer.
74+
// When sample rate conversion is involved, some samples may be buffered
75+
// between frames for proper interpolation. This function flushes those
76+
// buffered samples.
77+
// Returns an optional tensor containing the flushed samples, or std::nullopt
78+
// if there are no buffered samples or audio is not supported.
79+
virtual std::optional<torch::Tensor> maybeFlushAudioBuffers() {
80+
return std::nullopt;
81+
}
82+
6883
// In order for decoding to actually happen on an FFmpeg managed hardware
6984
// device, we need to register the DeviceInterface managed
7085
// AVHardwareDeviceContext with the AVCodecContext. We don't need to do this
@@ -126,6 +141,7 @@ class DeviceInterface {
126141
protected:
127142
torch::Device device_;
128143
SharedAVCodecContext codecContext_;
144+
AVMediaType avMediaType_;
129145
};
130146

131147
using CreateDeviceInterfaceFn =

0 commit comments

Comments
 (0)