@@ -236,7 +236,7 @@ def test_quantization(self):
236236 ("uint7wo" , np .array ([0.4648 , 0.5195 , 0.5547 , 0.4219 , 0.4414 , 0.6445 , 0.4316 , 0.4531 , 0.5625 ])),
237237 ]
238238
239- if TorchAoConfig ._is_cuda_capability_atleast_8_9 ():
239+ if TorchAoConfig ._is_xpu_or_cuda_capability_atleast_8_9 ():
240240 QUANTIZATION_TYPES_TO_TEST .extend ([
241241 ("float8wo_e5m2" , np .array ([0.4590 , 0.5273 , 0.5547 , 0.4219 , 0.4375 , 0.6406 , 0.4316 , 0.4512 , 0.5625 ])),
242242 ("float8wo_e4m3" , np .array ([0.4648 , 0.5234 , 0.5547 , 0.4219 , 0.4414 , 0.6406 , 0.4316 , 0.4531 , 0.5625 ])),
@@ -753,7 +753,7 @@ def test_quantization(self):
753753 ("int8dq" , np .array ([0.0546 , 0.0761 , 0.1386 , 0.0488 , 0.0644 , 0.1425 , 0.0605 , 0.0742 , 0.1406 , 0.0625 , 0.0722 , 0.1523 , 0.0625 , 0.0742 , 0.1503 , 0.0605 , 0.3886 , 0.7968 , 0.5507 , 0.4492 , 0.7890 , 0.5351 , 0.4316 , 0.8007 , 0.5390 , 0.4179 , 0.8281 , 0.5820 , 0.4531 , 0.7812 , 0.5703 , 0.4921 ])),
754754 ]
755755
756- if TorchAoConfig ._is_cuda_capability_atleast_8_9 ():
756+ if TorchAoConfig ._is_xpu_or_cuda_capability_atleast_8_9 ():
757757 QUANTIZATION_TYPES_TO_TEST .extend ([
758758 ("float8wo_e4m3" , np .array ([0.0546 , 0.0722 , 0.1328 , 0.0468 , 0.0585 , 0.1367 , 0.0605 , 0.0703 , 0.1328 , 0.0625 , 0.0703 , 0.1445 , 0.0585 , 0.0703 , 0.1406 , 0.0605 , 0.3496 , 0.7109 , 0.4843 , 0.4042 , 0.7226 , 0.5000 , 0.4160 , 0.7031 , 0.4824 , 0.3886 , 0.6757 , 0.4667 , 0.3710 , 0.6679 , 0.4902 , 0.4238 ])),
759759 ("fp5_e3m1" , np .array ([0.0527 , 0.0762 , 0.1309 , 0.0449 , 0.0645 , 0.1328 , 0.0566 , 0.0723 , 0.125 , 0.0566 , 0.0703 , 0.1328 , 0.0566 , 0.0742 , 0.1348 , 0.0566 , 0.3633 , 0.7617 , 0.5273 , 0.4277 , 0.7891 , 0.5469 , 0.4375 , 0.8008 , 0.5586 , 0.4336 , 0.7383 , 0.5156 , 0.3906 , 0.6992 , 0.5156 , 0.4375 ])),
0 commit comments