Skip to content

Commit 5ef74fd

Browse files
authored
fix norm not training in train_control_lora_flux.py (#11832)
1 parent 64a9210 commit 5ef74fd

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

examples/flux-control/train_control_lora_flux.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -837,11 +837,6 @@ def main(args):
837837
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
838838
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
839839

840-
if args.train_norm_layers:
841-
for name, param in flux_transformer.named_parameters():
842-
if any(k in name for k in NORM_LAYER_PREFIXES):
843-
param.requires_grad = True
844-
845840
if args.lora_layers is not None:
846841
if args.lora_layers != "all-linear":
847842
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
@@ -879,6 +874,11 @@ def main(args):
879874
)
880875
flux_transformer.add_adapter(transformer_lora_config)
881876

877+
if args.train_norm_layers:
878+
for name, param in flux_transformer.named_parameters():
879+
if any(k in name for k in NORM_LAYER_PREFIXES):
880+
param.requires_grad = True
881+
882882
def unwrap_model(model):
883883
model = accelerator.unwrap_model(model)
884884
model = model._orig_mod if is_compiled_module(model) else model

0 commit comments

Comments
 (0)