Skip to content

Commit ffa971a

Browse files
authored
FIX LoftQ 8-bit bnb error, support XPU (#2797)
1 parent 4469af5 commit ffa971a

File tree

2 files changed

+38
-28
lines changed

2 files changed

+38
-28
lines changed

src/peft/utils/integrations.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,19 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
9393
"""
9494
import bitsandbytes as bnb
9595

96-
# BNB requires CUDA weights
96+
if state.SCB is None:
97+
state.SCB = weight.SCB
98+
99+
# BNB requires accelerator weights
97100
device = weight.device
98101
is_cpu = device.type == torch.device("cpu").type
99102
if is_cpu:
100103
if torch.cuda.is_available():
101104
weight = weight.to(torch.device("cuda"))
105+
state.SCB = state.SCB.to(torch.device("cuda"))
102106
elif is_xpu_available():
103107
weight = weight.to(torch.device("xpu"))
108+
state.SCB = state.SCB.to(torch.device("xpu"))
104109

105110
cls_name = weight.__class__.__name__
106111
if cls_name == "Params4bit":
@@ -109,9 +114,6 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
109114
dequantized = dequantized.to(device)
110115
return dequantized
111116

112-
if state.SCB is None:
113-
state.SCB = weight.SCB
114-
115117
if hasattr(bnb.functional, "int8_vectorwise_dequant"):
116118
# Use bitsandbytes API if available (requires v0.45.0+)
117119
dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)
@@ -121,6 +123,7 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
121123

122124
if is_cpu:
123125
dequantized = dequantized.to(device)
126+
state.SCB = state.SCB.to(device)
124127
return dequantized
125128

126129

tests/test_gpu_examples.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2817,28 +2817,29 @@ def test_olora_with_quantized_model(self, bits):
28172817
@pytest.mark.skipif(
28182818
not (torch.cuda.is_available() or is_xpu_available()), reason="test requires a hardware accelerator"
28192819
)
2820+
@pytest.mark.single_gpu_tests
28202821
@require_bitsandbytes
28212822
class TestLoftQ:
28222823
r"""
28232824
Tests for LoftQ to ensure that it reduces the quantization error compared to normal LoRA quantization.
28242825
"""
28252826

2826-
# The error factor indicates by how much the quantization error should be decreased when using LoftQ compared to
2827-
# quantization without LoftQ. Thus 1.03 means that the error should be decreased by 3% at least. This is a very
2828-
# conservative value to prevent flakiness, in practice most gains are > 1.5
2829-
device = infer_device()
2830-
error_factor = 1.005 if device in ("xpu", "cpu") else 1.03
2827+
def get_error_factor(self, device):
2828+
# The error factor indicates by how much the quantization error should be decreased when using LoftQ compared to
2829+
# quantization without LoftQ. Thus 1.03 means that the error should be decreased by 3% at least. This is a very
2830+
# conservative value to prevent flakiness, in practice most gains are > 1.5
2831+
error_factor = 1.005 if device in ("xpu", "cpu") else 1.03
2832+
return error_factor
28312833

28322834
def get_input(self, model_id, device):
28332835
tokenizer = AutoTokenizer.from_pretrained(model_id)
28342836
inputs = tokenizer("All I want is", padding=True, return_tensors="pt")
2835-
inputs = inputs.to(self.device)
2837+
inputs = inputs.to(device)
28362838
return inputs
28372839

28382840
def get_base_model(self, model_id, device, **kwargs):
28392841
cls = AutoModelForSeq2SeqLM if "t5" in str(model_id) else AutoModelForCausalLM
2840-
model = cls.from_pretrained(model_id, **kwargs).eval()
2841-
model = model.to(self.device)
2842+
model = cls.from_pretrained(model_id, device_map=device, **kwargs).eval()
28422843
return model
28432844

28442845
def get_logits(self, model, inputs):
@@ -2882,7 +2883,7 @@ def get_errors(
28822883
raise ValueError("bits must be 4 or 8")
28832884

28842885
quantized_model = get_peft_model(
2885-
self.get_base_model(model_id, device=None, **kwargs),
2886+
self.get_base_model(model_id, device, **kwargs),
28862887
lora_config,
28872888
)
28882889
torch.manual_seed(0)
@@ -2901,10 +2902,10 @@ def get_errors(
29012902
)
29022903
model = self.get_base_model(model_id, device)
29032904
if device != "cpu":
2904-
model = model.to(torch_device)
2905+
model = model.to(device)
29052906
loftq_model = get_peft_model(model, lora_config)
29062907
if device != "cpu":
2907-
loftq_model = loftq_model.to(torch_device)
2908+
loftq_model = loftq_model.to(device)
29082909

29092910
# save LoRA weights, they should be initialized such that they minimize the quantization error
29102911
loftq_model.base_model.peft_config["default"].init_lora_weights = True
@@ -2917,7 +2918,7 @@ def get_errors(
29172918
clear_device_cache(garbage_collection=True)
29182919

29192920
# now load quantized model and apply LoftQ-initialized weights on top
2920-
base_model = self.get_base_model(tmp_path / "base_model", device=None, **kwargs, torch_dtype=torch.float32)
2921+
base_model = self.get_base_model(tmp_path / "base_model", device=device, **kwargs, torch_dtype=torch.float32)
29212922
loftq_model = PeftModel.from_pretrained(base_model, tmp_path / "loftq_model", is_trainable=True)
29222923

29232924
# TODO sanity check: model is quantized
@@ -2966,8 +2967,9 @@ def test_bloomz_loftq_4bit_iter_5(self, device, tmp_path):
29662967
assert mse_loftq > 0.0
29672968

29682969
# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
2969-
assert mse_loftq < (mse_quantized / self.error_factor)
2970-
assert mae_loftq < (mae_quantized / self.error_factor)
2970+
error_factor = self.get_error_factor(device)
2971+
assert mse_loftq < (mse_quantized / error_factor)
2972+
assert mae_loftq < (mae_quantized / error_factor)
29712973

29722974
@pytest.mark.parametrize("device", [torch_device, "cpu"])
29732975
def test_bloomz_loftq_8bit(self, device, tmp_path):
@@ -2981,8 +2983,9 @@ def test_bloomz_loftq_8bit(self, device, tmp_path):
29812983
assert mse_loftq > 0.0
29822984

29832985
# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
2984-
assert mse_loftq < (mse_quantized / self.error_factor)
2985-
assert mae_loftq < (mae_quantized / self.error_factor)
2986+
error_factor = self.get_error_factor(device)
2987+
assert mse_loftq < (mse_quantized / error_factor)
2988+
assert mae_loftq < (mae_quantized / error_factor)
29862989

29872990
@pytest.mark.parametrize("device", [torch_device, "cpu"])
29882991
def test_bloomz_loftq_8bit_iter_5(self, device, tmp_path):
@@ -2998,8 +3001,9 @@ def test_bloomz_loftq_8bit_iter_5(self, device, tmp_path):
29983001
assert mse_loftq > 0.0
29993002

30003003
# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
3001-
assert mse_loftq < (mse_quantized / self.error_factor)
3002-
assert mae_loftq < (mae_quantized / self.error_factor)
3004+
error_factor = self.get_error_factor(device)
3005+
assert mse_loftq < (mse_quantized / error_factor)
3006+
assert mae_loftq < (mae_quantized / error_factor)
30033007

30043008
@pytest.mark.parametrize("device", [torch_device, "cpu"])
30053009
def test_t5_loftq_4bit(self, device, tmp_path):
@@ -3013,8 +3017,9 @@ def test_t5_loftq_4bit(self, device, tmp_path):
30133017
assert mse_loftq > 0.0
30143018

30153019
# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
3016-
assert mse_loftq < (mse_quantized / self.error_factor)
3017-
assert mae_loftq < (mae_quantized / self.error_factor)
3020+
error_factor = self.get_error_factor(device)
3021+
assert mse_loftq < (mse_quantized / error_factor)
3022+
assert mae_loftq < (mae_quantized / error_factor)
30183023

30193024
@pytest.mark.parametrize("device", [torch_device, "cpu"])
30203025
def test_t5_loftq_8bit(self, device, tmp_path):
@@ -3028,8 +3033,9 @@ def test_t5_loftq_8bit(self, device, tmp_path):
30283033
assert mse_loftq > 0.0
30293034

30303035
# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
3031-
assert mse_loftq < (mse_quantized / self.error_factor)
3032-
assert mae_loftq < (mae_quantized / self.error_factor)
3036+
error_factor = self.get_error_factor(device)
3037+
assert mse_loftq < (mse_quantized / error_factor)
3038+
assert mae_loftq < (mae_quantized / error_factor)
30333039

30343040
@pytest.mark.xfail # failing for now, but having DoRA pass is only a nice-to-have, not a must, so we're good
30353041
@pytest.mark.parametrize("device", [torch_device, "cpu"])
@@ -3063,8 +3069,9 @@ def test_bloomz_loftq_8bit_dora(self, device, tmp_path):
30633069
assert mse_loftq > 0.0
30643070

30653071
# next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
3066-
assert mae_loftq < (mae_quantized / self.error_factor)
3067-
assert mse_loftq < (mse_quantized / self.error_factor)
3072+
error_factor = self.get_error_factor(device)
3073+
assert mae_loftq < (mae_quantized / error_factor)
3074+
assert mse_loftq < (mse_quantized / error_factor)
30683075

30693076
def test_replace_lora_weights_with_loftq_using_callable(self):
30703077
"""

0 commit comments

Comments
 (0)