Skip to content

Commit c53ffce

Browse files
Update test code for the GraLoRA method
1 parent 70941ff commit c53ffce

File tree

2 files changed

+50
-46
lines changed

2 files changed

+50
-46
lines changed

tests/test_custom_models.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
C3AConfig,
3939
DeloraConfig,
4040
FourierFTConfig,
41+
GraloraConfig,
4142
HRAConfig,
4243
IA3Config,
4344
LNTuningConfig,
@@ -665,6 +666,25 @@
665666
"init_weights": True,
666667
},
667668
),
669+
###########
670+
# GraLoRA #
671+
###########
672+
("Vanilla MLP 1 GraLoRA", "MLP", GraloraConfig, {"target_modules": "lin0"}),
673+
("Vanilla MLP 2 GraLoRA", "MLP", GraloraConfig, {"target_modules": ["lin0"]}),
674+
("Vanilla MLP 3 GraLoRA", "MLP", GraloraConfig, {"target_modules": ["lin1"]}),
675+
("Vanilla MLP 4 GraLoRA", "MLP", GraloraConfig, {"target_modules": ["lin0", "lin1"]}),
676+
(
677+
"Vanilla MLP 5 GraLoRA",
678+
"MLP",
679+
GraloraConfig,
680+
{"target_modules": ["lin0"], "modules_to_save": ["lin1"]},
681+
),
682+
(
683+
"Embedding + transformers Conv1D 1 GraLoRA",
684+
"EmbConv1D",
685+
GraloraConfig,
686+
{"target_modules": ["conv1d"], "gralora_k": 1},
687+
),
668688
##########
669689
# VBLoRA #
670690
##########
@@ -973,6 +993,20 @@
973993
{"n_frequency": 10, "target_modules": ["lin0"]},
974994
{"n_frequency": 10, "target_modules": ["lin1"]},
975995
),
996+
(
997+
"GraLoRA Same",
998+
"gralora",
999+
GraloraConfig,
1000+
{"target_modules": ["lin0"], "init_weights": False},
1001+
{"target_modules": ["lin0"], "init_weights": False},
1002+
),
1003+
(
1004+
"GraLoRA Different",
1005+
"gralora",
1006+
GraloraConfig,
1007+
{"target_modules": ["lin0"], "init_weights": False},
1008+
{"target_modules": ["lin1"], "init_weights": False},
1009+
),
9761010
(
9771011
"SHiRA Same",
9781012
"shira",
@@ -1159,6 +1193,7 @@
11591193
VeraConfig: "vera_lambda_",
11601194
RandLoraConfig: "randlora_",
11611195
FourierFTConfig: "fourierft_",
1196+
GraloraConfig: "gralora_",
11621197
C3AConfig: "c3a_",
11631198
HRAConfig: "hra_",
11641199
ShiraConfig: "shira_",
@@ -3104,12 +3139,12 @@ def test_add_weighted_adapter_subtraction_with_negative_weights(self):
31043139
cancelled_B = module.lora_B["cancelled"].weight.data
31053140

31063141
# The weights should be approximately zero (they cancel out)
3107-
assert torch.allclose(cancelled_A, torch.zeros_like(cancelled_A), atol=1e-5), (
3108-
f"Cancelled A should be ~0, got max abs value {cancelled_A.abs().max()}"
3109-
)
3110-
assert torch.allclose(cancelled_B, torch.zeros_like(cancelled_B), atol=1e-5), (
3111-
f"Cancelled B should be ~0, got max abs value {cancelled_B.abs().max()}"
3112-
)
3142+
assert torch.allclose(
3143+
cancelled_A, torch.zeros_like(cancelled_A), atol=1e-5
3144+
), f"Cancelled A should be ~0, got max abs value {cancelled_A.abs().max()}"
3145+
assert torch.allclose(
3146+
cancelled_B, torch.zeros_like(cancelled_B), atol=1e-5
3147+
), f"Cancelled B should be ~0, got max abs value {cancelled_B.abs().max()}"
31133148

31143149
def test_add_weighted_adapter_negative_weight_with_different_scaling(self):
31153150
# Test negative weights with different scaling factors (lora_alpha)
@@ -3515,9 +3550,9 @@ def test_multirank_2(self):
35153550
if isinstance(module, BaseTunerLayer):
35163551
rank_expected = rank_pattern.get(key, r)
35173552
rank_current = module.lora_A[adapter].weight.shape[0]
3518-
assert rank_current == rank_expected, (
3519-
f"Rank {rank_current} is not equal to expected {rank_expected}"
3520-
)
3553+
assert (
3554+
rank_current == rank_expected
3555+
), f"Rank {rank_current} is not equal to expected {rank_expected}"
35213556

35223557

35233558
class TestLayerRepr:

tests/test_gralora.py

Lines changed: 6 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def test_gralora_parameter_shapes(self, mlp_gralora_hybrid):
112112
in_features = module.in_features
113113
out_features = module.out_features
114114
k = 4
115-
gralora_rank = 16 - 4 # r - hybrid_r
115+
gralora_rank = 16
116116

117117
# Check GraLoRA block shapes
118118
# Each block has full gralora_rank, not gralora_rank // k
@@ -203,7 +203,7 @@ def test_gralora_pure_vs_hybrid_params(self):
203203
mlp_hybrid = MLP()
204204
config_hybrid = GraloraConfig(
205205
target_modules=["lin1", "lin2"],
206-
r=16,
206+
r=12,
207207
gralora_k=4,
208208
hybrid_r=4,
209209
)
@@ -217,9 +217,9 @@ def count_trainable_params(model):
217217

218218
# Pure and hybrid should have same total parameters (r is constant)
219219
# but distributed differently between block-diagonal and full-rank components
220-
assert params_pure == params_hybrid, (
221-
f"Pure ({params_pure}) and Hybrid ({params_hybrid}) should have same parameter count"
222-
)
220+
assert (
221+
params_pure == params_hybrid
222+
), f"Pure ({params_pure}) and Hybrid ({params_hybrid}) should have same parameter count"
223223

224224
# Check that hybrid has general components
225225
has_general = False
@@ -444,7 +444,7 @@ def test_gralora_rank_divisibility_check(self):
444444
hybrid_r=0,
445445
)
446446

