From 6f6aa1b470bc861bd41d5635852a516e4e0d0b6e Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Thu, 7 Nov 2024 00:37:39 +0530 Subject: [PATCH 1/3] fix use_dora --- .../dreambooth/train_dreambooth_lora_sdxl.py | 35 +++++++++++++------ 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 6e621b3caee3..fbdea366ebf6 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -68,6 +68,7 @@ convert_state_dict_to_kohya, convert_unet_state_dict_to_peft, is_wandb_available, + is_peft_version ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available @@ -1183,25 +1184,39 @@ def main(args): text_encoder_one.gradient_checkpointing_enable() text_encoder_two.gradient_checkpointing_enable() + def get_lora_config(rank, use_dora, target_modules): + base_config = { + "r": rank, + "lora_alpha": rank, + "init_lora_weights": "gaussian", + "target_modules": target_modules, + } + if use_dora and is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + base_config["use_dora"] = True + + return LoraConfig(**base_config) + # now we will add new LoRA weights to the attention layers - unet_lora_config = LoraConfig( - r=args.rank, + unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + unet_lora_config = get_lora_config( + rank=args.rank, use_dora=args.use_dora, - lora_alpha=args.rank, - init_lora_weights="gaussian", - target_modules=["to_k", "to_q", "to_v", "to_out.0"], + target_modules=unet_target_modules ) unet.add_adapter(unet_lora_config) # The text encoder comes from 🤗 transformers, so we cannot directly modify it. # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: - text_lora_config = LoraConfig( - r=args.rank, + text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] + text_lora_config = get_lora_config( + rank=args.rank, use_dora=args.use_dora, - lora_alpha=args.rank, - init_lora_weights="gaussian", - target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + target_modules=text_target_modules ) text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) From 936cbeba3ee7dfbe5137bb5e26bf5ee6a3ad8923 Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Thu, 7 Nov 2024 00:43:06 +0530 Subject: [PATCH 2/3] fix style and quality --- .../dreambooth/train_dreambooth_lora_sdxl.py | 22 ++++++------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index fbdea366ebf6..cde8669e1dcc 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -67,8 +67,8 @@ convert_state_dict_to_diffusers, convert_state_dict_to_kohya, convert_unet_state_dict_to_peft, + is_peft_version, is_wandb_available, - is_peft_version ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available @@ -1193,31 +1193,23 @@ def get_lora_config(rank, use_dora, target_modules): } if use_dora and is_peft_version("<", "0.9.0"): raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) else: base_config["use_dora"] = True - + return LoraConfig(**base_config) - + # now we will add new LoRA weights to the attention layers unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"] - unet_lora_config = get_lora_config( - rank=args.rank, - use_dora=args.use_dora, - target_modules=unet_target_modules - ) + unet_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=unet_target_modules) unet.add_adapter(unet_lora_config) # The text encoder comes from 🤗 transformers, so we cannot directly modify it. # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] - text_lora_config = get_lora_config( - rank=args.rank, - use_dora=args.use_dora, - target_modules=text_target_modules - ) + text_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=text_target_modules) text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) From 2b843349b8af29d6422e8d1f3c55082b2ccc1dad Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Thu, 7 Nov 2024 01:26:26 +0530 Subject: [PATCH 3/3] fix use_dora with peft version --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index cde8669e1dcc..9cd321f6d055 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1191,12 +1191,13 @@ def get_lora_config(rank, use_dora, target_modules): "init_lora_weights": "gaussian", "target_modules": target_modules, } - if use_dora and is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - base_config["use_dora"] = True + if use_dora: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + base_config["use_dora"] = True return LoraConfig(**base_config)