7575 require_torch_2 ,
7676 require_torch_accelerator ,
7777 require_torch_accelerator_with_training ,
78- require_torch_gpu ,
7978 require_torch_multi_accelerator ,
8079 require_torch_version_greater ,
8180 run_test_in_subprocess ,
@@ -1829,8 +1828,8 @@ def test_wrong_device_map_raises_error(self, device_map, msg_substring):
18291828
18301829 assert msg_substring in str (err_ctx .exception )
18311830
1832- @parameterized .expand ([0 , "cuda" , torch .device ("cuda" )])
1833- @require_torch_gpu
1831+ @parameterized .expand ([0 , torch_device , torch .device (torch_device )])
1832+ @require_torch_accelerator
18341833 def test_passing_non_dict_device_map_works (self , device_map ):
18351834 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
18361835 model = self .model_class (** init_dict ).eval ()
@@ -1839,8 +1838,8 @@ def test_passing_non_dict_device_map_works(self, device_map):
18391838 loaded_model = self .model_class .from_pretrained (tmpdir , device_map = device_map )
18401839 _ = loaded_model (** inputs_dict )
18411840
1842- @parameterized .expand ([("" , "cuda" ), ("" , torch .device ("cuda" ))])
1843- @require_torch_gpu
1841+ @parameterized .expand ([("" , torch_device ), ("" , torch .device (torch_device ))])
1842+ @require_torch_accelerator
18441843 def test_passing_dict_device_map_works (self , name , device ):
18451844 # There are other valid dict-based `device_map` values too. It's best to refer to
18461845 # the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.
@@ -1945,7 +1944,7 @@ def test_push_to_hub_library_name(self):
19451944 delete_repo (self .repo_id , token = TOKEN )
19461945
19471946
1948- @require_torch_gpu
1947+ @require_torch_accelerator
19491948@require_torch_2
19501949@is_torch_compile
19511950@slow
@@ -2013,7 +2012,7 @@ def test_compile_with_group_offloading(self):
20132012 model .eval ()
20142013 # TODO: Can test for other group offloading kwargs later if needed.
20152014 group_offload_kwargs = {
2016- "onload_device" : "cuda" ,
2015+ "onload_device" : torch_device ,
20172016 "offload_device" : "cpu" ,
20182017 "offload_type" : "block_level" ,
20192018 "num_blocks_per_group" : 1 ,
0 commit comments