447-
with pytest.raises(AssertionError, match="r should be divisible by gralora_k"):
447+
with pytest.raises(ValueError, match="r should be divisible by gralora_k"):
448448
get_peft_model(mlp, config)
449449

450450
def test_gralora_trainable_parameters_only(self, mlp_gralora_hybrid):
@@ -827,37 +827,6 @@ def test_gralora_unload_without_merge(self):
827827
# Should match base model output (no merge)
828828
assert torch.allclose(base_output, unloaded_output, atol=1e-5)
829829

830-
def test_gralora_get_peft_config_as_dict(self):
831-
"""Test get_peft_config_as_dict method"""
832-
torch.manual_seed(0)
833-
mlp = MLP()
834-
config = GraloraConfig(
835-
target_modules=["lin1"],
836-
r=8,
837-
gralora_k=2,
838-
hybrid_r=4,
839-
gralora_alpha=16,
840-
)
841-
model = get_peft_model(mlp, config)
842-
843-
config_dict = model.get_peft_config_as_dict(inference=False)
844-
845-
assert "default" in config_dict
846-
assert config_dict["default"]["r"] == 8
847-
assert config_dict["default"]["gralora_k"] == 2
848-
assert config_dict["default"]["hybrid_r"] == 4
849-
850-
def test_gralora_get_peft_config_as_dict_inference_mode(self):
851-
"""Test get_peft_config_as_dict with inference=True"""
852-
torch.manual_seed(0)
853-
mlp = MLP()
854-
config = GraloraConfig(target_modules=["lin1"], r=8, gralora_k=2)
855-
model = get_peft_model(mlp, config)
856-
857-
config_dict = model.get_peft_config_as_dict(inference=True)
858-
859-
assert config_dict["default"]["inference_mode"] is True
860-
861830
def test_gralora_merge_with_hybrid_component(self):
862831
"""Test that merge works correctly with hybrid component"""
863832
torch.manual_seed(0)

0 commit comments

Comments
 (0)