Skip to content

Commit 694dcf2

Browse files
committed
refactor
1 parent 4b14ddd commit 694dcf2

File tree

2 files changed

+50
-25
lines changed

2 files changed

+50
-25
lines changed

src/diffusers/models/transformers/transformer_wan_vace.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -106,35 +106,38 @@ def forward(
106106
) -> torch.Tensor:
107107
if self.proj_in is not None:
108108
control_hidden_states = self.proj_in(control_hidden_states)
109-
hidden_states = hidden_states + control_hidden_states
110-
else:
111-
hidden_states = control_hidden_states
109+
control_hidden_states = control_hidden_states + hidden_states
112110

113111
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
114112
self.scale_shift_table + temb.float()
115113
).chunk(6, dim=1)
116114

117115
# 1. Self-attention
118-
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
116+
norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(
117+
control_hidden_states
118+
)
119119
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
120-
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
120+
control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states)
121121

122122
# 2. Cross-attention
123-
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
123+
norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states)
124124
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
125-
hidden_states = hidden_states + attn_output
125+
control_hidden_states = control_hidden_states + attn_output
126126

127127
# 3. Feed-forward
128-
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
129-
hidden_states
128+
norm_hidden_states = (self.norm3(control_hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
129+
control_hidden_states
130130
)
131131
ff_output = self.ffn(norm_hidden_states)
132-
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
132+
control_hidden_states = (control_hidden_states.float() + ff_output.float() * c_gate_msa).type_as(
133+
control_hidden_states
134+
)
133135

136+
conditioning_states = None
134137
if self.proj_out is not None:
135-
control_hidden_states = self.proj_out(hidden_states)
138+
conditioning_states = self.proj_out(control_hidden_states)
136139

137-
return hidden_states, control_hidden_states
140+
return conditioning_states, control_hidden_states
138141

139142

140143
class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
@@ -309,11 +312,9 @@ def forward(
309312
# 2. Patch embedding
310313
hidden_states = self.patch_embedding(hidden_states)
311314
hidden_states = hidden_states.flatten(2).transpose(1, 2)
312-
print("hidden_states", hidden_states.shape)
313315

314316
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
315317
control_hidden_states = control_hidden_states.flatten(2).transpose(1, 2)
316-
print("control_hidden_states", control_hidden_states.shape)
317318
control_hidden_states_padding = control_hidden_states.new_zeros(
318319
batch_size, hidden_states.size(1) - control_hidden_states.size(1), control_hidden_states.size(2)
319320
)
@@ -333,12 +334,11 @@ def forward(
333334
if torch.is_grad_enabled() and self.gradient_checkpointing:
334335
# Prepare VACE hints
335336
control_hidden_states_list = []
336-
vace_hidden_states = hidden_states
337337
for i, block in enumerate(self.vace_blocks):
338-
vace_hidden_states, control_hidden_states = self._gradient_checkpointing_func(
339-
block, vace_hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
338+
conditioning_states, control_hidden_states = self._gradient_checkpointing_func(
339+
block, hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
340340
)
341-
control_hidden_states_list.append((control_hidden_states, control_hidden_states_scale[i]))
341+
control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
342342
control_hidden_states_list = control_hidden_states_list[::-1]
343343

344344
for i, block in enumerate(self.blocks):
@@ -351,12 +351,11 @@ def forward(
351351
else:
352352
# Prepare VACE hints
353353
control_hidden_states_list = []
354-
vace_hidden_states = hidden_states
355354
for i, block in enumerate(self.vace_blocks):
356-
vace_hidden_states, control_hidden_states = block(
357-
vace_hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
355+
conditioning_states, control_hidden_states = block(
356+
hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
358357
)
359-
control_hidden_states_list.append((control_hidden_states, control_hidden_states_scale[i]))
358+
control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
360359
control_hidden_states_list = control_hidden_states_list[::-1]
361360

362361
for i, block in enumerate(self.blocks):

src/diffusers/pipelines/wan/pipeline_wan_vace.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2424
from ...image_processor import PipelineImageInput
2525
from ...loaders import WanLoraLoaderMixin
26-
from ...models import AutoencoderKLWan, WanTransformer3DModel
26+
from ...models import AutoencoderKLWan, WanVACETransformer3DModel
2727
from ...schedulers import FlowMatchEulerDiscreteScheduler
2828
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
2929
from ...utils.torch_utils import randn_tensor
@@ -137,7 +137,7 @@ def __init__(
137137
self,
138138
tokenizer: AutoTokenizer,
139139
text_encoder: UMT5EncoderModel,
140-
transformer: WanTransformer3DModel,
140+
transformer: WanVACETransformer3DModel,
141141
vae: AutoencoderKLWan,
142142
scheduler: FlowMatchEulerDiscreteScheduler,
143143
):
@@ -421,6 +421,13 @@ def preprocess_conditions(
421421
f"Batch size of `video` {video.shape[0]} and length of `reference_images` {len(reference_images)} does not match."
422422
)
423423

424+
ref_images_lengths = [len(reference_images_batch) for reference_images_batch in reference_images]
425+
if any(l != ref_images_lengths[0] for l in ref_images_lengths):
426+
raise ValueError(
427+
f"All batches of `reference_images` should have the same length, but got {ref_images_lengths}. Support for this "
428+
"may be added in the future."
429+
)
430+
424431
reference_images_preprocessed = []
425432
for i, reference_images_batch in enumerate(reference_images):
426433
preprocessed_images = []
@@ -449,7 +456,10 @@ def prepare_video_latents(
449456
mask: torch.Tensor,
450457
reference_images: Optional[List[List[torch.Tensor]]] = None,
451458
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
459+
device: Optional[torch.device] = None,
452460
) -> torch.Tensor:
461+
device = device or self._execution_device
462+
453463
if isinstance(generator, list):
454464
# TODO: support this
455465
raise ValueError("Passing a list of generators is not yet supported. This may be supported in the future.")
@@ -473,15 +483,25 @@ def prepare_video_latents(
473483
vae_dtype = self.vae.dtype
474484
video = video.to(dtype=vae_dtype)
475485

486+
latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=torch.float32).view(
487+
1, self.vae.config.z_dim, 1, 1, 1
488+
)
489+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=device, dtype=torch.float32).view(
490+
1, self.vae.config.z_dim, 1, 1, 1
491+
)
492+
476493
if mask is None:
477494
latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0)
495+
latents = ((latents.float() - latents_mean) * latents_std).to(vae_dtype)
478496
else:
479497
mask = mask.to(dtype=vae_dtype)
480498
mask = torch.where(mask > 0.5, 1.0, 0.0)
481499
inactive = video * (1 - mask)
482500
reactive = video * mask
483501
inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax")
484502
reactive = retrieve_latents(self.vae.encode(reactive), generator, sample_mode="argmax")
503+
inactive = ((inactive.float() - latents_mean) * latents_std).to(vae_dtype)
504+
reactive = ((reactive.float() - latents_mean) * latents_std).to(vae_dtype)
485505
latents = torch.cat([inactive, reactive], dim=1)
486506

487507
latent_list = []
@@ -491,6 +511,7 @@ def prepare_video_latents(
491511
reference_image = reference_image.to(dtype=vae_dtype)
492512
reference_image = reference_image[None, :, None, :, :] # [1, C, 1, H, W]
493513
reference_latent = retrieve_latents(self.vae.encode(reference_image), generator, sample_mode="argmax")
514+
reference_latent = ((reference_latent.float() - latents_mean) * latents_std).to(vae_dtype)
494515
reference_latent = torch.cat([reference_latent, torch.zeros_like(reference_latent)], dim=1)
495516
latent = torch.cat([reference_latent.squeeze(0), latent], dim=1) # Concat across frame dimension
496517
latent_list.append(latent)
@@ -790,7 +811,7 @@ def __call__(
790811
device,
791812
)
792813

793-
conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator)
814+
conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator, device)
794815
mask = self.prepare_masks(mask, reference_images, generator)
795816
conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
796817
conditioning_latents = conditioning_latents.to(transformer_dtype)
@@ -808,6 +829,11 @@ def __call__(
808829
latents,
809830
)
810831

832+
if conditioning_latents.shape[2] != latents.shape[2]:
833+
logger.warning(
834+
"The number of frames in the conditioning latents does not match the number of frames to be generated. Generation quality may be affected."
835+
)
836+
811837
# 6. Denoising loop
812838
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
813839
self._num_timesteps = len(timesteps)

0 commit comments

Comments
 (0)