From a4821dc6eee401f97ff6e26b0953276a911f6ea1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 19 Jan 2025 18:53:12 +0530 Subject: [PATCH] set rest of the blocks with requires_grad False. --- examples/flux-control/train_control_flux.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/flux-control/train_control_flux.py b/examples/flux-control/train_control_flux.py index 7d0e28069054..4449811ab747 100644 --- a/examples/flux-control/train_control_flux.py +++ b/examples/flux-control/train_control_flux.py @@ -812,6 +812,8 @@ def main(args): for name, module in flux_transformer.named_modules(): if "transformer_blocks" in name: module.requires_grad_(True) + else: + module.requirs_grad_(False) def unwrap_model(model): model = accelerator.unwrap_model(model)