Skip to content

Commit 70941ff

Browse files
Refactor GraLoRA code for clearer documentation, simplified inheritance, and more intuitive hybrid_r handling
1 parent a24d156 commit 70941ff

File tree

3 files changed

+86
-292
lines changed

3 files changed

+86
-292
lines changed

src/peft/tuners/gralora/config.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,54 @@
2121

2222
@dataclass
2323
class GraloraConfig(PeftConfig):
24-
r: int = field(default=8, metadata={"help": "gralora attention dimension"})
24+
r: int = field(
25+
default=32,
26+
metadata={
27+
"help": (
28+
"GraLoRA attention dimension determines the rank of the GraLoRA adapter. "
29+
"The total parameter count of the GraLoRA adapter is same as LoRA with same rank r, while the expressivitiy is multiplied by gralora_k."
30+
)
31+
},
32+
)
2533
hybrid_r: int = field(
26-
default=0, metadata={"help": "hybrid_r is the rank allocated to vanilla LoRA method when using Hybrid GraLoRA"}
34+
default=0,
35+
metadata={
36+
"help": (
37+
"hybrid_r is the rank allocated to vanilla LoRA method when using Hybrid GraLoRA method. "
38+
"Hybrid GraLoRA, a combination of GraLoRA and vanilla LoRA, becomes available when hybrid_r > 0. "
39+
"r + hybrid_r determines the parameter count of the GraLoRA adapter."
40+
)
41+
},
2742
)
2843
target_modules: Optional[Union[list[str], str]] = field(
2944
default=None,
3045
metadata={
3146
"help": (
32-
"List of module names or regex expression of the module names to replace with gralora."
47+
"List of module names or regex expression of the module names to replace with gralora. "
3348
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. "
3449
"Only linear layers are supported."
3550
)
3651
},
3752
)
38-
gralora_alpha: int = field(default=8, metadata={"help": "gralora alpha"})
53+
gralora_alpha: int = field(
54+
default=64,
55+
metadata={
56+
"help": (
57+
"gralora alpha is the scaling factor for the GraLoRA adapter."
58+
"Scale becomes gralora_alpha / (r + hybrid_r)."
59+
)
60+
},
61+
)
3962
gralora_dropout: float = field(default=0.0, metadata={"help": "gralora dropout"})
40-
gralora_k: int = field(default=2, metadata={"help": "gralora k"})
63+
gralora_k: int = field(
64+
default=2,
65+
metadata={
66+
"help": (
67+
"gralora_k determines the number of subblocks in the GraLoRA adapter."
68+
"The total parameter count is preserved regardles of gralora_k, while the expressivitiy is multiplied by gralora_k."
69+
)
70+
},
71+
)
4172
fan_in_fan_out: bool = field(
4273
default=False,
4374
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},

src/peft/tuners/gralora/layer.py

Lines changed: 46 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def __init__(self, base_layer: nn.Module, **kwargs):
3838
self.scaling = {}
3939
self.gralora_dropout = nn.ModuleDict({})
4040

41-
# Set to `None` otherwise to avoid computation with random weight
4241
self.gralora_A = nn.ParameterDict({})
4342
self.gralora_B = nn.ParameterDict({})
4443
self.gralora_A_general = nn.ModuleDict({})
@@ -55,57 +54,13 @@ def __init__(self, base_layer: nn.Module, **kwargs):
5554
in_features, out_features = (
5655
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape
5756
)
57+
else:
58+
raise NotImplementedError(f"Unsupported layer type {type(base_layer)}")
5859

5960
self.in_features = in_features
6061
self.out_features = out_features
6162
self.kwargs = kwargs
6263

