@@ -973,3 +973,178 @@ def swap_scale_shift(weight):
973973 converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
974974
975975 return converted_state_dict
976+
977+
978+ def _convert_hunyuan_video_lora_to_diffusers (original_state_dict ):
979+ converted_state_dict = {k : original_state_dict .pop (k ) for k in list (original_state_dict .keys ())}
980+
981+ def remap_norm_scale_shift_ (key , state_dict ):
982+ weight = state_dict .pop (key )
983+ shift , scale = weight .chunk (2 , dim = 0 )
984+ new_weight = torch .cat ([scale , shift ], dim = 0 )
985+ state_dict [key .replace ("final_layer.adaLN_modulation.1" , "norm_out.linear" )] = new_weight
986+
987+ def remap_txt_in_ (key , state_dict ):
988+ def rename_key (key ):
989+ new_key = key .replace ("individual_token_refiner.blocks" , "token_refiner.refiner_blocks" )
990+ new_key = new_key .replace ("adaLN_modulation.1" , "norm_out.linear" )
991+ new_key = new_key .replace ("txt_in" , "context_embedder" )
992+ new_key = new_key .replace ("t_embedder.mlp.0" , "time_text_embed.timestep_embedder.linear_1" )
993+ new_key = new_key .replace ("t_embedder.mlp.2" , "time_text_embed.timestep_embedder.linear_2" )
994+ new_key = new_key .replace ("c_embedder" , "time_text_embed.text_embedder" )
995+ new_key = new_key .replace ("mlp" , "ff" )
996+ return new_key
997+
998+ if "self_attn_qkv" in key :
999+ weight = state_dict .pop (key )
1000+ to_q , to_k , to_v = weight .chunk (3 , dim = 0 )
1001+ state_dict [rename_key (key .replace ("self_attn_qkv" , "attn.to_q" ))] = to_q
1002+ state_dict [rename_key (key .replace ("self_attn_qkv" , "attn.to_k" ))] = to_k
1003+ state_dict [rename_key (key .replace ("self_attn_qkv" , "attn.to_v" ))] = to_v
1004+ else :
1005+ state_dict [rename_key (key )] = state_dict .pop (key )
1006+
1007+ def remap_img_attn_qkv_ (key , state_dict ):
1008+ weight = state_dict .pop (key )
1009+ if "lora_A" in key :
1010+ state_dict [key .replace ("img_attn_qkv" , "attn.to_q" )] = weight
1011+ state_dict [key .replace ("img_attn_qkv" , "attn.to_k" )] = weight
1012+ state_dict [key .replace ("img_attn_qkv" , "attn.to_v" )] = weight
1013+ else :
1014+ to_q , to_k , to_v = weight .chunk (3 , dim = 0 )
1015+ state_dict [key .replace ("img_attn_qkv" , "attn.to_q" )] = to_q
1016+ state_dict [key .replace ("img_attn_qkv" , "attn.to_k" )] = to_k
1017+ state_dict [key .replace ("img_attn_qkv" , "attn.to_v" )] = to_v
1018+
1019+ def remap_txt_attn_qkv_ (key , state_dict ):
1020+ weight = state_dict .pop (key )
1021+ if "lora_A" in key :
1022+ state_dict [key .replace ("txt_attn_qkv" , "attn.add_q_proj" )] = weight
1023+ state_dict [key .replace ("txt_attn_qkv" , "attn.add_k_proj" )] = weight
1024+ state_dict [key .replace ("txt_attn_qkv" , "attn.add_v_proj" )] = weight
1025+ else :
1026+ to_q , to_k , to_v = weight .chunk (3 , dim = 0 )
1027+ state_dict [key .replace ("txt_attn_qkv" , "attn.add_q_proj" )] = to_q
1028+ state_dict [key .replace ("txt_attn_qkv" , "attn.add_k_proj" )] = to_k
1029+ state_dict [key .replace ("txt_attn_qkv" , "attn.add_v_proj" )] = to_v
1030+
1031+ def remap_single_transformer_blocks_ (key , state_dict ):
1032+ hidden_size = 3072
1033+
1034+ if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key :
1035+ linear1_weight = state_dict .pop (key )
1036+ if "lora_A" in key :
1037+ new_key = key .replace ("single_blocks" , "single_transformer_blocks" ).removesuffix (
1038+ ".linear1.lora_A.weight"
1039+ )
1040+ state_dict [f"{ new_key } .attn.to_q.lora_A.weight" ] = linear1_weight
1041+ state_dict [f"{ new_key } .attn.to_k.lora_A.weight" ] = linear1_weight
1042+ state_dict [f"{ new_key } .attn.to_v.lora_A.weight" ] = linear1_weight
1043+ state_dict [f"{ new_key } .proj_mlp.lora_A.weight" ] = linear1_weight
1044+ else :
1045+ split_size = (hidden_size , hidden_size , hidden_size , linear1_weight .size (0 ) - 3 * hidden_size )
1046+ q , k , v , mlp = torch .split (linear1_weight , split_size , dim = 0 )
1047+ new_key = key .replace ("single_blocks" , "single_transformer_blocks" ).removesuffix (
1048+ ".linear1.lora_B.weight"
1049+ )
1050+ state_dict [f"{ new_key } .attn.to_q.lora_B.weight" ] = q
1051+ state_dict [f"{ new_key } .attn.to_k.lora_B.weight" ] = k
1052+ state_dict [f"{ new_key } .attn.to_v.lora_B.weight" ] = v
1053+ state_dict [f"{ new_key } .proj_mlp.lora_B.weight" ] = mlp
1054+
1055+ elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key :
1056+ linear1_bias = state_dict .pop (key )
1057+ if "lora_A" in key :
1058+ new_key = key .replace ("single_blocks" , "single_transformer_blocks" ).removesuffix (
1059+ ".linear1.lora_A.bias"
1060+ )
1061+ state_dict [f"{ new_key } .attn.to_q.lora_A.bias" ] = linear1_bias
1062+ state_dict [f"{ new_key } .attn.to_k.lora_A.bias" ] = linear1_bias
1063+ state_dict [f"{ new_key } .attn.to_v.lora_A.bias" ] = linear1_bias
1064+ state_dict [f"{ new_key } .proj_mlp.lora_A.bias" ] = linear1_bias
1065+ else :
1066+ split_size = (hidden_size , hidden_size , hidden_size , linear1_bias .size (0 ) - 3 * hidden_size )
1067+ q_bias , k_bias , v_bias , mlp_bias = torch .split (linear1_bias , split_size , dim = 0 )
1068+ new_key = key .replace ("single_blocks" , "single_transformer_blocks" ).removesuffix (
1069+ ".linear1.lora_B.bias"
1070+ )
1071+ state_dict [f"{ new_key } .attn.to_q.lora_B.bias" ] = q_bias
1072+ state_dict [f"{ new_key } .attn.to_k.lora_B.bias" ] = k_bias
1073+ state_dict [f"{ new_key } .attn.to_v.lora_B.bias" ] = v_bias
1074+ state_dict [f"{ new_key } .proj_mlp.lora_B.bias" ] = mlp_bias
1075+
1076+ else :
1077+ new_key = key .replace ("single_blocks" , "single_transformer_blocks" )
1078+ new_key = new_key .replace ("linear2" , "proj_out" )
1079+ new_key = new_key .replace ("q_norm" , "attn.norm_q" )
1080+ new_key = new_key .replace ("k_norm" , "attn.norm_k" )
1081+ state_dict [new_key ] = state_dict .pop (key )
1082+
1083+ TRANSFORMER_KEYS_RENAME_DICT = {
1084+ "img_in" : "x_embedder" ,
1085+ "time_in.mlp.0" : "time_text_embed.timestep_embedder.linear_1" ,
1086+ "time_in.mlp.2" : "time_text_embed.timestep_embedder.linear_2" ,
1087+ "guidance_in.mlp.0" : "time_text_embed.guidance_embedder.linear_1" ,
1088+ "guidance_in.mlp.2" : "time_text_embed.guidance_embedder.linear_2" ,
1089+ "vector_in.in_layer" : "time_text_embed.text_embedder.linear_1" ,
1090+ "vector_in.out_layer" : "time_text_embed.text_embedder.linear_2" ,
1091+ "double_blocks" : "transformer_blocks" ,
1092+ "img_attn_q_norm" : "attn.norm_q" ,
1093+ "img_attn_k_norm" : "attn.norm_k" ,
1094+ "img_attn_proj" : "attn.to_out.0" ,
1095+ "txt_attn_q_norm" : "attn.norm_added_q" ,
1096+ "txt_attn_k_norm" : "attn.norm_added_k" ,
1097+ "txt_attn_proj" : "attn.to_add_out" ,
1098+ "img_mod.linear" : "norm1.linear" ,
1099+ "img_norm1" : "norm1.norm" ,
1100+ "img_norm2" : "norm2" ,
1101+ "img_mlp" : "ff" ,
1102+ "txt_mod.linear" : "norm1_context.linear" ,
1103+ "txt_norm1" : "norm1.norm" ,
1104+ "txt_norm2" : "norm2_context" ,
1105+ "txt_mlp" : "ff_context" ,
1106+ "self_attn_proj" : "attn.to_out.0" ,
1107+ "modulation.linear" : "norm.linear" ,
1108+ "pre_norm" : "norm.norm" ,
1109+ "final_layer.norm_final" : "norm_out.norm" ,
1110+ "final_layer.linear" : "proj_out" ,
1111+ "fc1" : "net.0.proj" ,
1112+ "fc2" : "net.2" ,
1113+ "input_embedder" : "proj_in" ,
1114+ }
1115+
1116+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
1117+ "txt_in" : remap_txt_in_ ,
1118+ "img_attn_qkv" : remap_img_attn_qkv_ ,
1119+ "txt_attn_qkv" : remap_txt_attn_qkv_ ,
1120+ "single_blocks" : remap_single_transformer_blocks_ ,
1121+ "final_layer.adaLN_modulation.1" : remap_norm_scale_shift_ ,
1122+ }
1123+
1124+ # Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
1125+ # and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
1126+ # sure that both follow the same initial format by stripping off the "transformer." prefix.
1127+ for key in list (converted_state_dict .keys ()):
1128+ if key .startswith ("transformer." ):
1129+ converted_state_dict [key [len ("transformer." ) :]] = converted_state_dict .pop (key )
1130+ if key .startswith ("diffusion_model." ):
1131+ converted_state_dict [key [len ("diffusion_model." ) :]] = converted_state_dict .pop (key )
1132+
1133+ # Rename and remap the state dict keys
1134+ for key in list (converted_state_dict .keys ()):
1135+ new_key = key [:]
1136+ for replace_key , rename_key in TRANSFORMER_KEYS_RENAME_DICT .items ():
1137+ new_key = new_key .replace (replace_key , rename_key )
1138+ converted_state_dict [new_key ] = converted_state_dict .pop (key )
1139+
1140+ for key in list (converted_state_dict .keys ()):
1141+ for special_key , handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP .items ():
1142+ if special_key not in key :
1143+ continue
1144+ handler_fn_inplace (key , converted_state_dict )
1145+
1146+ # Add back the "transformer." prefix
1147+ for key in list (converted_state_dict .keys ()):
1148+ converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
1149+
1150+ return converted_state_dict
0 commit comments