-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
flux pipeline inference fails
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
when enable_sequential_cpu_offload() is used
i cannot test in other memory management settings because my 3090 wont allow it to run
it fails at #5588725e8e7be497839432e5328c596169385f16
it works fine at #ded3db164bb3c090871647f30ff9988c9c17fd83 (the parent commit)
below is my venv :
accelerate==0.34.2
albucore==0.0.17
albumentations==1.4.16
annotated-types==0.7.0
asarPy==1.0.1
bsrgan==0.1.5
certifi==2024.8.30
charset-normalizer==3.3.2
cmake==3.30.4
compel==2.0.3
contourpy==1.3.0
cycler==0.12.1
Cython==3.0.11
-e git+https://github.com/huggingface/diffusers.git@main#egg=diffusers
easydict==1.13
eval_type_backport==0.2.0
filelock==3.16.1
fonttools==4.54.1
fsspec==2024.9.0
huggingface-hub==0.25.1
idna==3.10
imageio==2.35.1
importlib_metadata==8.5.0
insightface==0.7.3
Jinja2==3.1.4
joblib==1.4.2
kiwisolver==1.4.7
lazy_loader==0.4
MarkupSafe==2.1.5
matplotlib==3.9.2
mpmath==1.3.0
networkx==3.3
numpy==2.1.1
nvidia-cublas-cu11==11.11.3.6
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu11==11.8.89
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu11==9.4.0.58
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.6.68
nvidia-nvtx-cu12==12.1.105
onnx==1.16.2
opencv-python==4.10.0.84
opencv-python-headless==4.10.0.84
packaging==24.1
peft==0.12.0
pillow==10.4.0
prettytable==3.11.0
protobuf==5.28.2
psutil==6.0.0
pydantic==2.9.2
pydantic_core==2.23.4
pyparsing==3.1.4
python-dateutil==2.9.0.post0
PyYAML==6.0.2
regex==2024.9.11
requests==2.32.3
safetensors==0.4.5
scikit-image==0.24.0
scikit-learn==1.5.2
scipy==1.14.1
sentencepiece==0.2.0
six==1.16.0
style==1.1.0
sympy==1.13.3
threadpoolctl==3.5.0
tifffile==2024.9.20
timm==1.0.11
tokenizers==0.19.1
torch==2.4.1
torchaudio==2.4.1
torchsde==0.2.6
torchvision==0.19.1
tqdm==4.66.5
trampoline==0.1.2
transformers==4.44.2
triton==3.0.0
typing_extensions==4.12.2
update==0.0.1
urllib3==2.2.3
wcwidth==0.2.13
websockets==13.1
zipp==3.20.2
Reproduction
p=diffusers.FluxPipeline.from_pretrained('/home/rico/yastade/models/Flux/FLUX.1-schnell',torch_dtype=torch.bfloat16,use_safetensors=True)
p.enable_sequential_cpu_offload()
p(prompt="whatever",num_inference_steps=1)
Logs
>>> p=diffusers.FluxPipeline.from_pretrained('/home/rico/yastade/models/Flux/FLUX.1-schnell',torch_dtype=torch.bfloat16,use_safetensors=True)
Loading pipeline components...: 43%|███████████████████████████████████▏ | 3/7 [00:00<00:00, 6.14it/s]
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 11.71it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 7.28it/s]
>>> p.enable_sequential_cpu_offload()
>>> p(prompt="whatever",num_inference_steps=1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/rico/yastade/.yastade/lib64/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/rico/yastade/.yastade/src/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py", line 684, in __call__
latents, latent_image_ids = self.prepare_latents(
^^^^^^^^^^^^^^^^^^^^^
File "/home/rico/yastade/.yastade/src/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py", line 522, in prepare_latents
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rico/yastade/.yastade/src/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py", line 431, in _prepare_latent_image_ids
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!System Info
Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.
- 🤗 Diffusers version: 0.32.0.dev0
- Platform: Linux-6.10.11-1-default-x86_64-with-glibc2.40
- Running on Google Colab?: No
- Python version: 3.11.10
- PyTorch version (GPU?): 2.4.1+cu121 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.25.1
- Transformers version: 4.44.2
- Accelerate version: 0.34.2
- PEFT version: 0.12.0
- Bitsandbytes version: not installed
- Safetensors version: 0.4.5
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 3090, 24576 MiB
- Using GPU in script?: yes, via
- Using distributed or parallel set-up in script?: no
Who can help?
because
commit 5588725e8e7be497839432e5328c596169385f16
Author: Sayak Paul <[email protected]>
Date: Thu Nov 7 03:33:39 2024 +0100
[Flux] reduce explicit device transfers and typecasting in flux. (#9817)
reduce explicit device transfers and typecasting in flux.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working