-
Notifications
You must be signed in to change notification settings - Fork 56
Description
🐛 Describe the bug
When using multiple processes to decode videos in parallel, I sometimes (non-deterministically) get green frames. If I re-run the decode on the problematic frame enough times, it will eventually return the expected image. I've included sample code below. With the configuration below, I do not always see the green frames, however, as the number of workers increases, so does the frequency of incorrect decodes. Some more testing details:
- Used machines with A100 and T4
- Used the 1280x720 and 1920x1080 videos available here
- For the above videos, the green frame is always (0, 76, 0). However, with other videos, I've also seen (0, 124, 0).
Sample code:
import torch
from torchcodec.decoders import VideoDecoder
from torch.multiprocessing import Process, Queue
import torchvision
import uuid
NUM_WORKERS = 2
NUM_VIDEOS = 5
BATCH_SIZE = 30
WORKING_DIR = "/path/to/directory/"
def worker_fn(in_queue):
while True:
file = in_queue.get()
if file is None:
return
else:
decoder = VideoDecoder(file, seek_mode="exact", device="cuda:0")
frame_count = decoder.metadata.num_frames
for i in range(0, frame_count - BATCH_SIZE + 1, BATCH_SIZE):
indices = range(i, i+BATCH_SIZE)
frames = decoder.get_frames_at(indices).data / 255
worst_std, index = torch.min(torch.max(torch.std(frames, dim=[2,3]), dim = 1).values, dim=0)
if worst_std < 1e-4:
print("{} contains monochromatic frames".format(file))
id = uuid.uuid4()
torchvision.utils.save_image(frames[index,:], WORKING_DIR+str(id)+".png")
frames = decoder.get_frames_at(indices).data / 255
torchvision.utils.save_image(frames[index,:], WORKING_DIR+str(id)+"_new.png")
if __name__ == "__main__":
videos = [WORKING_DIR + "/filename.mp4"] * NUM_VIDEOS
in_queue = Queue(NUM_WORKERS)
processes = [Process(target=worker_fn, daemon=True, args=(in_queue,)) for _ in range(NUM_WORKERS)]
for p in processes:
p.start()
for video in videos:
in_queue.put(video)
for p in processes:
in_queue.put(None)
for p in processes:
p.join()
Versions
PyTorch version: 2.7.1+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.8.0-1027-azure-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.6.77
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 560.35.05
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.5.1
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 4
On-line CPU(s) list: 0-3
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7V12 64-Core Processor
CPU family: 23
Model: 49
Thread(s) per core: 1
Core(s) per socket: 4
Socket(s): 1
Stepping: 0
BogoMIPS: 4890.87
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext perfctr_core ssbd vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves clzero xsaveerptr rdpru arat umip rdpid
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 128 KiB (4 instances)
L1i cache: 128 KiB (4 instances)
L2 cache: 2 MiB (4 instances)
L3 cache: 16 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-3
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT disabled
Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] mypy==1.13.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.1.3
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] onnxruntime==1.20.0
[pip3] pytorch-lightning==2.4.0
[pip3] torch==2.7.1
[pip3] torchaudio==2.5.1
[pip3] torchcodec==0.5+cu126
[pip3] torchmetrics==1.6.0
[pip3] torchvision==0.22.1
[pip3] triton==3.3.1
[conda] Could not collect