63-
def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optional[torch.device] = None) -> None:
64-
"""
65-
Move the adapter of the given name to the device of the base layer.
66-
"""
67-
from peft.tuners._buffer_dict import BufferDict
68-
69-
if device is None:
70-
# check weight and qweight (for GPTQ)
71-
for weight_name in ("weight", "qweight"):
72-
weight = getattr(self.get_base_layer(), weight_name, None)
73-
if weight is not None:
74-
device = weight.device
75-
dtype = weight.dtype
76-
break
77-
else:
78-
# no break encountered: could not determine the device
79-
return
80-
81-
# loop through all potential adapter layers and move them to the device of the base layer; be careful to only
82-
# move this specific adapter to the device, as the other adapters could be on different devices
83-
# see #1639
84-
for adapter_layer_name in self.adapter_layer_names + self.other_param_names:
85-
adapter_layer = getattr(self, adapter_layer_name, None)
86-
if not isinstance(adapter_layer, (nn.ModuleDict, nn.ParameterDict, BufferDict)):
87-
continue
88-
if adapter_name not in adapter_layer:
89-
continue
90-
if weight.dtype.is_floating_point or weight.dtype.is_complex:
91-
adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device, dtype=dtype)
92-
else:
93-
adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device)
94-
95-
@property
96-
def merged(self) -> bool:
97-
return bool(self.merged_adapters)
98-
99-
@property
100-
def bias(self) -> torch.Tensor:
101-
base_layer = self.get_base_layer()
102-
if isinstance(base_layer, nn.Linear):
103-
return base_layer.bias
104-
elif isinstance(base_layer, Conv1D):
105-
return base_layer.bias
106-
else:
107-
return None
108-
10964
def update_layer(
11065
self,
11166
adapter_name,
@@ -119,6 +74,8 @@ def update_layer(
11974
):
12075
if r <= 0:
12176
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
77+
elif hybrid_r < 0:
78+
raise ValueError(f"`hybrid_r` should be a non-negative integer value but the value passed is {hybrid_r}")
12279

12380
self.r[adapter_name] = r
12481
self.gralora_alpha[adapter_name] = gralora_alpha
@@ -133,21 +90,31 @@ def update_layer(
13390
self.gralora_dropout.update(nn.ModuleDict({adapter_name: gralora_dropout_layer}))
13491

13592
# Actual trainable parameters
93+
if self.in_features % gralora_k != 0:
94+
raise ValueError(
95+
f"in_features should be divisible by gralora_k, but got {self.in_features} and {gralora_k}"
96+
)
97+
if self.out_features % gralora_k != 0:
98+
raise ValueError(
99+
f"out_features should be divisible by gralora_k, but got {self.out_features} and {gralora_k}"
100+
)
136101
subblock_in_features = self.in_features // gralora_k
137102
subblock_out_features = self.out_features // gralora_k
138103

139-
gralora_r = r - hybrid_r # gralora_r is the rank allocated to gralora method
140-
assert gralora_r % gralora_k == 0, f"r should be divisible by gralora_k, but got {r} and {gralora_k}"
104+
# gralora_r is the rank allocated to GraLoRA method; hybrid_r is the rank allocated to vanilla LoRA
105+
gralora_r = r
106+
if gralora_r % gralora_k != 0:
107+
raise ValueError(f"r should be divisible by gralora_k, but got {r} and {gralora_k}")
141108

142-
gralora_A = nn.ParameterList()
143-
gralora_B = nn.ParameterList()
109+
gralora_A = []
110+
gralora_B = []
144111
for _ in range(gralora_k):
145-
new_A = nn.Parameter(torch.zeros(gralora_r, subblock_in_features))
146-
new_B = nn.Parameter(torch.zeros(subblock_out_features, gralora_r))
112+
new_A = nn.Parameter(torch.empty(gralora_r, subblock_in_features))
113+
new_B = nn.Parameter(torch.empty(subblock_out_features, gralora_r))
147114
if init_weights:
148115
# Initialize to identity: A is random, B is zero
149116
nn.init.kaiming_uniform_(new_A, a=math.sqrt(5))
150-
# new_B is already initialized to zeros
117+
nn.init.zeros_(new_B)
151118
else:
152119
# Initialize to random: both A and B are random (for testing)
153120
nn.init.kaiming_uniform_(new_A, a=math.sqrt(5))
@@ -183,7 +150,7 @@ def update_layer(
183150

184151
self.module_name = module_name
185152

186-
self.scaling[adapter_name] = gralora_alpha / r
153+
self.scaling[adapter_name] = gralora_alpha / (gralora_r + hybrid_r)
187154
self._move_adapter_to_device_of_base_layer(adapter_name)
188155
self.set_adapter(self.active_adapters)
189156

@@ -305,30 +272,38 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
305272
# Get dimensions
306273
in_features = self.in_features
307274
out_features = self.out_features
308-
subblock_in = in_features // gralora_k
309-
subblock_out = out_features // gralora_k
310-
gralora_rank = r - hybrid_r
275+
gralora_rank = r
276+
if in_features % gralora_k != 0:
277+
raise ValueError(f"in_features should be divisible by gralora_k, but got {in_features} and {gralora_k}")
278+
elif out_features % gralora_k != 0:
279+
raise ValueError(f"out_features should be divisible by gralora_k, but got {out_features} and {gralora_k}")
280+
elif gralora_rank % gralora_k != 0:
281+
raise ValueError(f"rank should be divisible by gralora_k, but got {gralora_rank} and {gralora_k}")
311282
subblock_gralora_rank = gralora_rank // gralora_k
312283

313284
# scatter gralora_A to get the scattered weight matrix
314285
l_indices = torch.arange(in_features, device=device)
315-
n_indices = (l_indices // (in_features // gralora_k))
316-
i_indices = (l_indices % (in_features // gralora_k))
286+
n_indices = l_indices // (in_features // gralora_k)
287+
i_indices = l_indices % (in_features // gralora_k)
317288
gralora_A_scattered = torch.zeros(in_features, gralora_k, gralora_rank, device=device, dtype=dtype)
318-
gralora_A_scattered.scatter_(1,
289+
gralora_A_scattered.scatter_(
290+
1,
319291
n_indices.unsqueeze(1).unsqueeze(2).expand(-1, 1, gralora_rank),
320-
gralora_A[n_indices, i_indices, :].unsqueeze(1)
292+
gralora_A[n_indices, i_indices, :].unsqueeze(1),
321293
)
322294

323295
# compute the delta weight
324-
delta_weight = torch.einsum(
325-
"ikr, kro -> iko",
326-
gralora_A_scattered
327-
.view(in_features, gralora_k, gralora_k, subblock_gralora_rank)
328-
.permute(0, 2, 1, 3)
329-
.reshape(in_features, gralora_k, gralora_rank),
330-
gralora_B,
331-
).reshape(in_features, out_features).T
296+
delta_weight = (
297+
torch.einsum(
298+
"ikr, kro -> iko",
299+
gralora_A_scattered.view(in_features, gralora_k, gralora_k, subblock_gralora_rank)
300+
.permute(0, 2, 1, 3)
301+
.reshape(in_features, gralora_k, gralora_rank),
302+
gralora_B,
303+
)
304+
.reshape(in_features, out_features)
305+
.T
306+
)
332307

333308
# Add hybrid LoRA component if present
334309
if hybrid_r > 0:
@@ -380,16 +355,14 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
380355
gralora_B_general = self.gralora_B_general[active_adapter]
381356

382357
r = self.r[active_adapter]
358+
gralora_rank = r
383359
gralora_k = self.gralora_k[active_adapter]
384360
hybrid_r = self.hybrid_r[active_adapter]
385361

386-
assert len(gralora_A) == len(gralora_B)
387-
388362
dropout = self.gralora_dropout[active_adapter]
389363
scaling = self.scaling[active_adapter]
390364

391365
gralora_dtype = gralora_A.dtype
392-
gralora_rank = r - hybrid_r
393366

394367
B, L, in_features = x.shape
395368
N = gralora_k

0 commit comments

Comments
 (0)