File tree Expand file tree Collapse file tree 4 files changed +30
-20
lines changed Expand file tree Collapse file tree 4 files changed +30
-20
lines changed Original file line number Diff line number Diff line change @@ -866,15 +866,17 @@ def test_fp4_double_safe(self):
866866
867867@require_torch_version_greater ("2.7.1" )
868868class Bnb4BitCompileTests (QuantCompileTests ):
869- quantization_config = PipelineQuantizationConfig (
870- quant_backend = "bitsandbytes_8bit" ,
871- quant_kwargs = {
872- "load_in_4bit" : True ,
873- "bnb_4bit_quant_type" : "nf4" ,
874- "bnb_4bit_compute_dtype" : torch .bfloat16 ,
875- },
876- components_to_quantize = ["transformer" , "text_encoder_2" ],
877- )
869+ @property
870+ def quantization_config (self ):
871+ return PipelineQuantizationConfig (
872+ quant_backend = "bitsandbytes_8bit" ,
873+ quant_kwargs = {
874+ "load_in_4bit" : True ,
875+ "bnb_4bit_quant_type" : "nf4" ,
876+ "bnb_4bit_compute_dtype" : torch .bfloat16 ,
877+ },
878+ components_to_quantize = ["transformer" , "text_encoder_2" ],
879+ )
878880
879881 def test_torch_compile (self ):
880882 torch ._dynamo .config .capture_dynamic_output_shape_ops = True
Original file line number Diff line number Diff line change @@ -831,11 +831,13 @@ def test_serialization_sharded(self):
831831
832832@require_torch_version_greater_equal ("2.6.0" )
833833class Bnb8BitCompileTests (QuantCompileTests ):
834- quantization_config = PipelineQuantizationConfig (
835- quant_backend = "bitsandbytes_8bit" ,
836- quant_kwargs = {"load_in_8bit" : True },
837- components_to_quantize = ["transformer" , "text_encoder_2" ],
838- )
834+ @property
835+ def quantization_config (self ):
836+ return PipelineQuantizationConfig (
837+ quant_backend = "bitsandbytes_8bit" ,
838+ quant_kwargs = {"load_in_8bit" : True },
839+ components_to_quantize = ["transformer" , "text_encoder_2" ],
840+ )
839841
840842 def test_torch_compile (self ):
841843 torch ._dynamo .config .capture_dynamic_output_shape_ops = True
Original file line number Diff line number Diff line change 2424@require_torch_gpu
2525@slow
2626class QuantCompileTests (unittest .TestCase ):
27- quantization_config = None
27+ @property
28+ def quantization_config (self ):
29+ raise NotImplementedError (
30+ "This property should be implemented in the subclass to return the appropriate quantization config."
31+ )
2832
2933 def setUp (self ):
3034 super ().setUp ()
Original file line number Diff line number Diff line change @@ -631,11 +631,13 @@ def test_int_a16w8_cpu(self):
631631
632632@require_torchao_version_greater_or_equal ("0.7.0" )
633633class TorchAoCompileTest (QuantCompileTests ):
634- quantization_config = PipelineQuantizationConfig (
635- quant_mapping = {
636- "transformer" : TorchAoConfig (quant_type = "int8_weight_only" ),
637- },
638- )
634+ @property
635+ def quantization_config (self ):
636+ return PipelineQuantizationConfig (
637+ quant_mapping = {
638+ "transformer" : TorchAoConfig (quant_type = "int8_weight_only" ),
639+ },
640+ )
639641
640642 def test_torch_compile (self ):
641643 super ()._test_torch_compile (quantization_config = self .quantization_config )
You can’t perform that action at this time.
0 commit comments