@@ -291,9 +291,7 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):
291291
292292 return modules_to_save
293293
294- def check_if_adapters_added_correctly (
295- self , pipe , text_lora_config = None , denoiser_lora_config = None , adapter_name = "default"
296- ):
294+ def add_adapters_to_pipeline (self , pipe , text_lora_config = None , denoiser_lora_config = None , adapter_name = "default" ):
297295 if text_lora_config is not None :
298296 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
299297 pipe .text_encoder .add_adapter (text_lora_config , adapter_name = adapter_name )
@@ -345,7 +343,7 @@ def test_simple_inference_with_text_lora(self):
345343 output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
346344 self .assertTrue (output_no_lora .shape == self .output_shape )
347345
348- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config = None )
346+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
349347
350348 output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
351349 self .assertTrue (
@@ -428,7 +426,7 @@ def test_low_cpu_mem_usage_with_loading(self):
428426 output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
429427 self .assertTrue (output_no_lora .shape == self .output_shape )
430428
431- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
429+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
432430
433431 images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
434432
@@ -484,7 +482,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
484482 output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
485483 self .assertTrue (output_no_lora .shape == self .output_shape )
486484
487- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config = None )
485+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
488486
489487 output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
490488 self .assertTrue (
@@ -522,7 +520,7 @@ def test_simple_inference_with_text_lora_fused(self):
522520 output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
523521 self .assertTrue (output_no_lora .shape == self .output_shape )
524522
525- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config = None )
523+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
526524
527525 pipe .fuse_lora ()
528526 # Fusing should still keep the LoRA layers
@@ -554,7 +552,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
554552 output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
555553 self .assertTrue (output_no_lora .shape == self .output_shape )
556554
557- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config = None )
555+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
558556
559557 pipe .unload_lora_weights ()
560558 # unloading should remove the LoRA layers
@@ -589,7 +587,7 @@ def test_simple_inference_with_text_lora_save_load(self):
589587 output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
590588 self .assertTrue (output_no_lora .shape == self .output_shape )
591589
592- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config = None )
590+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
593591
594592 images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
595593
@@ -640,7 +638,7 @@ def test_simple_inference_with_partial_text_lora(self):
640638 output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
641639 self .assertTrue (output_no_lora .shape == self .output_shape )
642640
643- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config = None )
641+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
644642
645643 state_dict = {}
646644 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
@@ -691,7 +689,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
691689 output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
692690 self .assertTrue (output_no_lora .shape == self .output_shape )
693691
694- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config = None )
692+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
695693 images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
696694
697695 with tempfile .TemporaryDirectory () as tmpdirname :
@@ -734,7 +732,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
734732 output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
735733 self .assertTrue (output_no_lora .shape == self .output_shape )
736734
737- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
735+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
738736
739737 images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
740738
@@ -775,7 +773,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
775773 output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
776774 self .assertTrue (output_no_lora .shape == self .output_shape )
777775
778- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
776+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
779777
780778 output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
781779 self .assertTrue (
@@ -819,7 +817,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
819817 output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
820818 self .assertTrue (output_no_lora .shape == self .output_shape )
821819
822- pipe , denoiser = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
820+ pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
823821
824822 pipe .fuse_lora (components = self .pipeline_class ._lora_loadable_modules )
825823
@@ -857,7 +855,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
857855 output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
858856 self .assertTrue (output_no_lora .shape == self .output_shape )
859857
860- pipe , denoiser = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
858+ pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
861859
862860 pipe .unload_lora_weights ()
863861 # unloading should remove the LoRA layers
@@ -893,7 +891,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused(
893891 pipe .set_progress_bar_config (disable = None )
894892 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
895893
896- pipe , denoiser = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
894+ pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
897895
898896 pipe .fuse_lora (components = self .pipeline_class ._lora_loadable_modules )
899897 self .assertTrue (pipe .num_fused_loras == 1 , f"{ pipe .num_fused_loras = } , { pipe .fused_loras = } " )
@@ -1010,7 +1008,7 @@ def test_wrong_adapter_name_raises_error(self):
10101008 pipe .set_progress_bar_config (disable = None )
10111009 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
10121010
1013- pipe , _ = self .check_if_adapters_added_correctly (
1011+ pipe , _ = self .add_adapters_to_pipeline (
10141012 pipe , text_lora_config , denoiser_lora_config , adapter_name = adapter_name
10151013 )
10161014
@@ -1032,7 +1030,7 @@ def test_multiple_wrong_adapter_name_raises_error(self):
10321030 pipe .set_progress_bar_config (disable = None )
10331031 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
10341032
1035- pipe , _ = self .check_if_adapters_added_correctly (
1033+ pipe , _ = self .add_adapters_to_pipeline (
10361034 pipe , text_lora_config , denoiser_lora_config , adapter_name = adapter_name
10371035 )
10381036
@@ -1759,7 +1757,7 @@ def test_simple_inference_with_dora(self):
17591757 output_no_dora_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
17601758 self .assertTrue (output_no_dora_lora .shape == self .output_shape )
17611759
1762- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
1760+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
17631761
17641762 output_dora_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
17651763
@@ -1850,7 +1848,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
18501848 pipe .set_progress_bar_config (disable = None )
18511849 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
18521850
1853- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
1851+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
18541852
18551853 pipe .unet = torch .compile (pipe .unet , mode = "reduce-overhead" , fullgraph = True )
18561854 pipe .text_encoder = torch .compile (pipe .text_encoder , mode = "reduce-overhead" , fullgraph = True )
@@ -1937,7 +1935,7 @@ def test_set_adapters_match_attention_kwargs(self):
19371935 output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
19381936 self .assertTrue (output_no_lora .shape == self .output_shape )
19391937
1940- pipe , _ = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
1938+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
19411939
19421940 lora_scale = 0.5
19431941 attention_kwargs = {attention_kwargs_name : {"scale" : lora_scale }}
@@ -2119,7 +2117,7 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
21192117 pipe = pipe .to (torch_device , dtype = compute_dtype )
21202118 pipe .set_progress_bar_config (disable = None )
21212119
2122- pipe , denoiser = self .check_if_adapters_added_correctly (pipe , text_lora_config , denoiser_lora_config )
2120+ pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
21232121
21242122 if storage_dtype is not None :
21252123 denoiser .enable_layerwise_casting (storage_dtype = storage_dtype , compute_dtype = compute_dtype )
@@ -2237,7 +2235,7 @@ def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
22372235 )
22382236 pipe = self .pipeline_class (** components )
22392237
2240- pipe , _ = self .check_if_adapters_added_correctly (
2238+ pipe , _ = self .add_adapters_to_pipeline (
22412239 pipe , text_lora_config = text_lora_config , denoiser_lora_config = denoiser_lora_config
22422240 )
22432241
@@ -2290,7 +2288,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
22902288 output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
22912289 self .assertTrue (output_no_lora .shape == self .output_shape )
22922290
2293- pipe , _ = self .check_if_adapters_added_correctly (
2291+ pipe , _ = self .add_adapters_to_pipeline (
22942292 pipe , text_lora_config = text_lora_config , denoiser_lora_config = denoiser_lora_config
22952293 )
22962294 output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -2309,6 +2307,25 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
23092307 np .allclose (output_lora , output_lora_pretrained , atol = 1e-3 , rtol = 1e-3 ), "Lora outputs should match."
23102308 )
23112309
2310+ def test_lora_unload_add_adapter (self ):
2311+ """Tests if `unload_lora_weights()` -> `add_adapter()` works."""
2312+ scheduler_cls = self .scheduler_classes [0 ]
2313+ components , text_lora_config , denoiser_lora_config = self .get_dummy_components (scheduler_cls )
2314+ pipe = self .pipeline_class (** components ).to (torch_device )
2315+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
2316+
2317+ pipe , _ = self .add_adapters_to_pipeline (
2318+ pipe , text_lora_config = text_lora_config , denoiser_lora_config = denoiser_lora_config
2319+ )
2320+ _ = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2321+
2322+ # unload and then add.
2323+ pipe .unload_lora_weights ()
2324+ pipe , _ = self .add_adapters_to_pipeline (
2325+ pipe , text_lora_config = text_lora_config , denoiser_lora_config = denoiser_lora_config
2326+ )
2327+ _ = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2328+
23122329 def test_inference_load_delete_load_adapters (self ):
23132330 "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
23142331 for scheduler_cls in self .scheduler_classes :
0 commit comments