Skip to content

Conversation

@sywangyi
Copy link
Contributor

@sywangyi sywangyi commented Oct 30, 2025

fix the crash when testing CP for wan2.2-TI2V-5B

test script:

import random

import numpy as np
import torch
from torch import distributed as dist

from diffusers import AutoencoderKLWan, ContextParallelConfig, WanPipeline
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video


model_id="Wan-AI/Wan2.2-TI2V-5B-Diffusers"

def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    device = torch.device(f"cuda:{dist.get_rank()}")
    torch.cuda.set_device(device)
    return device


def set_seed_for_all_ranks(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    generator = torch.Generator(device="cuda")
    generator.manual_seed(seed)
    return generator


device = setup_distributed()
generator = set_seed_for_all_ranks(42)
onload_device = device
offload_device = torch.device("cpu")

vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
# group-offloading
pipe = WanPipeline.from_pretrained(
    model_id,
    vae=vae,
    torch_dtype=torch.bfloat16,
)
ulysses_degree = torch.distributed.get_world_size()
pipe.transformer.set_attention_backend("_native_cudnn")
pipe.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=ulysses_degree))
apply_group_offloading(pipe.text_encoder,
    onload_device=onload_device,
    offload_device=offload_device,
    offload_type="leaf_level",
    use_stream=True,
)

pipe.transformer.enable_group_offload(
    onload_device=onload_device,
    offload_device=offload_device,
    offload_type="leaf_level",
    use_stream=True,
)
pipe.vae.enable_group_offload(onload_device=onload_device, offload_type="leaf_level", use_stream=True)

pipe.vae.enable_tiling(tile_sample_min_height=480,tile_sample_min_width=960,tile_sample_stride_height=352,tile_sample_stride_width=640)
height = 704
width = 1280
num_frames = 121
num_inference_steps = 50
guidance_scale = 5.0


prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩>残留,丑陋的,残缺的>,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂>乱的背景,三条腿,背>景
人很多,倒着走"

output = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=height,
    width=width,
    num_frames=num_frames,
    guidance_scale=guidance_scale,
    num_inference_steps=num_inference_steps,
    generator=generator,
).frames[0]
if torch.distributed.get_rank() == 0:
    export_to_video(output, "5bit2v_output.mp4", fps=24)
if dist.is_initialized():
    torch.distributed.destroy_process_group()

@sywangyi
Copy link
Contributor Author

torchrun --nproc-per-node 2 test.py

crash stack:

[rank1]: Traceback (most recent call last):
[rank1]:   File "/mnt/disk3/wangyi/diffusers/test_14B_cp_offload.py", line 72, in <module>
[rank1]:     output = pipe(
[rank1]:              ^^^^^
[rank1]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/pipelines/wan/pipeline_wan.py", line 593, in __call__
[rank1]:     noise_pred = current_model(
[rank1]:                  ^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank1]:     output = function_reference.forward(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank1]:     output = function_reference.forward(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/models/transformers/transformer_wan.py", line 680, in forward
[rank1]:     hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
[rank1]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank1]:     output = function_reference.forward(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank1]:     output = function_reference.forward(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank1]:     output = function_reference.forward(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   [Previous line repeated 1 more time]
[rank1]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/models/transformers/transformer_wan.py", line 482, in forward
[rank1]:     norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
[rank1]:                           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~
[rank1]: RuntimeError: The size of tensor a (13640) must match the size of tensor b (27280) at non-singleton dimension 1
[rank1]:[W1030 12:11:14.705123356 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resource, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
  0%|                                                                                                                                                       | 0/1
[rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/disk3/wangyi/diffusers/test_14B_cp_offload.py", line 72, in <module>
[rank0]:     output = pipe(
[rank0]:              ^^^^^
[rank0]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/pipelines/wan/pipeline_wan.py", line 593, in __call__
[rank0]:     noise_pred = current_model(
[rank0]:                  ^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank0]:     output = function_reference.forward(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank0]:     output = function_reference.forward(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/models/transformers/transformer_wan.py", line 680, in forward
[rank0]:     hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk0/wangyi/miniforge3/envs/transformers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank0]:     output = function_reference.forward(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank0]:     output = function_reference.forward(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank0]:     output = function_reference.forward(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   [Previous line repeated 1 more time]
[rank0]:   File "/mnt/disk3/wangyi/diffusers/src/diffusers/models/transformers/transformer_wan.py", line 482, in forward
[rank0]:     norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
[rank0]:                           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~
[rank0]: RuntimeError: The size of tensor a (13640) must match the size of tensor b (27280) at non-singleton dimension 1

@sywangyi
Copy link
Contributor Author

@yiyixuxu @sayakpaul please help review

@sayakpaul
Copy link
Member

Could you also supplement an output with the fix?

)
if ts_seq_len is not None:
# Check if running under context parallel and split along seq_len dimension
if hasattr(self, '_parallel_config') and self._parallel_config is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate why this is needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when cp is enabled, seq_len is split, timestep_shape is [batch_size, seq_len, 6, inner_dim], so should be split in dim_1 as well since hidden state is split in seq_len dim as well. or else shape miss match will occur

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@sywangyi sywangyi Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean split timestep in forward? adding
"": {
"timestep": ContextParallelInput(split_dim=1, split_output=False)
}, to _cp_plan will make 5B work, but 14B fail since 5B timestep dims is 2. 14 timestep dims is 1.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this is an interesting situation. To tackle these, I think we might have to revisit the ContextParallelInput and ContextParallelOutput definitions a bit more.

If we had a way to tell the partitioner that the input might have "dynamic" dimensions depending on the model configs (like in this case), and what it should do if that's the case, it might be more flexible as a solution.

@DN6 curious to know what you think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes a lot of things easier, for sure!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expand_timesteps is not passed in WanTransformer3DModel init, so, no way to judge if timestep.dim is 2 or 1 currently.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add a config if we want to
i want to hear @DN6 's thoughts on this first though

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially, we can add cp_plan directly as a config, allow model owner to overridee it I think (in this case, we would send a PR into wan repo, i think it'd be ok)

It's also very in line with transformers does it btw.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu @DN6 we allow passing parallel_config through from_pretrained(), too. Wonder, if it could make sense to allow users to pass a custom _cp_plan through it.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Could you please explain the changes and also provide an example output?

@sayakpaul sayakpaul requested a review from DN6 October 30, 2025 05:50
@sywangyi
Copy link
Contributor Author

seems I can not attach the video here, blocked may be ....

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants