Skip to content

Conversation

@tolgacangoz
Copy link
Contributor

@tolgacangoz tolgacangoz commented Aug 29, 2025

This PR is fixing #12257.

Comparison with the original repo

When I put with torch.amp.autocast('cuda', dtype=torch.bfloat16): onto the transformer only and converted the initial noise's dtype into torch.float32 from torch.bfloat16 in the original repo, the videos seem almost the same. As far as I can see, the original repo's video has an extra blink.

wan.mp4
diffusers.mp4
Try WanSpeechToVideoPipeline!
!git clone https://github.com/tolgacangoz/diffusers.git
%cd diffusers
#!git switch "integrations/wan2.2-s2v"  # This is constantly changing...
!git switch "wan2.2-s2v"
!pip install pip uv -qU
!uv pip install -e ".[dev]" -q
!uv pip install imageio-ffmpeg ftfy decord ninja packaging -q
# For Flash attention 2:
#!uv pip install flash-attn --no-build-isolation
# For Flash attention 3 in diffusers:
#import os
#os.environ["DIFFUSERS_ENABLE_HUB_KERNELS"] = "YES"


import numpy as np
import torch, os
from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline
from diffusers.utils import export_to_video, load_image, load_audio, load_video
from transformers import Wav2Vec2ForCTC

model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers"  # will be official
model_id = "tolgacangoz/Wan2.2-S2V-14B-Diffusers"
audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanSpeechToVideoPipeline.from_pretrained(
    model_id, vae=vae, audio_encoder=audio_encoder, torch_dtype=torch.bfloat16,
)#.to("cuda")
pipe.enable_model_cpu_offload()
#pipe.transformer.set_attention_backend("flash")  # FA 2
#pipe.transformer.set_attention_backend("_flash_3_hub")  # FA 3

first_frame = load_image("https://raw.githubusercontent.com/Wan-Video/Wan2.2/refs/heads/main/examples/i2v_input.JPG")
audio, sampling_rate = load_audio("https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/talk.wav")

import math

def get_size_less_than_area(height,
                            width,
                            target_area=1024 * 704,
                            divisor=64):
    if height * width <= target_area:
        # If the original image area is already less than or equal to the target,
        # no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target.
        max_upper_area = target_area
        min_scale = 0.1
        max_scale = 1.0
    else:
        # Resize to fit within the target area and then pad to multiples of `divisor`
        max_upper_area = target_area  # Maximum allowed total pixel count after padding
        d = divisor - 1
        b = d * (height + width)
        a = height * width
        c = d**2 - max_upper_area

        # Calculate scale boundaries using quadratic equation
        min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / (
            2 * a)  # Scale when maximum padding is applied
        max_scale = math.sqrt(max_upper_area /
                                (height * width))  # Scale without any padding

    # We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area
    # Use binary search-like iteration to find this scale
    find_it = False
    for i in range(100):
        scale = max_scale - (max_scale - min_scale) * i / 100
        new_height, new_width = int(height * scale), int(width * scale)

        # Pad to make dimensions divisible by 64
        pad_height = (64 - new_height % 64) % 64
        pad_width = (64 - new_width % 64) % 64
        pad_top = pad_height // 2
        pad_bottom = pad_height - pad_top
        pad_left = pad_width // 2
        pad_right = pad_width - pad_left

        padded_height, padded_width = new_height + pad_height, new_width + pad_width

        if padded_height * padded_width <= max_upper_area:
            find_it = True
            break

    if find_it:
        return padded_height, padded_width
    else:
        # Fallback: calculate target dimensions based on aspect ratio and divisor alignment
        aspect_ratio = width / height
        target_width = int(
            (target_area * aspect_ratio)**0.5 // divisor * divisor)
        target_height = int(
            (target_area / aspect_ratio)**0.5 // divisor * divisor)

        # Ensure the result is not larger than the original resolution
        if target_width >= width or target_height >= height:
            target_width = int(width // divisor * divisor)
            target_height = int(height // divisor * divisor)

        return target_height, target_width

height, width = get_size_less_than_area(first_frame.height, first_frame.width, target_area=480*832)

prompt = "Einstein singing a song."

output = pipe(
    image=first_frame, audio=audio, sampling_rate=sampling_rate,
    prompt=prompt, height=height, width=width, num_frames_per_chunk=80,
).frames[0]
export_to_video(output, "video.mp4", fps=16)

import logging, shutil, subprocess

def merge_video_audio(video_path: str, audio_path: str):
    """
    Merge the video and audio into a new video, with the duration set to the shorter of the two,
    and overwrite the original video file.

    Parameters:
    video_path (str): Path to the original video file
    audio_path (str): Path to the audio file
    """
    # set logging
    logging.basicConfig(level=logging.INFO)

    # check
    if not os.path.exists(video_path):
        raise FileNotFoundError(f"video file {video_path} does not exist")
    if not os.path.exists(audio_path):
        raise FileNotFoundError(f"audio file {audio_path} does not exist")

    base, ext = os.path.splitext(video_path)
    temp_output = f"{base}_temp{ext}"

    try:
        # create ffmpeg command
        command = [
            'ffmpeg',
            '-y',  # overwrite
            '-i',
            video_path,
            '-i',
            audio_path,
            '-c:v',
            'copy',  # copy video stream
            '-c:a',
            'aac',  # use AAC audio encoder
            '-b:a',
            '192k',  # set audio bitrate (optional)
            '-map',
            '0:v:0',  # select the first video stream
            '-map',
            '1:a:0',  # select the first audio stream
            '-shortest',  # choose the shortest duration
            temp_output
        ]

        # execute the command
        logging.info("Start merging video and audio...")
        result = subprocess.run(
            command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

        # check result
        if result.returncode != 0:
            error_msg = f"FFmpeg execute failed: {result.stderr}"
            logging.error(error_msg)
            raise RuntimeError(error_msg)

        shutil.move(temp_output, video_path)
        logging.info(f"Merge completed, saved to {video_path}")

    except Exception as e:
        if os.path.exists(temp_output):
            os.remove(temp_output)
        logging.error(f"merge_video_audio failed with error: {e}")

import requests, tempfile
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT

response = requests.get(audio, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT)
with tempfile.NamedTemporaryFile(delete=False) as talk:
    for chunk in response.iter_content(chunk_size=8192):
        talk.write(chunk)
    talk_file = talk.name

merge_video_audio("video.mp4", talk_file)

@yiyixuxu @sayakpaul @asomoza @dg845 @stevhliu
@WanX-Video-1 @Steven-SWZhang @kelseyee
@SHYuanBest @J4BEZ @okaris @xziayro-ai @teith @luke14free @lopho @arnold408

…date example imports

Add unit tests for WanSpeechToVideoPipeline and WanS2VTransformer3DModel and gguf
The previous audio encoding logic was a placeholder. It is now replaced with a `Wav2Vec2ForCTC` model and processor, including the full implementation for processing audio inputs. This involves resampling and aligning audio features with video frames to ensure proper synchronization.

Additionally, utility functions for loading audio from files or URLs are added, and the `audio_processor` module is refactored to correctly handle audio data types instead of image types.
Introduces support for audio and pose conditioning, replacing the previous image conditioning mechanism. The model now accepts audio embeddings and pose latents as input.

This change also adds two new, mutually exclusive motion processing modules:
- `MotionerTransformers`: A transformer-based module for encoding motion.
- `FramePackMotioner`: A module that packs frames from different temporal buckets for motion representation.

Additionally, an `AudioInjector` module is implemented to fuse audio features into specific transformer blocks using cross-attention.
The `MotionerTransformers` module is removed and its functionality is replaced by a `FramePackMotioner` module and a simplified standard motion processing pipeline.

The codebase is refactored to remove the `einops` dependency, replacing `rearrange` operations with standard PyTorch tensor manipulations for better code consistency.

Additionally, `AdaLayerNorm` is introduced for improved conditioning, and helper functions for Rotary Positional Embeddings (RoPE) are added (probably temporarily) and refactored for clarity and flexibility. The audio injection mechanism is also updated to align with the new model structure.
Removes the calculation of several unused variables and an unnecessary `deepcopy` operation on the latents tensor.

This change also removes the now-unused `deepcopy` import, simplifying the overall logic.
Refactors the `WanS2VTransformer3DModel` for clarity and better handling of various conditioning inputs like audio, pose, and motion.

Key changes:
- Simplifies the `WanS2VTransformerBlock` by removing projection layers and streamlining the forward pass.
- Introduces `after_transformer_block` to cleanly inject audio information after each transformer block, improving code organization.
- Enhances the main `forward` method to better process and combine multiple conditioning signals (image, audio, motion) before the transformer blocks.
- Adds support for a zero-value timestep to differentiate between image and video latents.
- Generalizes temporal embedding logic to support multiple model variations.
Introduces the necessary configurations and state dictionary key mappings to enable the conversion of S2V model checkpoints to the Diffusers format.

This includes:
- A new transformer configuration for the S2V model architecture, including parameters for audio and pose conditioning.
- A comprehensive rename dictionary to map the original S2V layer names to their Diffusers equivalents.
Comment on lines +891 to +901
pose_video = None
if pose_video_path_or_url is not None:
pose_video = load_video(
pose_video_path_or_url,
n_frames=num_frames_per_chunk * num_chunks,
target_fps=sampling_fps,
reverse=True,
)
pose_video = self.video_processor.preprocess_video(
pose_video, height=height, width=width, resize_mode="resize_min_center_crop"
).to(device, dtype=torch.float32)
Copy link
Contributor Author

@tolgacangoz tolgacangoz Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Giving pose info as pose_video_path_or_url doesn't seem diffusers friendly, right? load_video is usually run before the pipeline is called. But in this case, we need num_chunks after it might have been updated in the lines 881-882. Is there a better way to do this?

audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
return audio_embed_bucket, num_repeat

# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@J4BEZ
Copy link
Contributor

J4BEZ commented Oct 16, 2025

Dear @tolgacangoz
I appreciate for your hard work again!

While trying to use pipe.enable_sequential_cpu_offload() instead of pipe.enable_model_cpu_offload(),
I encountered an error like below:

error stack trace
File .../torch/utils/_contextlib.py:120, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    117 @functools.wraps(func)
    118 def decorate_context(*args, **kwargs):
    119     with ctx_factory():
--> 120         return func(*args, **kwargs)

File .../diffusers/pipelines/wan/pipeline_wan_s2v.py:882, in WanSpeechToVideoPipeline.__call__(self, image, audio, sampling_rate, prompt, negative_prompt, pose_video_path_or_url, height, width, num_frames_per_chunk, num_inference_steps, guidance_scale, num_videos_per_prompt, generator, latents, prompt_embeds, negative_prompt_embeds, image_embeds, audio_embeds, output_type, return_dict, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length, init_first_frame, sampling_fps, num_chunks)
    879     negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
    881 if audio_embeds is None:
--> 882     audio_embeds, num_chunks_audio = self.encode_audio(
    883         audio, sampling_rate, num_frames_per_chunk, sampling_fps, device
    884     )
    885 if num_chunks is None or num_chunks > num_chunks_audio:
    886     num_chunks = num_chunks_audio

File .../diffusers/pipelines/wan/pipeline_wan_s2v.py:351, in WanSpeechToVideoPipeline.encode_audio(self, audio, sampling_rate, num_frames, fps, device)
    349 input_values = self.audio_processor(audio, sampling_rate=sampling_rate, return_tensors="pt").input_values
    350 # retrieve logits & take argmax
--> 351 res = self.audio_encoder(input_values.to(self.audio_encoder.device), output_hidden_states=True)
    352 feat = torch.cat(res.hidden_states)
    354 feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate)

File .../torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
   1771     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1772 else:
-> 1773     return self._call_impl(*args, **kwargs)

File .../torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
   1779 # If we don't have any hooks, we want to skip the rest of the logic in
   1780 # this function, and just call forward.
   1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1782         or _global_backward_pre_hooks or _global_backward_hooks
   1783         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784     return forward_call(*args, **kwargs)
   1786 result = None
   1787 called_always_called_hooks = set()

File .../accelerate/hooks.py:175, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    173         output = module._old_forward(*args, **kwargs)
    174 else:
--> 175     output = module._old_forward(*args, **kwargs)
    176 return module._hf_hook.post_forward(module, output)

File .../transformers/models/wav2vec2/modeling_wav2vec2.py:1862, in Wav2Vec2ForCTC.forward(self, input_values, attention_mask, output_attentions, output_hidden_states, return_dict, labels)
   1859 if labels is not None and labels.max() >= self.config.vocab_size:
   1860     raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
-> 1862 outputs = self.wav2vec2(
   1863     input_values,
   1864     attention_mask=attention_mask,
   1865     output_attentions=output_attentions,
   1866     output_hidden_states=output_hidden_states,
   1867     return_dict=return_dict,
   1868 )
   1870 hidden_states = outputs[0]
   1871 hidden_states = self.dropout(hidden_states)

File .../torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
   1771     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1772 else:
-> 1773     return self._call_impl(*args, **kwargs)

File .../torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
   1779 # If we don't have any hooks, we want to skip the rest of the logic in
   1780 # this function, and just call forward.
   1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1782         or _global_backward_pre_hooks or _global_backward_hooks
   1783         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784     return forward_call(*args, **kwargs)
   1786 result = None
   1787 called_always_called_hooks = set()

File .../accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    169 def new_forward(module, *args, **kwargs):
--> 170     args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
    171     if module._hf_hook.no_grad:
    172         with torch.no_grad():

File .../accelerate/hooks.py:369, in AlignDevicesHook.pre_forward(self, module, *args, **kwargs)
    358             self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device))
    360         set_module_tensor_to_device(
    361             module,
    362             name,
   (...)    366             tied_params_map=self.tied_params_map,
    367         )
--> 369 return send_to_device(args, self.execution_device), send_to_device(
    370     kwargs, self.execution_device, skip_keys=self.skip_keys
    371 )

File .../accelerate/utils/operations.py:169, in send_to_device(tensor, device, non_blocking, skip_keys)
    167         return tensor.to(device)
    168 elif isinstance(tensor, (tuple, list)):
--> 169     return honor_type(
    170         tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor)
    171     )
    172 elif isinstance(tensor, Mapping):
    173     if isinstance(skip_keys, str):

File .../accelerate/utils/operations.py:81, in honor_type(obj, generator)
     79     return type(obj)(*list(generator))
     80 else:
---> 81     return type(obj)(generator)

File .../accelerate/utils/operations.py:170, in <genexpr>(.0)
    167         return tensor.to(device)
    168 elif isinstance(tensor, (tuple, list)):
    169     return honor_type(
--> 170         tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor)
    171     )
    172 elif isinstance(tensor, Mapping):
    173     if isinstance(skip_keys, str):

File .../accelerate/utils/operations.py:153, in send_to_device(tensor, device, non_blocking, skip_keys)
    151     device = "npu:0"
    152 try:
--> 153     return tensor.to(device, non_blocking=non_blocking)
    154 except TypeError:  # .to() doesn't accept non_blocking as kwarg
    155     return tensor.to(device)

NotImplementedError: Cannot copy out of meta tensor; no data!

After some investigation, I found a workaround that resolved the issue on my end, so I wanted to share the changes I made in case they’re helpful.

In def encode_audio() in the pipeline_wan_s2v.py

    def encode_audio(
        self,
        audio: PipelineAudioInput,
        sampling_rate: int,
        num_frames: int,
        fps: int = 16,
        device: Optional[torch.device] = None,
    ):
        device = device or self._execution_device
        video_rate = 30
        audio_sample_m = 0

        input_values = self.audio_processor(audio, sampling_rate=sampling_rate, return_tensors="pt").input_values

        # retrieve logits & take argmax
-        res = self.audio_encoder(input_values.to(self.audio_encoder.device), output_hidden_states=True)
+        res = self.audio_encoder(input_values.to(device), output_hidden_states=True)
        feat = torch.cat(res.hidden_states)
...

and in def load_pose_condition()

    def load_pose_condition(
        self, pose_video, num_chunks, num_frames_per_chunk, height, width, latents_mean, latents_std
    ):
