@@ -38,7 +38,6 @@ def __init__(self, base_layer: nn.Module, **kwargs):
3838 self .scaling = {}
3939 self .gralora_dropout = nn .ModuleDict ({})
4040
41- # Set to `None` otherwise to avoid computation with random weight
4241 self .gralora_A = nn .ParameterDict ({})
4342 self .gralora_B = nn .ParameterDict ({})
4443 self .gralora_A_general = nn .ModuleDict ({})
@@ -55,57 +54,13 @@ def __init__(self, base_layer: nn.Module, **kwargs):
5554 in_features , out_features = (
5655 base_layer .weight .ds_shape if hasattr (base_layer .weight , "ds_shape" ) else base_layer .weight .shape
5756 )
57+ else :
58+ raise NotImplementedError (f"Unsupported layer type { type (base_layer )} " )
5859
5960 self .in_features = in_features
6061 self .out_features = out_features
6162 self .kwargs = kwargs
6263
63- def _move_adapter_to_device_of_base_layer (self , adapter_name : str , device : Optional [torch .device ] = None ) -> None :
64- """
65- Move the adapter of the given name to the device of the base layer.
66- """
67- from peft .tuners ._buffer_dict import BufferDict
68-
69- if device is None :
70- # check weight and qweight (for GPTQ)
71- for weight_name in ("weight" , "qweight" ):
72- weight = getattr (self .get_base_layer (), weight_name , None )
73- if weight is not None :
74- device = weight .device
75- dtype = weight .dtype
76- break
77- else :
78- # no break encountered: could not determine the device
79- return
80-
81- # loop through all potential adapter layers and move them to the device of the base layer; be careful to only
82- # move this specific adapter to the device, as the other adapters could be on different devices
83- # see #1639
84- for adapter_layer_name in self .adapter_layer_names + self .other_param_names :
85- adapter_layer = getattr (self , adapter_layer_name , None )
86- if not isinstance (adapter_layer , (nn .ModuleDict , nn .ParameterDict , BufferDict )):
87- continue
88- if adapter_name not in adapter_layer :
89- continue
90- if weight .dtype .is_floating_point or weight .dtype .is_complex :
91- adapter_layer [adapter_name ] = adapter_layer [adapter_name ].to (device , dtype = dtype )
92- else :
93- adapter_layer [adapter_name ] = adapter_layer [adapter_name ].to (device )
94-
95- @property
96- def merged (self ) -> bool :
97- return bool (self .merged_adapters )
98-
99- @property
100- def bias (self ) -> torch .Tensor :
101- base_layer = self .get_base_layer ()
102- if isinstance (base_layer , nn .Linear ):
103- return base_layer .bias
104- elif isinstance (base_layer , Conv1D ):
105- return base_layer .bias
106- else :
107- return None
108-
10964 def update_layer (
11065 self ,
11166 adapter_name ,
@@ -119,6 +74,8 @@ def update_layer(
11974 ):
12075 if r <= 0 :
12176 raise ValueError (f"`r` should be a positive integer value but the value passed is { r } " )
77+ elif hybrid_r < 0 :
78+ raise ValueError (f"`hybrid_r` should be a non-negative integer value but the value passed is { hybrid_r } " )
12279
12380 self .r [adapter_name ] = r
12481 self .gralora_alpha [adapter_name ] = gralora_alpha
@@ -133,21 +90,31 @@ def update_layer(
13390 self .gralora_dropout .update (nn .ModuleDict ({adapter_name : gralora_dropout_layer }))
13491
13592 # Actual trainable parameters
93+ if self .in_features % gralora_k != 0 :
94+ raise ValueError (
95+ f"in_features should be divisible by gralora_k, but got { self .in_features } and { gralora_k } "
96+ )
97+ if self .out_features % gralora_k != 0 :
98+ raise ValueError (
99+ f"out_features should be divisible by gralora_k, but got { self .out_features } and { gralora_k } "
100+ )
136101 subblock_in_features = self .in_features // gralora_k
137102 subblock_out_features = self .out_features // gralora_k
138103
139- gralora_r = r - hybrid_r # gralora_r is the rank allocated to gralora method
140- assert gralora_r % gralora_k == 0 , f"r should be divisible by gralora_k, but got { r } and { gralora_k } "
104+ # gralora_r is the rank allocated to GraLoRA method; hybrid_r is the rank allocated to vanilla LoRA
105+ gralora_r = r
106+ if gralora_r % gralora_k != 0 :
107+ raise ValueError (f"r should be divisible by gralora_k, but got { r } and { gralora_k } " )
141108
142- gralora_A = nn . ParameterList ()
143- gralora_B = nn . ParameterList ()
109+ gralora_A = []
110+ gralora_B = []
144111 for _ in range (gralora_k ):
145- new_A = nn .Parameter (torch .zeros (gralora_r , subblock_in_features ))
146- new_B = nn .Parameter (torch .zeros (subblock_out_features , gralora_r ))
112+ new_A = nn .Parameter (torch .empty (gralora_r , subblock_in_features ))
113+ new_B = nn .Parameter (torch .empty (subblock_out_features , gralora_r ))
147114 if init_weights :
148115 # Initialize to identity: A is random, B is zero
149116 nn .init .kaiming_uniform_ (new_A , a = math .sqrt (5 ))
150- # new_B is already initialized to zeros
117+ nn . init . zeros_ ( new_B )
151118 else :
152119 # Initialize to random: both A and B are random (for testing)
153120 nn .init .kaiming_uniform_ (new_A , a = math .sqrt (5 ))
@@ -183,7 +150,7 @@ def update_layer(
183150
184151 self .module_name = module_name
185152
186- self .scaling [adapter_name ] = gralora_alpha / r
153+ self .scaling [adapter_name ] = gralora_alpha / ( gralora_r + hybrid_r )
187154 self ._move_adapter_to_device_of_base_layer (adapter_name )
188155 self .set_adapter (self .active_adapters )
189156
@@ -305,30 +272,38 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
305272 # Get dimensions
306273 in_features = self .in_features
307274 out_features = self .out_features
308- subblock_in = in_features // gralora_k
309- subblock_out = out_features // gralora_k
310- gralora_rank = r - hybrid_r
275+ gralora_rank = r
276+ if in_features % gralora_k != 0 :
277+ raise ValueError (f"in_features should be divisible by gralora_k, but got { in_features } and { gralora_k } " )
278+ elif out_features % gralora_k != 0 :
279+ raise ValueError (f"out_features should be divisible by gralora_k, but got { out_features } and { gralora_k } " )
280+ elif gralora_rank % gralora_k != 0 :
281+ raise ValueError (f"rank should be divisible by gralora_k, but got { gralora_rank } and { gralora_k } " )
311282 subblock_gralora_rank = gralora_rank // gralora_k
312283
313284 # scatter gralora_A to get the scattered weight matrix
314285 l_indices = torch .arange (in_features , device = device )
315- n_indices = ( l_indices // (in_features // gralora_k ) )
316- i_indices = ( l_indices % (in_features // gralora_k ) )
286+ n_indices = l_indices // (in_features // gralora_k )
287+ i_indices = l_indices % (in_features // gralora_k )
317288 gralora_A_scattered = torch .zeros (in_features , gralora_k , gralora_rank , device = device , dtype = dtype )
318- gralora_A_scattered .scatter_ (1 ,
289+ gralora_A_scattered .scatter_ (
290+ 1 ,
319291 n_indices .unsqueeze (1 ).unsqueeze (2 ).expand (- 1 , 1 , gralora_rank ),
320- gralora_A [n_indices , i_indices , :].unsqueeze (1 )
292+ gralora_A [n_indices , i_indices , :].unsqueeze (1 ),
321293 )
322294
323295 # compute the delta weight
324- delta_weight = torch .einsum (
325- "ikr, kro -> iko" ,
326- gralora_A_scattered
327- .view (in_features , gralora_k , gralora_k , subblock_gralora_rank )
328- .permute (0 , 2 , 1 , 3 )
329- .reshape (in_features , gralora_k , gralora_rank ),
330- gralora_B ,
331- ).reshape (in_features , out_features ).T
296+ delta_weight = (
297+ torch .einsum (
298+ "ikr, kro -> iko" ,
299+ gralora_A_scattered .view (in_features , gralora_k , gralora_k , subblock_gralora_rank )
300+ .permute (0 , 2 , 1 , 3 )
301+ .reshape (in_features , gralora_k , gralora_rank ),
302+ gralora_B ,
303+ )
304+ .reshape (in_features , out_features )
305+ .T
306+ )
332307
333308 # Add hybrid LoRA component if present
334309 if hybrid_r > 0 :
@@ -380,16 +355,14 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
380355 gralora_B_general = self .gralora_B_general [active_adapter ]
381356
382357 r = self .r [active_adapter ]
358+ gralora_rank = r
383359 gralora_k = self .gralora_k [active_adapter ]
384360 hybrid_r = self .hybrid_r [active_adapter ]
385361
386- assert len (gralora_A ) == len (gralora_B )
387-
388362 dropout = self .gralora_dropout [active_adapter ]
389363 scaling = self .scaling [active_adapter ]
390364
391365 gralora_dtype = gralora_A .dtype
392- gralora_rank = r - hybrid_r
393366
394367 B , L , in_features = x .shape
395368 N = gralora_k
0 commit comments