File tree Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -837,11 +837,6 @@ def main(args):
837837 assert torch .all (flux_transformer .x_embedder .weight [:, initial_input_channels :].data == 0 )
838838 flux_transformer .register_to_config (in_channels = initial_input_channels * 2 , out_channels = initial_input_channels )
839839
840- if args .train_norm_layers :
841- for name , param in flux_transformer .named_parameters ():
842- if any (k in name for k in NORM_LAYER_PREFIXES ):
843- param .requires_grad = True
844-
845840 if args .lora_layers is not None :
846841 if args .lora_layers != "all-linear" :
847842 target_modules = [layer .strip () for layer in args .lora_layers .split ("," )]
@@ -879,6 +874,11 @@ def main(args):
879874 )
880875 flux_transformer .add_adapter (transformer_lora_config )
881876
877+ if args .train_norm_layers :
878+ for name , param in flux_transformer .named_parameters ():
879+ if any (k in name for k in NORM_LAYER_PREFIXES ):
880+ param .requires_grad = True
881+
882882 def unwrap_model (model ):
883883 model = accelerator .unwrap_model (model )
884884 model = model ._orig_mod if is_compiled_module (model ) else model
You can’t perform that action at this time.
0 commit comments