Skip to content

Commit 3605918

Browse files
vladmandicDN6github-actions[bot]
authored andcommitted
fix scale_shift_factor being on cpu for wan and ltx (#12347)
* wan fix scale_shift_factor being on cpu * apply device cast to ltx transformer * Apply style fixes --------- Co-authored-by: Dhruv Nair <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 9169e81 commit 3605918

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,9 @@ def forward(
350350
norm_hidden_states = self.norm1(hidden_states)
351351

352352
num_ada_params = self.scale_shift_table.shape[0]
353-
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
353+
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
354+
batch_size, temb.size(1), num_ada_params, -1
355+
)
354356
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
355357
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
356358

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,12 +665,12 @@ def forward(
665665
# 5. Output norm, projection & unpatchify
666666
if temb.ndim == 3:
667667
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
668-
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
668+
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
669669
shift = shift.squeeze(2)
670670
scale = scale.squeeze(2)
671671
else:
672672
# batch_size, inner_dim
673-
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
673+
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
674674

675675
# Move the shift and scale tensors to the same device as hidden_states.
676676
# When using multi-GPU inference via accelerate these will be on the

src/diffusers/models/transformers/transformer_wan_vace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def forward(
103103
control_hidden_states = control_hidden_states + hidden_states
104104

105105
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
106-
self.scale_shift_table + temb.float()
106+
self.scale_shift_table.to(temb.device) + temb.float()
107107
).chunk(6, dim=1)
108108

109109
# 1. Self-attention
@@ -359,7 +359,7 @@ def forward(
359359
hidden_states = hidden_states + control_hint * scale
360360

361361
# 6. Output norm, projection & unpatchify
362-
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
362+
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
363363

364364
# Move the shift and scale tensors to the same device as hidden_states.
365365
# When using multi-GPU inference via accelerate these will be on the

0 commit comments

Comments
 (0)