From 20cb3a84daf1ee942989871cb7da5f9a89a2f4fb Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 17 Jan 2025 14:44:38 +0000 Subject: [PATCH 1/5] enable dreambooth_lora on other devices Signed-off-by: jiqing-feng --- examples/dreambooth/train_dreambooth_lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index e81fbe80576d..75a530daf1c3 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -151,14 +151,14 @@ def log_validation( if args.validation_images is None: images = [] for _ in range(args.num_validation_images): - with torch.cuda.amp.autocast(): + with torch.amp.autocast(pipeline.device.type): image = pipeline(**pipeline_args, generator=generator).images[0] images.append(image) else: images = [] for image in args.validation_images: image = Image.open(image) - with torch.cuda.amp.autocast(): + with torch.amp.autocast(pipeline.device.type): image = pipeline(**pipeline_args, image=image, generator=generator).images[0] images.append(image) From cd45e14db2b4d21b9ddcd60fca1b11a57743e2e9 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 17 Jan 2025 14:46:41 +0000 Subject: [PATCH 2/5] enable xpu Signed-off-by: jiqing-feng --- examples/dreambooth/train_dreambooth_lora.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 75a530daf1c3..a0f5c85b2e82 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -178,6 +178,8 @@ def log_validation( del pipeline torch.cuda.empty_cache() + if hasattr(torch, "xpu") and torch.xpu.is_available(): + torch.xpu.empty_cache() return images @@ -793,7 +795,7 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + torch_dtype = torch.float16 if accelerator.device.type in ("cuda", "xpu") else torch.float32 if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": @@ -831,6 +833,8 @@ def main(args): del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() + if hasattr(torch, "xpu") and torch.xpu.is_available(): + torch.xpu.empty_cache() # Handle the repository creation if accelerator.is_main_process: @@ -1086,6 +1090,8 @@ def compute_text_embeddings(prompt): gc.collect() torch.cuda.empty_cache() + if hasattr(torch, "xpu") and torch.xpu.is_available(): + torch.xpu.empty_cache() else: pre_computed_encoder_hidden_states = None validation_prompt_encoder_hidden_states = None From a3438487b9d964ab229f1bd96d90a9ba81fd75ec Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 17 Jan 2025 15:03:58 +0000 Subject: [PATCH 3/5] check cuda device before empty cache Signed-off-by: jiqing-feng --- examples/dreambooth/train_dreambooth_lora.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index a0f5c85b2e82..33f6812888e9 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -177,7 +177,8 @@ def log_validation( ) del pipeline - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() if hasattr(torch, "xpu") and torch.xpu.is_available(): torch.xpu.empty_cache() @@ -1089,7 +1090,8 @@ def compute_text_embeddings(prompt): tokenizer = None gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() if hasattr(torch, "xpu") and torch.xpu.is_available(): torch.xpu.empty_cache() else: From 660a0bfb9b70da56788cb763c276d369c3d04e11 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 21 Jan 2025 10:28:27 +0000 Subject: [PATCH 4/5] fix comment Signed-off-by: jiqing-feng --- examples/dreambooth/train_dreambooth_lora.py | 26 +++++++++----------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 33f6812888e9..610fcfc798a7 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -75,6 +75,13 @@ logger = get_logger(__name__) +def free_memory(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + if hasattr(torch, "xpu") and torch.xpu.is_available(): + torch.xpu.empty_cache() + + def save_model_card( repo_id: str, images=None, @@ -151,14 +158,14 @@ def log_validation( if args.validation_images is None: images = [] for _ in range(args.num_validation_images): - with torch.amp.autocast(pipeline.device.type): + with torch.amp.autocast(accelerator.device.type): image = pipeline(**pipeline_args, generator=generator).images[0] images.append(image) else: images = [] for image in args.validation_images: image = Image.open(image) - with torch.amp.autocast(pipeline.device.type): + with torch.amp.autocast(accelerator.device.type): image = pipeline(**pipeline_args, image=image, generator=generator).images[0] images.append(image) @@ -177,10 +184,7 @@ def log_validation( ) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() - if hasattr(torch, "xpu") and torch.xpu.is_available(): - torch.xpu.empty_cache() + free_memory() return images @@ -832,10 +836,7 @@ def main(args): image.save(image_filename) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() - if hasattr(torch, "xpu") and torch.xpu.is_available(): - torch.xpu.empty_cache() + free_memory() # Handle the repository creation if accelerator.is_main_process: @@ -1090,10 +1091,7 @@ def compute_text_embeddings(prompt): tokenizer = None gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - if hasattr(torch, "xpu") and torch.xpu.is_available(): - torch.xpu.empty_cache() + free_memory() else: pre_computed_encoder_hidden_states = None validation_prompt_encoder_hidden_states = None From 83343defa5c6a1bd3d53ad889effde4198a54b49 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 21 Jan 2025 11:02:44 +0000 Subject: [PATCH 5/5] import free_memory Signed-off-by: jiqing-feng --- examples/dreambooth/train_dreambooth_lora.py | 13 +++++-------- src/diffusers/training_utils.py | 2 ++ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 610fcfc798a7..b3206d929f48 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -54,7 +54,11 @@ ) from diffusers.loaders import StableDiffusionLoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params +from diffusers.training_utils import ( + _set_state_dict_into_text_encoder, + cast_training_params, + free_memory, +) from diffusers.utils import ( check_min_version, convert_state_dict_to_diffusers, @@ -75,13 +79,6 @@ logger = get_logger(__name__) -def free_memory(): - if torch.cuda.is_available(): - torch.cuda.empty_cache() - if hasattr(torch, "xpu") and torch.xpu.is_available(): - torch.xpu.empty_cache() - - def save_model_card( repo_id: str, images=None, diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 2474ed5c2114..082640f37a17 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -299,6 +299,8 @@ def free_memory(): torch.mps.empty_cache() elif is_torch_npu_available(): torch_npu.npu.empty_cache() + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + torch.xpu.empty_cache() # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14