@@ -424,6 +424,17 @@ def _load_lora_into_text_encoder(
424424
425425
426426def _func_optionally_disable_offloading (_pipeline ):
427+ """
428+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
429+
430+ Args:
431+ _pipeline (`DiffusionPipeline`):
432+ The pipeline to disable offloading for.
433+
434+ Returns:
435+ tuple:
436+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
437+ """
427438 is_model_cpu_offload = False
428439 is_sequential_cpu_offload = False
429440
@@ -453,6 +464,24 @@ class LoraBaseMixin:
453464 _lora_loadable_modules = []
454465 _merged_adapters = set ()
455466
467+ @property
468+ def lora_scale (self ) -> float :
469+ """
470+ Returns the lora scale which can be set at run time by the pipeline. # if `_lora_scale` has not been set,
471+ return 1.
472+ """
473+ return self ._lora_scale if hasattr (self , "_lora_scale" ) else 1.0
474+
475+ @property
476+ def num_fused_loras (self ):
477+ """Returns the number of LoRAs that have been fused."""
478+ return len (self ._merged_adapters )
479+
480+ @property
481+ def fused_loras (self ):
482+ """Returns names of the LoRAs that have been fused."""
483+ return self ._merged_adapters
484+
456485 def load_lora_weights (self , ** kwargs ):
457486 raise NotImplementedError ("`load_lora_weights()` is not implemented." )
458487
@@ -464,33 +493,6 @@ def save_lora_weights(cls, **kwargs):
464493 def lora_state_dict (cls , ** kwargs ):
465494 raise NotImplementedError ("`lora_state_dict()` is not implemented." )
466495
467- @classmethod
468- def _optionally_disable_offloading (cls , _pipeline ):
469- """
470- Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
471-
472- Args:
473- _pipeline (`DiffusionPipeline`):
474- The pipeline to disable offloading for.
475-
476- Returns:
477- tuple:
478- A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
479- """
480- return _func_optionally_disable_offloading (_pipeline = _pipeline )
481-
482- @classmethod
483- def _fetch_state_dict (cls , * args , ** kwargs ):
484- deprecation_message = f"Using the `_fetch_state_dict()` method from { cls } has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
485- deprecate ("_fetch_state_dict" , "0.35.0" , deprecation_message )
486- return _fetch_state_dict (* args , ** kwargs )
487-
488- @classmethod
489- def _best_guess_weight_name (cls , * args , ** kwargs ):
490- deprecation_message = f"Using the `_best_guess_weight_name()` method from { cls } has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
491- deprecate ("_best_guess_weight_name" , "0.35.0" , deprecation_message )
492- return _best_guess_weight_name (* args , ** kwargs )
493-
494496 def unload_lora_weights (self ):
495497 """
496498 Unloads the LoRA parameters.
@@ -661,19 +663,37 @@ def unfuse_lora(self, components: List[str] = [], **kwargs):
661663 self ._merged_adapters = self ._merged_adapters - {adapter }
662664 module .unmerge ()
663665
664- @property
665- def num_fused_loras (self ):
666- return len (self ._merged_adapters )
667-
668- @property
669- def fused_loras (self ):
670- return self ._merged_adapters
671-
672666 def set_adapters (
673667 self ,
674668 adapter_names : Union [List [str ], str ],
675669 adapter_weights : Optional [Union [float , Dict , List [float ], List [Dict ]]] = None ,
676670 ):
671+ """
672+ Set the currently active adapters for use in the pipeline.
673+
674+ Args:
675+ adapter_names (`List[str]` or `str`):
676+ The names of the adapters to use.
677+ adapter_weights (`Union[List[float], float]`, *optional*):
678+ The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
679+ adapters.
680+
681+ Example:
682+
683+ ```py
684+ from diffusers import AutoPipelineForText2Image
685+ import torch
686+
687+ pipeline = AutoPipelineForText2Image.from_pretrained(
688+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
689+ ).to("cuda")
690+ pipeline.load_lora_weights(
691+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
692+ )
693+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
694+ pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
695+ ```
696+ """
677697 if isinstance (adapter_weights , dict ):
678698 components_passed = set (adapter_weights .keys ())
679699 lora_components = set (self ._lora_loadable_modules )
@@ -743,6 +763,24 @@ def set_adapters(
743763 set_adapters_for_text_encoder (adapter_names , model , _component_adapter_weights [component ])
744764
745765 def disable_lora (self ):
766+ """
767+ Disables the active LoRA layers of the pipeline.
768+
769+ Example:
770+
771+ ```py
772+ from diffusers import AutoPipelineForText2Image
773+ import torch
774+
775+ pipeline = AutoPipelineForText2Image.from_pretrained(
776+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
777+ ).to("cuda")
778+ pipeline.load_lora_weights(
779+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
780+ )
781+ pipeline.disable_lora()
782+ ```
783+ """
746784 if not USE_PEFT_BACKEND :
747785 raise ValueError ("PEFT backend is required for this method." )
748786
@@ -755,6 +793,24 @@ def disable_lora(self):
755793 disable_lora_for_text_encoder (model )
756794
757795 def enable_lora (self ):
796+ """
797+ Enables the active LoRA layers of the pipeline.
798+
799+ Example:
800+
801+ ```py
802+ from diffusers import AutoPipelineForText2Image
803+ import torch
804+
805+ pipeline = AutoPipelineForText2Image.from_pretrained(
806+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
807+ ).to("cuda")
808+ pipeline.load_lora_weights(
809+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
810+ )
811+ pipeline.enable_lora()
812+ ```
813+ """
758814 if not USE_PEFT_BACKEND :
759815 raise ValueError ("PEFT backend is required for this method." )
760816
@@ -768,10 +824,26 @@ def enable_lora(self):
768824
769825 def delete_adapters (self , adapter_names : Union [List [str ], str ]):
770826 """
827+ Delete an adapter's LoRA layers from the pipeline.
828+
771829 Args:
772- Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
773830 adapter_names (`Union[List[str], str]`):
774- The names of the adapter to delete. Can be a single string or a list of strings
831+ The names of the adapters to delete.
832+
833+ Example:
834+
835+ ```py
836+ from diffusers import AutoPipelineForText2Image
837+ import torch
838+
839+ pipeline = AutoPipelineForText2Image.from_pretrained(
840+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
841+ ).to("cuda")
842+ pipeline.load_lora_weights(
843+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
844+ )
845+ pipeline.delete_adapters("cinematic")
846+ ```
775847 """
776848 if not USE_PEFT_BACKEND :
777849 raise ValueError ("PEFT backend is required for this method." )
@@ -872,6 +944,24 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
872944 adapter_name
873945 ].to (device )
874946
947+ def enable_lora_hotswap (self , ** kwargs ) -> None :
948+ """
949+ Hotswap adapters without triggering recompilation of a model or if the ranks of the loaded adapters are
950+ different.
951+
952+ Args:
953+ target_rank (`int`):
954+ The highest rank among all the adapters that will be loaded.
955+ check_compiled (`str`, *optional*, defaults to `"error"`):
956+ How to handle a model that is already compiled. The check can return the following messages:
957+ - "error" (default): raise an error
958+ - "warn": issue a warning
959+ - "ignore": do nothing
960+ """
961+ for key , component in self .components .items ():
962+ if hasattr (component , "enable_lora_hotswap" ) and (key in self ._lora_loadable_modules ):
963+ component .enable_lora_hotswap (** kwargs )
964+
875965 @staticmethod
876966 def pack_weights (layers , prefix ):
877967 layers_weights = layers .state_dict () if isinstance (layers , torch .nn .Module ) else layers
@@ -887,6 +977,7 @@ def write_lora_layers(
887977 safe_serialization : bool ,
888978 lora_adapter_metadata : Optional [dict ] = None ,
889979 ):
980+ """Writes the state dict of the LoRA layers (optionally with metadata) to disk."""
890981 if os .path .isfile (save_directory ):
891982 logger .error (f"Provided path ({ save_directory } ) should be a directory, not a file" )
892983 return
@@ -927,28 +1018,18 @@ def save_function(weights, filename):
9271018 save_function (state_dict , save_path )
9281019 logger .info (f"Model weights saved in { save_path } " )
9291020
930- @property
931- def lora_scale (self ) -> float :
932- # property function that returns the lora scale which can be set at run time by the pipeline.
933- # if _lora_scale has not been set, return 1
934- return self ._lora_scale if hasattr (self , "_lora_scale" ) else 1.0
935-
936- def enable_lora_hotswap (self , ** kwargs ) -> None :
937- """Enables the possibility to hotswap LoRA adapters.
1021+ @classmethod
1022+ def _optionally_disable_offloading (cls , _pipeline ):
1023+ return _func_optionally_disable_offloading (_pipeline = _pipeline )
9381024
939- Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
940- the loaded adapters differ.
1025+ @classmethod
1026+ def _fetch_state_dict (cls , * args , ** kwargs ):
1027+ deprecation_message = f"Using the `_fetch_state_dict()` method from { cls } has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
1028+ deprecate ("_fetch_state_dict" , "0.35.0" , deprecation_message )
1029+ return _fetch_state_dict (* args , ** kwargs )
9411030
942- Args:
943- target_rank (`int`):
944- The highest rank among all the adapters that will be loaded.
945- check_compiled (`str`, *optional*, defaults to `"error"`):
946- How to handle the case when the model is already compiled, which should generally be avoided. The
947- options are:
948- - "error" (default): raise an error
949- - "warn": issue a warning
950- - "ignore": do nothing
951- """
952- for key , component in self .components .items ():
953- if hasattr (component , "enable_lora_hotswap" ) and (key in self ._lora_loadable_modules ):
954- component .enable_lora_hotswap (** kwargs )
1031+ @classmethod
1032+ def _best_guess_weight_name (cls , * args , ** kwargs ):
1033+ deprecation_message = f"Using the `_best_guess_weight_name()` method from { cls } has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
1034+ deprecate ("_best_guess_weight_name" , "0.35.0" , deprecation_message )
1035+ return _best_guess_weight_name (* args , ** kwargs )
0 commit comments