+        device = self._execution_device
+        dtype = self.vae.dtype
        if pose_video is not None:
            padding_frame_num = num_chunks * num_frames_per_chunk - pose_video.shape[2]
-            pose_video = pose_video.to(dtype=self.vae.dtype, device=self.vae.device)
+           pose_video = pose_video.to(dtype=dtype, device=device)
            pose_video = torch.cat(
                [
                    pose_video,
                    -torch.ones(
-                        [1, 3, padding_frame_num, height, width], dtype=self.vae.dtype, device=self.vae.device
+                        [1, 3, padding_frame_num, height, width], dtype=dtype, device=device
                    ),
                ],
                dim=2,
            )

            pose_video = torch.chunk(pose_video, num_chunks, dim=2)
        else:
            pose_video = [
-                -torch.ones([1, 3, num_frames_per_chunk, height, width], dtype=self.vae.dtype, device=self.vae.device)
+                -torch.ones([1, 3, num_frames_per_chunk, height, width], dtype=dtype, device=device)
            ]

I hope this would be a little help!
Thanks for your dedication and hope you stay healthy and have a peaceful day!

- Updated device references in audio encoding and pose video loading to use a unified device variable.
- Enhanced image preprocessing to include a resize mode option for better handling of input dimensions.

Co-authored-by: Ju Hoon Park <[email protected]>
@tolgacangoz
Copy link
Contributor Author

Thanks @J4BEZ, fixed it.

@J4BEZ
Copy link
Contributor

J4BEZ commented Oct 18, 2025

@tolgacangoz Thanks! I am delighted to help☺️

Have a peaceful day!

Added contributor information and enhanced model description.
Added project page link for Wan-S2V model and improved context.

The project page: https://humanaigc.github.io/wan-s2v-webpage/

This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
Copy link
Contributor Author

@tolgacangoz tolgacangoz Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tolgacangoz
Copy link
Contributor Author

This will be my second official pipeline contribution and my fourth overall, yay 🥳

@tin2tin
Copy link

tin2tin commented Nov 7, 2025

Just a word of encouragement. This technology is actually quite good, and I hope it'll be priotized for review soonish. Here's a video I did with it: https://m.youtube.com/watch?v=N7ARyKKwGfc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants