Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,8 @@ def unload_lora(self):
recurse_remove_peft_layers(self)
if hasattr(self, "peft_config"):
del self.peft_config
if hasattr(self, "_hf_peft_config_loaded"):
self._hf_peft_config_loaded = None

def disable_lora(self):
"""
Expand Down
67 changes: 43 additions & 24 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,7 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):

return modules_to_save

def check_if_adapters_added_correctly(
self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"
):
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
if text_lora_config is not None:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
Expand Down Expand Up @@ -344,7 +342,7 @@ def test_simple_inference_with_text_lora(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
Expand Down Expand Up @@ -427,7 +425,7 @@ def test_low_cpu_mem_usage_with_loading(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

Expand Down Expand Up @@ -483,7 +481,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
Expand Down Expand Up @@ -521,7 +519,7 @@ def test_simple_inference_with_text_lora_fused(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

pipe.fuse_lora()
# Fusing should still keep the LoRA layers
Expand Down Expand Up @@ -553,7 +551,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

pipe.unload_lora_weights()
# unloading should remove the LoRA layers
Expand Down Expand Up @@ -588,7 +586,7 @@ def test_simple_inference_with_text_lora_save_load(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

Expand Down Expand Up @@ -639,7 +637,7 @@ def test_simple_inference_with_partial_text_lora(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

state_dict = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
Expand Down Expand Up @@ -690,7 +688,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

with tempfile.TemporaryDirectory() as tmpdirname:
Expand Down Expand Up @@ -733,7 +731,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

Expand Down Expand Up @@ -774,7 +772,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
Expand Down Expand Up @@ -818,7 +816,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)

Expand Down Expand Up @@ -856,7 +854,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

pipe.unload_lora_weights()
# unloading should remove the LoRA layers
Expand Down Expand Up @@ -892,7 +890,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused(
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
Expand Down Expand Up @@ -1009,7 +1007,7 @@ def test_wrong_adapter_name_raises_error(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

pipe, _ = self.check_if_adapters_added_correctly(
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
)

Expand All @@ -1031,7 +1029,7 @@ def test_multiple_wrong_adapter_name_raises_error(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

pipe, _ = self.check_if_adapters_added_correctly(
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
)

Expand Down Expand Up @@ -1758,7 +1756,7 @@ def test_simple_inference_with_dora(self):
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_dora_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

Expand Down Expand Up @@ -1849,7 +1847,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
Expand Down Expand Up @@ -1936,7 +1934,7 @@ def test_set_adapters_match_attention_kwargs(self):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

lora_scale = 0.5
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
Expand Down Expand Up @@ -2118,7 +2116,7 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)

pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

if storage_dtype is not None:
denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
Expand Down Expand Up @@ -2236,7 +2234,7 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
)
pipe = self.pipeline_class(**components)

pipe, _ = self.check_if_adapters_added_correctly(
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)

Expand Down Expand Up @@ -2289,7 +2287,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe, _ = self.check_if_adapters_added_correctly(
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
Expand All @@ -2308,6 +2306,27 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
)

def test_lora_unload_add_adapter(self):
"""Tests if `unload_lora_weights()` -> `add_adapter()` works."""
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]

# unload and then add.
pipe.unload_lora_weights()
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)

output_lora_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_lora, output_lora_2, atol=1e-3, rtol=1e-3), "Lora outputs should match.")

def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes:
Expand Down
Loading