@@ -1605,9 +1605,18 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16051605 if diff_keys :
16061606 for diff_k in diff_keys :
16071607 param = original_state_dict [diff_k ]
1608+ # The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
1609+ # and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
1610+ # to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
1611+ # is okay to ignore because they do not affect the model output in a significant manner.
1612+ threshold = 1.6e-2
1613+ absdiff = param .abs ().max () - param .abs ().min ()
16081614 all_zero = torch .all (param == 0 ).item ()
1609- if all_zero :
1610- logger .debug (f"Removed { diff_k } key from the state dict as it's all zeros." )
1615+ all_absdiff_lower_than_threshold = absdiff < threshold
1616+ if all_zero or all_absdiff_lower_than_threshold :
1617+ logger .debug (
1618+ f"Removed { diff_k } key from the state dict as it's all zeros, or values lower than hardcoded threshold."
1619+ )
16111620 original_state_dict .pop (diff_k )
16121621
16131622 # For the `diff_b` keys, we treat them as lora_bias.
@@ -1655,12 +1664,16 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16551664
16561665 # FFN
16571666 for o , c in zip (["ffn.0" , "ffn.2" ], ["net.0.proj" , "net.2" ]):
1658- converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_A.weight" ] = original_state_dict .pop (
1659- f"blocks.{ i } .{ o } .{ lora_down_key } .weight"
1660- )
1661- converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_B.weight" ] = original_state_dict .pop (
1662- f"blocks.{ i } .{ o } .{ lora_up_key } .weight"
1663- )
1667+ original_key = f"blocks.{ i } .{ o } .{ lora_down_key } .weight"
1668+ converted_key = f"blocks.{ i } .ffn.{ c } .lora_A.weight"
1669+ if original_key in original_state_dict :
1670+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1671+
1672+ original_key = f"blocks.{ i } .{ o } .{ lora_up_key } .weight"
1673+ converted_key = f"blocks.{ i } .ffn.{ c } .lora_B.weight"
1674+ if original_key in original_state_dict :
1675+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1676+
16641677 if f"blocks.{ i } .{ o } .diff_b" in original_state_dict :
16651678 converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_B.bias" ] = original_state_dict .pop (
16661679 f"blocks.{ i } .{ o } .diff_b"
@@ -1669,12 +1682,16 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16691682 # Remaining.
16701683 if original_state_dict :
16711684 if any ("time_projection" in k for k in original_state_dict ):
1672- converted_state_dict ["condition_embedder.time_proj.lora_A.weight" ] = original_state_dict .pop (
1673- f"time_projection.1.{ lora_down_key } .weight"
1674- )
1675- converted_state_dict ["condition_embedder.time_proj.lora_B.weight" ] = original_state_dict .pop (
1676- f"time_projection.1.{ lora_up_key } .weight"
1677- )
1685+ original_key = f"time_projection.1.{ lora_down_key } .weight"
1686+ converted_key = "condition_embedder.time_proj.lora_A.weight"
1687+ if original_key in original_state_dict :
1688+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1689+
1690+ original_key = f"time_projection.1.{ lora_up_key } .weight"
1691+ converted_key = "condition_embedder.time_proj.lora_B.weight"
1692+ if original_key in original_state_dict :
1693+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1694+
16781695 if "time_projection.1.diff_b" in original_state_dict :
16791696 converted_state_dict ["condition_embedder.time_proj.lora_B.bias" ] = original_state_dict .pop (
16801697 "time_projection.1.diff_b"
@@ -1709,6 +1726,20 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
17091726 original_state_dict .pop (f"{ text_time } .{ b_n } .diff_b" )
17101727 )
17111728
1729+ for img_ours , img_theirs in [
1730+ ("ff.net.0.proj" , "img_emb.proj.1" ),
1731+ ("ff.net.2" , "img_emb.proj.3" ),
1732+ ]:
1733+ original_key = f"{ img_theirs } .{ lora_down_key } .weight"
1734+ converted_key = f"condition_embedder.image_embedder.{ img_ours } .lora_A.weight"
1735+ if original_key in original_state_dict :
1736+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1737+
1738+ original_key = f"{ img_theirs } .{ lora_up_key } .weight"
1739+ converted_key = f"condition_embedder.image_embedder.{ img_ours } .lora_B.weight"
1740+ if original_key in original_state_dict :
1741+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1742+
17121743 if len (original_state_dict ) > 0 :
17131744 diff = all (".diff" in k for k in original_state_dict )
17141745 if diff :
0 commit comments