diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index a63654bc998b..44a58fa2a815 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -51,7 +51,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler -from diffusers.training_utils import resolve_interpolation_mode +from diffusers.training_utils import cast_training_params, resolve_interpolation_mode from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -860,10 +860,8 @@ def main(args): # Make sure the trainable params are in float32. if args.mixed_precision == "fp16": - for param in unet.parameters(): - # only upcast trainable parameters (LoRA) into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(unet, dtype=torch.float32) # Also move the alpha and sigma noise schedules to accelerator.device. alpha_schedule = alpha_schedule.to(accelerator.device) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index c59036d13beb..a995eb3043dc 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -53,7 +53,7 @@ ) from diffusers.loaders import LoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import _set_state_dict_into_text_encoder, compute_snr +from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr from diffusers.utils import ( check_min_version, convert_state_dict_to_diffusers, @@ -1086,11 +1086,8 @@ def load_model_hook(models, input_dir): models = [unet_] if args.train_text_encoder: models.extend([text_encoder_one_, text_encoder_two_]) - for model in models: - for param in model.parameters(): - # only upcast trainable parameters (LoRA) into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -1110,11 +1107,9 @@ def load_model_hook(models, input_dir): models = [unet] if args.train_text_encoder: models.extend([text_encoder_one, text_encoder_two]) - for model in models: - for param in model.parameters(): - # only upcast trainable parameters (LoRA) into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) + + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 27dedc8f7fd1..959d6d362d1e 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -43,7 +43,7 @@ import diffusers from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_snr +from diffusers.training_utils import cast_training_params, compute_snr from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -466,10 +466,8 @@ def main(): # Add adapter and make sure the trainable params are in float32. unet.add_adapter(unet_lora_config) if args.mixed_precision == "fp16": - for param in unet.parameters(): - # only upcast trainable parameters (LoRA) into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(unet, dtype=torch.float32) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 606a88f55b32..32345347f11a 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -51,7 +51,7 @@ ) from diffusers.loaders import LoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_snr +from diffusers.training_utils import cast_training_params, compute_snr from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -634,11 +634,8 @@ def main(args): models = [unet] if args.train_text_encoder: models.extend([text_encoder_one, text_encoder_two]) - for model in models: - for param in model.parameters(): - # only upcast trainable parameters (LoRA) into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 8ff904305242..596e5c4868fe 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -1,7 +1,7 @@ import contextlib import copy import random -from typing import Any, Dict, Iterable, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union import numpy as np import torch @@ -121,6 +121,16 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: return lora_state_dict +def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32): + if not isinstance(model, list): + model = [model] + for m in model: + for param in m.parameters(): + # only upcast trainable parameters into fp32 + if param.requires_grad: + param.data = param.to(dtype) + + def _set_state_dict_into_text_encoder( lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module ):