@@ -1977,14 +1977,34 @@ def get_alpha_scales(down_weight, alpha_key):
19771977 "time_projection.1.diff_b"
19781978 )
19791979
1980- if any ("head.head" in k for k in state_dict ):
1981- converted_state_dict ["proj_out.lora_A.weight" ] = original_state_dict .pop (
1982- f"head.head.{ lora_down_key } .weight"
1983- )
1984- converted_state_dict ["proj_out.lora_B.weight" ] = original_state_dict .pop (f"head.head.{ lora_up_key } .weight" )
1980+ if any ("head.head" in k for k in original_state_dict ):
1981+ if any (f"head.head.{ lora_down_key } .weight" in k for k in state_dict ):
1982+ converted_state_dict ["proj_out.lora_A.weight" ] = original_state_dict .pop (
1983+ f"head.head.{ lora_down_key } .weight"
1984+ )
1985+ if any (f"head.head.{ lora_up_key } .weight" in k for k in state_dict ):
1986+ converted_state_dict ["proj_out.lora_B.weight" ] = original_state_dict .pop (
1987+ f"head.head.{ lora_up_key } .weight"
1988+ )
19851989 if "head.head.diff_b" in original_state_dict :
19861990 converted_state_dict ["proj_out.lora_B.bias" ] = original_state_dict .pop ("head.head.diff_b" )
19871991
1992+ # Notes: https://huggingface.co/lightx2v/Wan2.2-Distill-Loras
1993+ # This is my (sayakpaul) assumption that this particular key belongs to the down matrix.
1994+ # Since for this particular LoRA, we don't have the corresponding up matrix, I will use
1995+ # an identity.
1996+ if any ("head.head" in k and k .endswith (".diff" ) for k in state_dict ):
1997+ if f"head.head.{ lora_down_key } .weight" in state_dict :
1998+ logger .info (
1999+ f"The state dict seems to be have both `head.head.diff` and `head.head.{ lora_down_key } .weight` keys, which is unexpected."
2000+ )
2001+ converted_state_dict ["proj_out.lora_A.weight" ] = original_state_dict .pop ("head.head.diff" )
2002+ down_matrix_head = converted_state_dict ["proj_out.lora_A.weight" ]
2003+ up_matrix_shape = (down_matrix_head .shape [0 ], converted_state_dict ["proj_out.lora_B.bias" ].shape [0 ])
2004+ converted_state_dict ["proj_out.lora_B.weight" ] = torch .eye (
2005+ * up_matrix_shape , dtype = down_matrix_head .dtype , device = down_matrix_head .device
2006+ ).T
2007+
19882008 for text_time in ["text_embedding" , "time_embedding" ]:
19892009 if any (text_time in k for k in original_state_dict ):
19902010 for b_n in [0 , 2 ]:
0 commit comments