@@ -35,6 +35,7 @@ void CpuDeviceInterface::initializeVideo(
3535 const VideoStreamOptions& videoStreamOptions,
3636 const std::vector<std::unique_ptr<Transform>>& transforms,
3737 const std::optional<FrameDims>& resizedOutputDims) {
38+ avMediaType_ = AVMEDIA_TYPE_VIDEO;
3839 videoStreamOptions_ = videoStreamOptions;
3940 resizedOutputDims_ = resizedOutputDims;
4041
@@ -86,6 +87,13 @@ void CpuDeviceInterface::initializeVideo(
8687 initialized_ = true ;
8788}
8889
90+ void CpuDeviceInterface::initializeAudio (
91+ const AudioStreamOptions& audioStreamOptions) {
92+ avMediaType_ = AVMEDIA_TYPE_AUDIO;
93+ audioStreamOptions_ = audioStreamOptions;
94+ initialized_ = true ;
95+ }
96+
8997ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary (
9098 const FrameDims& outputDims) const {
9199 // swscale requires widths to be multiples of 32:
@@ -114,6 +122,20 @@ ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary(
114122 }
115123}
116124
125+ void CpuDeviceInterface::convertAVFrameToFrameOutput (
126+ UniqueAVFrame& avFrame,
127+ FrameOutput& frameOutput,
128+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
129+ TORCH_CHECK (initialized_, " CpuDeviceInterface was not initialized." );
130+
131+ if (avMediaType_ == AVMEDIA_TYPE_AUDIO) {
132+ convertAudioAVFrameToFrameOutput (avFrame, frameOutput);
133+ } else {
134+ convertVideoAVFrameToFrameOutput (
135+ avFrame, frameOutput, preAllocatedOutputTensor);
136+ }
137+ }
138+
117139// Note [preAllocatedOutputTensor with swscale and filtergraph]:
118140// Callers may pass a pre-allocated tensor, where the output.data tensor will
119141// be stored. This parameter is honored in any case, but it only leads to a
@@ -123,12 +145,10 @@ ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary(
123145// TODO: Figure out whether that's possible!
124146// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
125147// `dimension_order` parameter. It's up to callers to re-shape it if needed.
126- void CpuDeviceInterface::convertAVFrameToFrameOutput (
148+ void CpuDeviceInterface::convertVideoAVFrameToFrameOutput (
127149 UniqueAVFrame& avFrame,
128150 FrameOutput& frameOutput,
129151 std::optional<torch::Tensor> preAllocatedOutputTensor) {
130- TORCH_CHECK (initialized_, " CpuDeviceInterface was not initialized." );
131-
132152 // Note that we ignore the dimensions from the metadata; we don't even bother
133153 // storing them. The resized dimensions take priority. If we don't have any,
134154 // then we use the dimensions from the actual decoded frame. We use the actual
@@ -278,6 +298,123 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
278298 return rgbAVFrameToTensor (filterGraph_->convert (avFrame));
279299}
280300
301+ void CpuDeviceInterface::convertAudioAVFrameToFrameOutput (
302+ UniqueAVFrame& srcAVFrame,
303+ FrameOutput& frameOutput) {
304+ AVSampleFormat srcSampleFormat =
305+ static_cast <AVSampleFormat>(srcAVFrame->format );
306+ AVSampleFormat outSampleFormat = AV_SAMPLE_FMT_FLTP;
307+
308+ int srcSampleRate = srcAVFrame->sample_rate ;
309+ int outSampleRate = audioStreamOptions_.sampleRate .value_or (srcSampleRate);
310+
311+ int srcNumChannels = getNumChannels (codecContext_);
312+ TORCH_CHECK (
313+ srcNumChannels == getNumChannels (srcAVFrame),
314+ " The frame has " ,
315+ getNumChannels (srcAVFrame),
316+ " channels, expected " ,
317+ srcNumChannels,
318+ " . If you are hitting this, it may be because you are using "
319+ " a buggy FFmpeg version. FFmpeg4 is known to fail here in some "
320+ " valid scenarios. Try to upgrade FFmpeg?" );
321+ int outNumChannels = audioStreamOptions_.numChannels .value_or (srcNumChannels);
322+
323+ bool mustConvert =
324+ (srcSampleFormat != outSampleFormat || srcSampleRate != outSampleRate ||
325+ srcNumChannels != outNumChannels);
326+
327+ UniqueAVFrame convertedAVFrame;
328+ if (mustConvert) {
329+ if (!swrContext_) {
330+ swrContext_.reset (createSwrContext (
331+ srcSampleFormat,
332+ outSampleFormat,
333+ srcSampleRate,
334+ outSampleRate,
335+ srcAVFrame,
336+ outNumChannels));
337+ }
338+
339+ convertedAVFrame = convertAudioAVFrameSamples (
340+ swrContext_,
341+ srcAVFrame,
342+ outSampleFormat,
343+ outSampleRate,
344+ outNumChannels);
345+ }
346+ const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;
347+
348+ AVSampleFormat format = static_cast <AVSampleFormat>(avFrame->format );
349+ TORCH_CHECK (
350+ format == outSampleFormat,
351+ " Something went wrong, the frame didn't get converted to the desired format. " ,
352+ " Desired format = " ,
353+ av_get_sample_fmt_name (outSampleFormat),
354+ " source format = " ,
355+ av_get_sample_fmt_name (format));
356+
357+ int numChannels = getNumChannels (avFrame);
358+ TORCH_CHECK (
359+ numChannels == outNumChannels,
360+ " Something went wrong, the frame didn't get converted to the desired " ,
361+ " number of channels = " ,
362+ outNumChannels,
363+ " . Got " ,
364+ numChannels,
365+ " instead." );
366+
367+ auto numSamples = avFrame->nb_samples ;
368+
369+ frameOutput.data = torch::empty ({numChannels, numSamples}, torch::kFloat32 );
370+
371+ if (numSamples > 0 ) {
372+ uint8_t * outputChannelData =
373+ static_cast <uint8_t *>(frameOutput.data .data_ptr ());
374+ auto numBytesPerChannel = numSamples * av_get_bytes_per_sample (format);
375+ for (auto channel = 0 ; channel < numChannels;
376+ ++channel, outputChannelData += numBytesPerChannel) {
377+ std::memcpy (
378+ outputChannelData,
379+ avFrame->extended_data [channel],
380+ numBytesPerChannel);
381+ }
382+ }
383+ }
384+
385+ std::optional<torch::Tensor> CpuDeviceInterface::maybeFlushAudioBuffers () {
386+ // When sample rate conversion is involved, swresample buffers some of the
387+ // samples in-between calls to swr_convert (see the libswresample docs).
388+ // That's because the last few samples in a given frame require future
389+ // samples from the next frame to be properly converted. This function
390+ // flushes out the samples that are stored in swresample's buffers.
391+ if (!swrContext_) {
392+ return std::nullopt ;
393+ }
394+ auto numRemainingSamples = // this is an upper bound
395+ swr_get_out_samples (swrContext_.get (), 0 );
396+
397+ if (numRemainingSamples == 0 ) {
398+ return std::nullopt ;
399+ }
400+
401+ int numChannels =
402+ audioStreamOptions_.numChannels .value_or (getNumChannels (codecContext_));
403+ torch::Tensor lastSamples =
404+ torch::empty ({numChannels, numRemainingSamples}, torch::kFloat32 );
405+
406+ std::vector<uint8_t *> outputBuffers (numChannels);
407+ for (auto i = 0 ; i < numChannels; i++) {
408+ outputBuffers[i] = static_cast <uint8_t *>(lastSamples[i].data_ptr ());
409+ }
410+
411+ auto actualNumRemainingSamples = swr_convert (
412+ swrContext_.get (), outputBuffers.data (), numRemainingSamples, nullptr , 0 );
413+
414+ return lastSamples.narrow (
415+ /* dim=*/ 1 , /* start=*/ 0 , /* length=*/ actualNumRemainingSamples);
416+ }
417+
281418std::string CpuDeviceInterface::getDetails () {
282419 return std::string (" CPU Device Interface." );
283420}
0 commit comments