diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 1d0c83751d39..f09f0d8ecb89 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2400,7 +2400,6 @@ def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): components, _, _ = self.get_dummy_components(self.scheduler_classes[0]) pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet @@ -2416,6 +2415,10 @@ def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): num_blocks_per_group=1, use_stream=use_stream, ) + # Place other model-level components on `torch_device`. + for _, component in pipe.components.items(): + if isinstance(component, torch.nn.Module): + component.to(torch_device) group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser) self.assertTrue(group_offload_hook_1 is not None) output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]