|  | 
| 38 | 38 |     C3AConfig, | 
| 39 | 39 |     DeloraConfig, | 
| 40 | 40 |     FourierFTConfig, | 
|  | 41 | +    GraloraConfig, | 
| 41 | 42 |     HRAConfig, | 
| 42 | 43 |     IA3Config, | 
| 43 | 44 |     LNTuningConfig, | 
|  | 
| 665 | 666 |             "init_weights": True, | 
| 666 | 667 |         }, | 
| 667 | 668 |     ), | 
|  | 669 | +    ########### | 
|  | 670 | +    # GraLoRA # | 
|  | 671 | +    ########### | 
|  | 672 | +    ("Vanilla MLP 1 GraLoRA", "MLP", GraloraConfig, {"target_modules": "lin0"}), | 
|  | 673 | +    ("Vanilla MLP 2 GraLoRA", "MLP", GraloraConfig, {"target_modules": ["lin0"]}), | 
|  | 674 | +    ("Vanilla MLP 3 GraLoRA", "MLP", GraloraConfig, {"target_modules": ["lin1"]}), | 
|  | 675 | +    ("Vanilla MLP 4 GraLoRA", "MLP", GraloraConfig, {"target_modules": ["lin0", "lin1"]}), | 
|  | 676 | +    ( | 
|  | 677 | +        "Vanilla MLP 5 GraLoRA", | 
|  | 678 | +        "MLP", | 
|  | 679 | +        GraloraConfig, | 
|  | 680 | +        {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}, | 
|  | 681 | +    ), | 
|  | 682 | +    ( | 
|  | 683 | +        "Embedding + transformers Conv1D 1 GraLoRA", | 
|  | 684 | +        "EmbConv1D", | 
|  | 685 | +        GraloraConfig, | 
|  | 686 | +        {"target_modules": ["conv1d"], "gralora_k": 1}, | 
|  | 687 | +    ), | 
| 668 | 688 |     ########## | 
| 669 | 689 |     # VBLoRA # | 
| 670 | 690 |     ########## | 
|  | 
| 973 | 993 |         {"n_frequency": 10, "target_modules": ["lin0"]}, | 
| 974 | 994 |         {"n_frequency": 10, "target_modules": ["lin1"]}, | 
| 975 | 995 |     ), | 
|  | 996 | +    ( | 
|  | 997 | +        "GraLoRA Same", | 
|  | 998 | +        "gralora", | 
|  | 999 | +        GraloraConfig, | 
|  | 1000 | +        {"target_modules": ["lin0"], "init_weights": False}, | 
|  | 1001 | +        {"target_modules": ["lin0"], "init_weights": False}, | 
|  | 1002 | +    ), | 
|  | 1003 | +    ( | 
|  | 1004 | +        "GraLoRA Different", | 
|  | 1005 | +        "gralora", | 
|  | 1006 | +        GraloraConfig, | 
|  | 1007 | +        {"target_modules": ["lin0"], "init_weights": False}, | 
|  | 1008 | +        {"target_modules": ["lin1"], "init_weights": False}, | 
|  | 1009 | +    ), | 
| 976 | 1010 |     ( | 
| 977 | 1011 |         "SHiRA Same", | 
| 978 | 1012 |         "shira", | 
|  | 
| 1159 | 1193 |     VeraConfig: "vera_lambda_", | 
| 1160 | 1194 |     RandLoraConfig: "randlora_", | 
| 1161 | 1195 |     FourierFTConfig: "fourierft_", | 
|  | 1196 | +    GraloraConfig: "gralora_", | 
| 1162 | 1197 |     C3AConfig: "c3a_", | 
| 1163 | 1198 |     HRAConfig: "hra_", | 
| 1164 | 1199 |     ShiraConfig: "shira_", | 
| @@ -3104,12 +3139,12 @@ def test_add_weighted_adapter_subtraction_with_negative_weights(self): | 
| 3104 | 3139 |                 cancelled_B = module.lora_B["cancelled"].weight.data | 
| 3105 | 3140 | 
 | 
| 3106 | 3141 |                 # The weights should be approximately zero (they cancel out) | 
| 3107 |  | -                assert torch.allclose(cancelled_A, torch.zeros_like(cancelled_A), atol=1e-5), ( | 
| 3108 |  | -                    f"Cancelled A should be ~0, got max abs value {cancelled_A.abs().max()}" | 
| 3109 |  | -                ) | 
| 3110 |  | -                assert torch.allclose(cancelled_B, torch.zeros_like(cancelled_B), atol=1e-5), ( | 
| 3111 |  | -                    f"Cancelled B should be ~0, got max abs value {cancelled_B.abs().max()}" | 
| 3112 |  | -                ) | 
|  | 3142 | +                assert torch.allclose( | 
|  | 3143 | +                    cancelled_A, torch.zeros_like(cancelled_A), atol=1e-5 | 
|  | 3144 | +                ), f"Cancelled A should be ~0, got max abs value {cancelled_A.abs().max()}" | 
|  | 3145 | +                assert torch.allclose( | 
|  | 3146 | +                    cancelled_B, torch.zeros_like(cancelled_B), atol=1e-5 | 
|  | 3147 | +                ), f"Cancelled B should be ~0, got max abs value {cancelled_B.abs().max()}" | 
| 3113 | 3148 | 
 | 
| 3114 | 3149 |     def test_add_weighted_adapter_negative_weight_with_different_scaling(self): | 
| 3115 | 3150 |         # Test negative weights with different scaling factors (lora_alpha) | 
| @@ -3515,9 +3550,9 @@ def test_multirank_2(self): | 
| 3515 | 3550 |                 if isinstance(module, BaseTunerLayer): | 
| 3516 | 3551 |                     rank_expected = rank_pattern.get(key, r) | 
| 3517 | 3552 |                     rank_current = module.lora_A[adapter].weight.shape[0] | 
| 3518 |  | -                    assert rank_current == rank_expected, ( | 
| 3519 |  | -                        f"Rank {rank_current} is not equal to expected {rank_expected}" | 
| 3520 |  | -                    ) | 
|  | 3553 | +                    assert ( | 
|  | 3554 | +                        rank_current == rank_expected | 
|  | 3555 | +                    ), f"Rank {rank_current} is not equal to expected {rank_expected}" | 
| 3521 | 3556 | 
 | 
| 3522 | 3557 | 
 | 
| 3523 | 3558 | class TestLayerRepr: | 
|  | 
0 commit comments