@@ -2817,28 +2817,29 @@ def test_olora_with_quantized_model(self, bits):
28172817@pytest .mark .skipif (
28182818 not (torch .cuda .is_available () or is_xpu_available ()), reason = "test requires a hardware accelerator"
28192819)
2820+ @pytest .mark .single_gpu_tests
28202821@require_bitsandbytes
28212822class TestLoftQ :
28222823 r"""
28232824 Tests for LoftQ to ensure that it reduces the quantization error compared to normal LoRA quantization.
28242825 """
28252826
2826- # The error factor indicates by how much the quantization error should be decreased when using LoftQ compared to
2827- # quantization without LoftQ. Thus 1.03 means that the error should be decreased by 3% at least. This is a very
2828- # conservative value to prevent flakiness, in practice most gains are > 1.5
2829- device = infer_device ()
2830- error_factor = 1.005 if device in ("xpu" , "cpu" ) else 1.03
2827+ def get_error_factor (self , device ):
2828+ # The error factor indicates by how much the quantization error should be decreased when using LoftQ compared to
2829+ # quantization without LoftQ. Thus 1.03 means that the error should be decreased by 3% at least. This is a very
2830+ # conservative value to prevent flakiness, in practice most gains are > 1.5
2831+ error_factor = 1.005 if device in ("xpu" , "cpu" ) else 1.03
2832+ return error_factor
28312833
28322834 def get_input (self , model_id , device ):
28332835 tokenizer = AutoTokenizer .from_pretrained (model_id )
28342836 inputs = tokenizer ("All I want is" , padding = True , return_tensors = "pt" )
2835- inputs = inputs .to (self . device )
2837+ inputs = inputs .to (device )
28362838 return inputs
28372839
28382840 def get_base_model (self , model_id , device , ** kwargs ):
28392841 cls = AutoModelForSeq2SeqLM if "t5" in str (model_id ) else AutoModelForCausalLM
2840- model = cls .from_pretrained (model_id , ** kwargs ).eval ()
2841- model = model .to (self .device )
2842+ model = cls .from_pretrained (model_id , device_map = device , ** kwargs ).eval ()
28422843 return model
28432844
28442845 def get_logits (self , model , inputs ):
@@ -2882,7 +2883,7 @@ def get_errors(
28822883 raise ValueError ("bits must be 4 or 8" )
28832884
28842885 quantized_model = get_peft_model (
2885- self .get_base_model (model_id , device = None , ** kwargs ),
2886+ self .get_base_model (model_id , device , ** kwargs ),
28862887 lora_config ,
28872888 )
28882889 torch .manual_seed (0 )
@@ -2901,10 +2902,10 @@ def get_errors(
29012902 )
29022903 model = self .get_base_model (model_id , device )
29032904 if device != "cpu" :
2904- model = model .to (torch_device )
2905+ model = model .to (device )
29052906 loftq_model = get_peft_model (model , lora_config )
29062907 if device != "cpu" :
2907- loftq_model = loftq_model .to (torch_device )
2908+ loftq_model = loftq_model .to (device )
29082909
29092910 # save LoRA weights, they should be initialized such that they minimize the quantization error
29102911 loftq_model .base_model .peft_config ["default" ].init_lora_weights = True
@@ -2917,7 +2918,7 @@ def get_errors(
29172918 clear_device_cache (garbage_collection = True )
29182919
29192920 # now load quantized model and apply LoftQ-initialized weights on top
2920- base_model = self .get_base_model (tmp_path / "base_model" , device = None , ** kwargs , torch_dtype = torch .float32 )
2921+ base_model = self .get_base_model (tmp_path / "base_model" , device = device , ** kwargs , torch_dtype = torch .float32 )
29212922 loftq_model = PeftModel .from_pretrained (base_model , tmp_path / "loftq_model" , is_trainable = True )
29222923
29232924 # TODO sanity check: model is quantized
@@ -2966,8 +2967,9 @@ def test_bloomz_loftq_4bit_iter_5(self, device, tmp_path):
29662967 assert mse_loftq > 0.0
29672968
29682969 # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
2969- assert mse_loftq < (mse_quantized / self .error_factor )
2970- assert mae_loftq < (mae_quantized / self .error_factor )
2970+ error_factor = self .get_error_factor (device )
2971+ assert mse_loftq < (mse_quantized / error_factor )
2972+ assert mae_loftq < (mae_quantized / error_factor )
29712973
29722974 @pytest .mark .parametrize ("device" , [torch_device , "cpu" ])
29732975 def test_bloomz_loftq_8bit (self , device , tmp_path ):
@@ -2981,8 +2983,9 @@ def test_bloomz_loftq_8bit(self, device, tmp_path):
29812983 assert mse_loftq > 0.0
29822984
29832985 # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
2984- assert mse_loftq < (mse_quantized / self .error_factor )
2985- assert mae_loftq < (mae_quantized / self .error_factor )
2986+ error_factor = self .get_error_factor (device )
2987+ assert mse_loftq < (mse_quantized / error_factor )
2988+ assert mae_loftq < (mae_quantized / error_factor )
29862989
29872990 @pytest .mark .parametrize ("device" , [torch_device , "cpu" ])
29882991 def test_bloomz_loftq_8bit_iter_5 (self , device , tmp_path ):
@@ -2998,8 +3001,9 @@ def test_bloomz_loftq_8bit_iter_5(self, device, tmp_path):
29983001 assert mse_loftq > 0.0
29993002
30003003 # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
3001- assert mse_loftq < (mse_quantized / self .error_factor )
3002- assert mae_loftq < (mae_quantized / self .error_factor )
3004+ error_factor = self .get_error_factor (device )
3005+ assert mse_loftq < (mse_quantized / error_factor )
3006+ assert mae_loftq < (mae_quantized / error_factor )
30033007
30043008 @pytest .mark .parametrize ("device" , [torch_device , "cpu" ])
30053009 def test_t5_loftq_4bit (self , device , tmp_path ):
@@ -3013,8 +3017,9 @@ def test_t5_loftq_4bit(self, device, tmp_path):
30133017 assert mse_loftq > 0.0
30143018
30153019 # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
3016- assert mse_loftq < (mse_quantized / self .error_factor )
3017- assert mae_loftq < (mae_quantized / self .error_factor )
3020+ error_factor = self .get_error_factor (device )
3021+ assert mse_loftq < (mse_quantized / error_factor )
3022+ assert mae_loftq < (mae_quantized / error_factor )
30183023
30193024 @pytest .mark .parametrize ("device" , [torch_device , "cpu" ])
30203025 def test_t5_loftq_8bit (self , device , tmp_path ):
@@ -3028,8 +3033,9 @@ def test_t5_loftq_8bit(self, device, tmp_path):
30283033 assert mse_loftq > 0.0
30293034
30303035 # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
3031- assert mse_loftq < (mse_quantized / self .error_factor )
3032- assert mae_loftq < (mae_quantized / self .error_factor )
3036+ error_factor = self .get_error_factor (device )
3037+ assert mse_loftq < (mse_quantized / error_factor )
3038+ assert mae_loftq < (mae_quantized / error_factor )
30333039
30343040 @pytest .mark .xfail # failing for now, but having DoRA pass is only a nice-to-have, not a must, so we're good
30353041 @pytest .mark .parametrize ("device" , [torch_device , "cpu" ])
@@ -3063,8 +3069,9 @@ def test_bloomz_loftq_8bit_dora(self, device, tmp_path):
30633069 assert mse_loftq > 0.0
30643070
30653071 # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
3066- assert mae_loftq < (mae_quantized / self .error_factor )
3067- assert mse_loftq < (mse_quantized / self .error_factor )
3072+ error_factor = self .get_error_factor (device )
3073+ assert mae_loftq < (mae_quantized / error_factor )
3074+ assert mse_loftq < (mse_quantized / error_factor )
30683075
30693076 def test_replace_lora_weights_with_loftq_using_callable (self ):
30703077 """
0 commit comments