|
| 1 | +#include <unistd.h> |
| 2 | + |
| 3 | +#include <level_zero/ze_api.h> |
| 4 | +#include <va/va_drmcommon.h> |
| 5 | + |
| 6 | +#include <ATen/DLConvertor.h> |
| 7 | +#include <c10/xpu/XPUStream.h> |
| 8 | + |
| 9 | +#include "src/torchcodec/_core/Cache.h" |
| 10 | +#include "src/torchcodec/_core/FFMPEGCommon.h" |
| 11 | +#include "src/torchcodec/_core/XpuDeviceInterface.h" |
| 12 | + |
| 13 | +extern "C" { |
| 14 | +#include <libavfilter/buffersink.h> |
| 15 | +#include <libavfilter/buffersrc.h> |
| 16 | +#include <libavutil/hwcontext_vaapi.h> |
| 17 | +#include <libavutil/pixdesc.h> |
| 18 | +} |
| 19 | + |
| 20 | +namespace facebook::torchcodec { |
| 21 | +namespace { |
| 22 | + |
| 23 | +static bool g_xpu = registerDeviceInterface( |
| 24 | + torch::kXPU, |
| 25 | + [](const torch::Device& device) { return new XpuDeviceInterface(device); }); |
| 26 | + |
| 27 | +const int MAX_XPU_GPUS = 128; |
| 28 | +// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching. |
| 29 | +// Set to a positive number to have a cache of that size. |
| 30 | +const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1; |
| 31 | +PerGpuCache<AVBufferRef, Deleterp<AVBufferRef, void, av_buffer_unref>> |
| 32 | + g_cached_hw_device_ctxs(MAX_XPU_GPUS, MAX_CONTEXTS_PER_GPU_IN_CACHE); |
| 33 | + |
| 34 | +UniqueAVBufferRef getVaapiContext(const torch::Device& device) { |
| 35 | + enum AVHWDeviceType type = av_hwdevice_find_type_by_name("vaapi"); |
| 36 | + TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find vaapi device"); |
| 37 | + torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device); |
| 38 | + |
| 39 | + UniqueAVBufferRef hw_device_ctx = g_cached_hw_device_ctxs.get(device); |
| 40 | + if (hw_device_ctx) { |
| 41 | + return hw_device_ctx; |
| 42 | + } |
| 43 | + |
| 44 | + std::string renderD = "/dev/dri/renderD128"; |
| 45 | + |
| 46 | + sycl::device syclDevice = c10::xpu::get_raw_device(nonNegativeDeviceIndex); |
| 47 | + if (syclDevice.has(sycl::aspect::ext_intel_pci_address)) { |
| 48 | + auto BDF = |
| 49 | + syclDevice.get_info<sycl::ext::intel::info::device::pci_address>(); |
| 50 | + renderD = "/dev/dri/by-path/pci-" + BDF + "-render"; |
| 51 | + } |
| 52 | + |
| 53 | + AVBufferRef* ctx = nullptr; |
| 54 | + int err = av_hwdevice_ctx_create(&ctx, type, renderD.c_str(), nullptr, 0); |
| 55 | + if (err < 0) { |
| 56 | + TORCH_CHECK( |
| 57 | + false, |
| 58 | + "Failed to create specified HW device: ", |
| 59 | + getFFMPEGErrorStringFromErrorCode(err)); |
| 60 | + } |
| 61 | + return UniqueAVBufferRef(ctx); |
| 62 | +} |
| 63 | + |
| 64 | +} // namespace |
| 65 | + |
| 66 | +XpuDeviceInterface::XpuDeviceInterface(const torch::Device& device) |
| 67 | + : DeviceInterface(device) { |
| 68 | + TORCH_CHECK(g_xpu, "XpuDeviceInterface was not registered!"); |
| 69 | + TORCH_CHECK( |
| 70 | + device_.type() == torch::kXPU, "Unsupported device: ", device_.str()); |
| 71 | +} |
| 72 | + |
| 73 | +XpuDeviceInterface::~XpuDeviceInterface() { |
| 74 | + if (ctx_) { |
| 75 | + g_cached_hw_device_ctxs.addIfCacheHasCapacity(device_, std::move(ctx_)); |
| 76 | + } |
| 77 | +} |
| 78 | + |
| 79 | +VADisplay getVaDisplayFromAV(AVFrame* avFrame) { |
| 80 | + AVHWFramesContext* hwfc = (AVHWFramesContext*)avFrame->hw_frames_ctx->data; |
| 81 | + AVHWDeviceContext* hwdc = hwfc->device_ctx; |
| 82 | + AVVAAPIDeviceContext* vactx = (AVVAAPIDeviceContext*)hwdc->hwctx; |
| 83 | + return vactx->display; |
| 84 | +} |
| 85 | + |
| 86 | +void XpuDeviceInterface::initializeContext(AVCodecContext* codecContext) { |
| 87 | + TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized"); |
| 88 | + |
| 89 | + // It is important for pytorch itself to create the xpu context. If ffmpeg |
| 90 | + // creates the context it may not be compatible with pytorch. |
| 91 | + // This is a dummy tensor to initialize the xpu context. |
| 92 | + torch::Tensor dummyTensorForXpuInitialization = torch::empty( |
| 93 | + {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); |
| 94 | + ctx_ = getVaapiContext(device_); |
| 95 | + codecContext->hw_device_ctx = av_buffer_ref(ctx_.get()); |
| 96 | + return; |
| 97 | +} |
| 98 | + |
| 99 | +struct xpuManagerCtx { |
| 100 | + UniqueAVFrame avFrame; |
| 101 | + ze_context_handle_t zeCtx = nullptr; |
| 102 | +}; |
| 103 | + |
| 104 | +void deleter(DLManagedTensor* self) { |
| 105 | + std::unique_ptr<DLManagedTensor> tensor(self); |
| 106 | + std::unique_ptr<xpuManagerCtx> context((xpuManagerCtx*)self->manager_ctx); |
| 107 | + zeMemFree(context->zeCtx, self->dl_tensor.data); |
| 108 | +} |
| 109 | + |
| 110 | +torch::Tensor AVFrameToTensor( |
| 111 | + const torch::Device& device, |
| 112 | + const UniqueAVFrame& frame) { |
| 113 | + TORCH_CHECK_EQ(frame->format, AV_PIX_FMT_VAAPI); |
| 114 | + |
| 115 | + VADRMPRIMESurfaceDescriptor desc{}; |
| 116 | + |
| 117 | + VAStatus sts = vaExportSurfaceHandle( |
| 118 | + getVaDisplayFromAV(frame.get()), |
| 119 | + (VASurfaceID)(uintptr_t)frame->data[3], |
| 120 | + VA_SURFACE_ATTRIB_MEM_TYPE_DRM_PRIME_2, |
| 121 | + VA_EXPORT_SURFACE_READ_ONLY, |
| 122 | + &desc); |
| 123 | + TORCH_CHECK( |
| 124 | + sts == VA_STATUS_SUCCESS, |
| 125 | + "vaExportSurfaceHandle failed: ", |
| 126 | + vaErrorStr(sts)); |
| 127 | + |
| 128 | + TORCH_CHECK(desc.num_objects == 1, "Expected 1 fd, got ", desc.num_objects); |
| 129 | + TORCH_CHECK(desc.num_layers == 1, "Expected 1 layer, got ", desc.num_layers); |
| 130 | + TORCH_CHECK( |
| 131 | + desc.layers[0].num_planes == 1, |
| 132 | + "Expected 1 plane, got ", |
| 133 | + desc.num_layers); |
| 134 | + |
| 135 | + std::unique_ptr<xpuManagerCtx> context = std::make_unique<xpuManagerCtx>(); |
| 136 | + ze_device_handle_t ze_device{}; |
| 137 | + sycl::queue queue = c10::xpu::getCurrentXPUStream(device.index()); |
| 138 | + |
| 139 | + queue |
| 140 | + .submit([&](sycl::handler& cgh) { |
| 141 | + cgh.host_task([&](const sycl::interop_handle& ih) { |
| 142 | + context->zeCtx = |
| 143 | + ih.get_native_context<sycl::backend::ext_oneapi_level_zero>(); |
| 144 | + ze_device = |
| 145 | + ih.get_native_device<sycl::backend::ext_oneapi_level_zero>(); |
| 146 | + }); |
| 147 | + }) |
| 148 | + .wait(); |
| 149 | + |
| 150 | + ze_external_memory_import_fd_t import_fd_desc{}; |
| 151 | + import_fd_desc.stype = ZE_STRUCTURE_TYPE_EXTERNAL_MEMORY_IMPORT_FD; |
| 152 | + import_fd_desc.flags = ZE_EXTERNAL_MEMORY_TYPE_FLAG_DMA_BUF; |
| 153 | + import_fd_desc.fd = desc.objects[0].fd; |
| 154 | + |
| 155 | + ze_device_mem_alloc_desc_t alloc_desc{}; |
| 156 | + alloc_desc.pNext = &import_fd_desc; |
| 157 | + void* usm_ptr = nullptr; |
| 158 | + |
| 159 | + ze_result_t res = zeMemAllocDevice( |
| 160 | + context->zeCtx, |
| 161 | + &alloc_desc, |
| 162 | + desc.objects[0].size, |
| 163 | + 0, |
| 164 | + ze_device, |
| 165 | + &usm_ptr); |
| 166 | + TORCH_CHECK( |
| 167 | + res == ZE_RESULT_SUCCESS, "Failed to import fd=", desc.objects[0].fd); |
| 168 | + |
| 169 | + close(desc.objects[0].fd); |
| 170 | + |
| 171 | + std::unique_ptr<DLManagedTensor> dl_dst = std::make_unique<DLManagedTensor>(); |
| 172 | + int64_t shape[3] = {desc.height, desc.width, 4}; |
| 173 | + |
| 174 | + context->avFrame.reset(av_frame_alloc()); |
| 175 | + TORCH_CHECK(context->avFrame.get(), "Failed to allocate AVFrame"); |
| 176 | + |
| 177 | + int status = av_frame_ref(context->avFrame.get(), frame.get()); |
| 178 | + TORCH_CHECK( |
| 179 | + status >= 0, |
| 180 | + "Failed to reference AVFrame: ", |
| 181 | + getFFMPEGErrorStringFromErrorCode(status)); |
| 182 | + |
| 183 | + dl_dst->manager_ctx = context.release(); |
| 184 | + dl_dst->deleter = deleter; |
| 185 | + dl_dst->dl_tensor.data = usm_ptr; |
| 186 | + dl_dst->dl_tensor.device.device_type = kDLOneAPI; |
| 187 | + dl_dst->dl_tensor.device.device_id = device.index(); |
| 188 | + dl_dst->dl_tensor.ndim = 3; |
| 189 | + dl_dst->dl_tensor.dtype.code = kDLUInt; |
| 190 | + dl_dst->dl_tensor.dtype.bits = 8; |
| 191 | + dl_dst->dl_tensor.dtype.lanes = 1; |
| 192 | + dl_dst->dl_tensor.shape = shape; |
| 193 | + dl_dst->dl_tensor.strides = nullptr; |
| 194 | + dl_dst->dl_tensor.byte_offset = desc.layers[0].offset[0]; |
| 195 | + |
| 196 | + auto dst = at::fromDLPack(dl_dst.release()); |
| 197 | + |
| 198 | + return dst; |
| 199 | +} |
| 200 | + |
| 201 | +VADisplay getVaDisplayFromAV(UniqueAVFrame& avFrame) { |
| 202 | + AVHWFramesContext* hwfc = (AVHWFramesContext*)avFrame->hw_frames_ctx->data; |
| 203 | + AVHWDeviceContext* hwdc = hwfc->device_ctx; |
| 204 | + AVVAAPIDeviceContext* vactx = (AVVAAPIDeviceContext*)hwdc->hwctx; |
| 205 | + return vactx->display; |
| 206 | +} |
| 207 | + |
| 208 | +void XpuDeviceInterface::convertAVFrameToFrameOutput( |
| 209 | + const VideoStreamOptions& videoStreamOptions, |
| 210 | + [[maybe_unused]] const AVRational& timeBase, |
| 211 | + UniqueAVFrame& avFrame, |
| 212 | + FrameOutput& frameOutput, |
| 213 | + std::optional<torch::Tensor> preAllocatedOutputTensor) { |
| 214 | + // TODO: consider to copy handling of CPU frame from CUDA |
| 215 | + // TODO: consider to copy NV12 format check from CUDA |
| 216 | + TORCH_CHECK( |
| 217 | + avFrame->format == AV_PIX_FMT_VAAPI, |
| 218 | + "Expected format to be AV_PIX_FMT_VAAPI, got " + |
| 219 | + std::string(av_get_pix_fmt_name((AVPixelFormat)avFrame->format))); |
| 220 | + auto frameDims = |
| 221 | + getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); |
| 222 | + int height = frameDims.height; |
| 223 | + int width = frameDims.width; |
| 224 | + torch::Tensor& dst = frameOutput.data; |
| 225 | + if (preAllocatedOutputTensor.has_value()) { |
| 226 | + dst = preAllocatedOutputTensor.value(); |
| 227 | + auto shape = dst.sizes(); |
| 228 | + TORCH_CHECK( |
| 229 | + (shape.size() == 3) && (shape[0] == height) && (shape[1] == width) && |
| 230 | + (shape[2] == 3), |
| 231 | + "Expected tensor of shape ", |
| 232 | + height, |
| 233 | + "x", |
| 234 | + width, |
| 235 | + "x3, got ", |
| 236 | + shape); |
| 237 | + } else { |
| 238 | + dst = allocateEmptyHWCTensor(height, width, device_); |
| 239 | + } |
| 240 | + |
| 241 | + auto start = std::chrono::high_resolution_clock::now(); |
| 242 | + // We need to compare the current frame context with our previous frame |
| 243 | + // context. If they are different, then we need to re-create our colorspace |
| 244 | + // conversion objects. We create our colorspace conversion objects late so |
| 245 | + // that we don't have to depend on the unreliable metadata in the header. |
| 246 | + // And we sometimes re-create them because it's possible for frame |
| 247 | + // resolution to change mid-stream. Finally, we want to reuse the colorspace |
| 248 | + // conversion objects as much as possible for performance reasons. |
| 249 | + enum AVPixelFormat frameFormat = |
| 250 | + static_cast<enum AVPixelFormat>(avFrame->format); |
| 251 | + FiltersContext filtersContext; |
| 252 | + |
| 253 | + filtersContext.inputWidth = avFrame->width; |
| 254 | + filtersContext.inputHeight = avFrame->height; |
| 255 | + filtersContext.inputFormat = frameFormat; |
| 256 | + filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio; |
| 257 | + // Actual output color format will be set via filter options |
| 258 | + filtersContext.outputFormat = AV_PIX_FMT_VAAPI; |
| 259 | + filtersContext.timeBase = timeBase; |
| 260 | + filtersContext.hwFramesCtx.reset(av_buffer_ref(avFrame->hw_frames_ctx)); |
| 261 | + |
| 262 | + std::stringstream filters; |
| 263 | + filters << "scale_vaapi=" << width << ":" << height; |
| 264 | + filters << ":format=rgba"; //: out_color_matrix=bt709:out_range=tv"; |
| 265 | + |
| 266 | + filtersContext.filters = filters.str(); |
| 267 | + |
| 268 | + if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) { |
| 269 | + filterGraphContext_ = |
| 270 | + std::make_unique<FilterGraph>(filtersContext, videoStreamOptions); |
| 271 | + prevFiltersContext_ = std::move(filtersContext); |
| 272 | + } |
| 273 | + |
| 274 | + // We convert input to the RGBX color format with VAAPI getting WxHx4 |
| 275 | + // tensor on the output. |
| 276 | + UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame); |
| 277 | + |
| 278 | + TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_VAAPI); |
| 279 | + |
| 280 | + torch::Tensor dst_rgb4 = AVFrameToTensor(device_, filteredAVFrame); |
| 281 | + dst.copy_(dst_rgb4.narrow(2, 0, 3)); |
| 282 | + |
| 283 | + auto end = std::chrono::high_resolution_clock::now(); |
| 284 | + |
| 285 | + std::chrono::duration<double, std::micro> duration = end - start; |
| 286 | + VLOG(9) << "Conversion of frame height=" << height << " width=" << width |
| 287 | + << " took: " << duration.count() << "us" << std::endl; |
| 288 | +} |
| 289 | + |
| 290 | +// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9 |
| 291 | +// we have to do this because of an FFmpeg bug where hardware decoding is not |
| 292 | +// appropriately set, so we just go off and find the matching codec for the CUDA |
| 293 | +// device |
| 294 | +std::optional<const AVCodec*> XpuDeviceInterface::findCodec( |
| 295 | + const AVCodecID& codecId) { |
| 296 | + void* i = nullptr; |
| 297 | + const AVCodec* codec = nullptr; |
| 298 | + while ((codec = av_codec_iterate(&i)) != nullptr) { |
| 299 | + if (codec->id != codecId || !av_codec_is_decoder(codec)) { |
| 300 | + continue; |
| 301 | + } |
| 302 | + |
| 303 | + const AVCodecHWConfig* config = nullptr; |
| 304 | + for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr; |
| 305 | + ++j) { |
| 306 | + if (config->device_type == AV_HWDEVICE_TYPE_VAAPI) { |
| 307 | + return codec; |
| 308 | + } |
| 309 | + } |
| 310 | + } |
| 311 | + |
| 312 | + return std::nullopt; |
| 313 | +} |
| 314 | + |
| 315 | +} // namespace facebook::torchcodec |
0 commit comments