1414
1515import copy
1616import inspect
17+ import json
1718import os
1819from pathlib import Path
1920from typing import Callable , Dict , List , Optional , Union
4546 set_adapter_layers ,
4647 set_weights_and_activate_adapters ,
4748)
49+ from ..utils .state_dict_utils import _load_sft_state_dict_metadata
4850
4951
5052if is_transformers_available ():
6264
6365LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
6466LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
67+ LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"
6568
6669
6770def fuse_text_encoder_lora (text_encoder , lora_scale = 1.0 , safe_fusing = False , adapter_names = None ):
@@ -206,6 +209,7 @@ def _fetch_state_dict(
206209 subfolder ,
207210 user_agent ,
208211 allow_pickle ,
212+ metadata = None ,
209213):
210214 model_file = None
211215 if not isinstance (pretrained_model_name_or_path_or_dict , dict ):
@@ -236,11 +240,14 @@ def _fetch_state_dict(
236240 user_agent = user_agent ,
237241 )
238242 state_dict = safetensors .torch .load_file (model_file , device = "cpu" )
243+ metadata = _load_sft_state_dict_metadata (model_file )
244+
239245 except (IOError , safetensors .SafetensorError ) as e :
240246 if not allow_pickle :
241247 raise e
242248 # try loading non-safetensors weights
243249 model_file = None
250+ metadata = None
244251 pass
245252
246253 if model_file is None :
@@ -261,10 +268,11 @@ def _fetch_state_dict(
261268 user_agent = user_agent ,
262269 )
263270 state_dict = load_state_dict (model_file )
271+ metadata = None
264272 else :
265273 state_dict = pretrained_model_name_or_path_or_dict
266274
267- return state_dict
275+ return state_dict , metadata
268276
269277
270278def _best_guess_weight_name (
@@ -306,6 +314,11 @@ def _best_guess_weight_name(
306314 return weight_name
307315
308316
317+ def _pack_dict_with_prefix (state_dict , prefix ):
318+ sd_with_prefix = {f"{ prefix } .{ key } " : value for key , value in state_dict .items ()}
319+ return sd_with_prefix
320+
321+
309322def _load_lora_into_text_encoder (
310323 state_dict ,
311324 network_alphas ,
@@ -317,10 +330,14 @@ def _load_lora_into_text_encoder(
317330 _pipeline = None ,
318331 low_cpu_mem_usage = False ,
319332 hotswap : bool = False ,
333+ metadata = None ,
320334):
321335 if not USE_PEFT_BACKEND :
322336 raise ValueError ("PEFT backend is required for this method." )
323337
338+ if network_alphas and metadata :
339+ raise ValueError ("`network_alphas` and `metadata` cannot be specified both at the same time." )
340+
324341 peft_kwargs = {}
325342 if low_cpu_mem_usage :
326343 if not is_peft_version (">=" , "0.13.1" ):
@@ -349,6 +366,8 @@ def _load_lora_into_text_encoder(
349366 # Load the layers corresponding to text encoder and make necessary adjustments.
350367 if prefix is not None :
351368 state_dict = {k .removeprefix (f"{ prefix } ." ): v for k , v in state_dict .items () if k .startswith (f"{ prefix } ." )}
369+ if metadata is not None :
370+ metadata = {k .removeprefix (f"{ prefix } ." ): v for k , v in metadata .items () if k .startswith (f"{ prefix } ." )}
352371
353372 if len (state_dict ) > 0 :
354373 logger .info (f"Loading { prefix } ." )
@@ -376,7 +395,10 @@ def _load_lora_into_text_encoder(
376395 alpha_keys = [k for k in network_alphas .keys () if k .startswith (prefix ) and k .split ("." )[0 ] == prefix ]
377396 network_alphas = {k .removeprefix (f"{ prefix } ." ): v for k , v in network_alphas .items () if k in alpha_keys }
378397
379- lora_config_kwargs = get_peft_kwargs (rank , network_alphas , state_dict , is_unet = False )
398+ if metadata is not None :
399+ lora_config_kwargs = metadata
400+ else :
401+ lora_config_kwargs = get_peft_kwargs (rank , network_alphas , state_dict , is_unet = False )
380402
381403 if "use_dora" in lora_config_kwargs :
382404 if lora_config_kwargs ["use_dora" ]:
@@ -398,7 +420,10 @@ def _load_lora_into_text_encoder(
398420 if is_peft_version ("<=" , "0.13.2" ):
399421 lora_config_kwargs .pop ("lora_bias" )
400422
401- lora_config = LoraConfig (** lora_config_kwargs )
423+ try :
424+ lora_config = LoraConfig (** lora_config_kwargs )
425+ except TypeError as e :
426+ raise TypeError ("`LoraConfig` class could not be instantiated." ) from e
402427
403428 # adapter_name
404429 if adapter_name is None :
@@ -889,8 +914,7 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
889914 @staticmethod
890915 def pack_weights (layers , prefix ):
891916 layers_weights = layers .state_dict () if isinstance (layers , torch .nn .Module ) else layers
892- layers_state_dict = {f"{ prefix } .{ module_name } " : param for module_name , param in layers_weights .items ()}
893- return layers_state_dict
917+ return _pack_dict_with_prefix (layers_weights , prefix )
894918
895919 @staticmethod
896920 def write_lora_layers (
@@ -900,16 +924,32 @@ def write_lora_layers(
900924 weight_name : str ,
901925 save_function : Callable ,
902926 safe_serialization : bool ,
927+ lora_adapter_metadata : Optional [dict ] = None ,
903928 ):
904929 if os .path .isfile (save_directory ):
905930 logger .error (f"Provided path ({ save_directory } ) should be a directory, not a file" )
906931 return
907932
933+ if lora_adapter_metadata and not safe_serialization :
934+ raise ValueError ("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`." )
935+ if lora_adapter_metadata and not isinstance (lora_adapter_metadata , dict ):
936+ raise TypeError ("`lora_adapter_metadata` must be of type `dict`." )
937+
908938 if save_function is None :
909939 if safe_serialization :
910940
911941 def save_function (weights , filename ):
912- return safetensors .torch .save_file (weights , filename , metadata = {"format" : "pt" })
942+ # Inject framework format.
943+ metadata = {"format" : "pt" }
944+ if lora_adapter_metadata :
945+ for key , value in lora_adapter_metadata .items ():
946+ if isinstance (value , set ):
947+ lora_adapter_metadata [key ] = list (value )
948+ metadata [LORA_ADAPTER_METADATA_KEY ] = json .dumps (
949+ lora_adapter_metadata , indent = 2 , sort_keys = True
950+ )
951+
952+ return safetensors .torch .save_file (weights , filename , metadata = metadata )
913953
914954 else :
915955 save_function = torch .save
0 commit comments