Skip to content

Commit df0e2a4

Browse files
authored
support latest few-step wan LoRA. (#12541)
* support latest few-step wan LoRA. * up * up
1 parent 303efd2 commit df0e2a4

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1977,14 +1977,34 @@ def get_alpha_scales(down_weight, alpha_key):
19771977
"time_projection.1.diff_b"
19781978
)
19791979

1980-
if any("head.head" in k for k in state_dict):
1981-
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
1982-
f"head.head.{lora_down_key}.weight"
1983-
)
1984-
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight")
1980+
if any("head.head" in k for k in original_state_dict):
1981+
if any(f"head.head.{lora_down_key}.weight" in k for k in state_dict):
1982+
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
1983+
f"head.head.{lora_down_key}.weight"
1984+
)
1985+
if any(f"head.head.{lora_up_key}.weight" in k for k in state_dict):
1986+
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(
1987+
f"head.head.{lora_up_key}.weight"
1988+
)
19851989
if "head.head.diff_b" in original_state_dict:
19861990
converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")
19871991

1992+
# Notes: https://huggingface.co/lightx2v/Wan2.2-Distill-Loras
1993+
# This is my (sayakpaul) assumption that this particular key belongs to the down matrix.
1994+
# Since for this particular LoRA, we don't have the corresponding up matrix, I will use
1995+
# an identity.
1996+
if any("head.head" in k and k.endswith(".diff") for k in state_dict):
1997+
if f"head.head.{lora_down_key}.weight" in state_dict:
1998+
logger.info(
1999+
f"The state dict seems to be have both `head.head.diff` and `head.head.{lora_down_key}.weight` keys, which is unexpected."
2000+
)
2001+
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop("head.head.diff")
2002+
down_matrix_head = converted_state_dict["proj_out.lora_A.weight"]
2003+
up_matrix_shape = (down_matrix_head.shape[0], converted_state_dict["proj_out.lora_B.bias"].shape[0])
2004+
converted_state_dict["proj_out.lora_B.weight"] = torch.eye(
2005+
*up_matrix_shape, dtype=down_matrix_head.dtype, device=down_matrix_head.device
2006+
).T
2007+
19882008
for text_time in ["text_embedding", "time_embedding"]:
19892009
if any(text_time in k for k in original_state_dict):
19902010
for b_n in [0, 2]:

0 commit comments

Comments
 (0)