1313# limitations under the License.
1414
1515from contextlib import contextmanager , nullcontext
16- from typing import Dict , List , Optional , Set , Tuple
16+ from typing import Dict , List , Optional , Set , Tuple , Union
1717
1818import torch
1919
@@ -55,7 +55,7 @@ def __init__(
5555 parameters : Optional [List [torch .nn .Parameter ]] = None ,
5656 buffers : Optional [List [torch .Tensor ]] = None ,
5757 non_blocking : bool = False ,
58- stream : Optional [torch .cuda .Stream ] = None ,
58+ stream : Union [torch .cuda .Stream , torch . Stream , None ] = None ,
5959 record_stream : Optional [bool ] = False ,
6060 low_cpu_mem_usage : bool = False ,
6161 onload_self : bool = True ,
@@ -115,8 +115,13 @@ def _pinned_memory_tensors(self):
115115
116116 def onload_ (self ):
117117 r"""Onloads the group of modules to the onload_device."""
118- context = nullcontext () if self .stream is None else torch .cuda .stream (self .stream )
119- current_stream = torch .cuda .current_stream () if self .record_stream else None
118+ torch_accelerator_module = (
119+ getattr (torch , torch .accelerator .current_accelerator ().type )
120+ if hasattr (torch , "accelerator" )
121+ else torch .cuda
122+ )
123+ context = nullcontext () if self .stream is None else torch_accelerator_module .stream (self .stream )
124+ current_stream = torch_accelerator_module .current_stream () if self .record_stream else None
120125
121126 if self .stream is not None :
122127 # Wait for previous Host->Device transfer to complete
@@ -162,9 +167,15 @@ def onload_(self):
162167
163168 def offload_ (self ):
164169 r"""Offloads the group of modules to the offload_device."""
170+
171+ torch_accelerator_module = (
172+ getattr (torch , torch .accelerator .current_accelerator ().type )
173+ if hasattr (torch , "accelerator" )
174+ else torch .cuda
175+ )
165176 if self .stream is not None :
166177 if not self .record_stream :
167- torch . cuda .current_stream ().synchronize ()
178+ torch_accelerator_module .current_stream ().synchronize ()
168179 for group_module in self .modules :
169180 for param in group_module .parameters ():
170181 param .data = self .cpu_param_dict [param ]
@@ -429,8 +440,10 @@ def apply_group_offloading(
429440 if use_stream :
430441 if torch .cuda .is_available ():
431442 stream = torch .cuda .Stream ()
443+ elif hasattr (torch , "xpu" ) and torch .xpu .is_available ():
444+ stream = torch .Stream ()
432445 else :
433- raise ValueError ("Using streams for data transfer requires a CUDA device." )
446+ raise ValueError ("Using streams for data transfer requires a CUDA device, or an Intel XPU device ." )
434447
435448 _raise_error_if_accelerate_model_or_sequential_hook_present (module )
436449
@@ -468,7 +481,7 @@ def _apply_group_offloading_block_level(
468481 offload_device : torch .device ,
469482 onload_device : torch .device ,
470483 non_blocking : bool ,
471- stream : Optional [torch .cuda .Stream ] = None ,
484+ stream : Union [torch .cuda .Stream , torch . Stream , None ] = None ,
472485 record_stream : Optional [bool ] = False ,
473486 low_cpu_mem_usage : bool = False ,
474487) -> None :
@@ -486,7 +499,7 @@ def _apply_group_offloading_block_level(
486499 non_blocking (`bool`):
487500 If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
488501 and data transfer.
489- stream (`torch.cuda.Stream`, *optional*):
502+ stream (`torch.cuda.Stream`or `torch.Stream` , *optional*):
490503 If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
491504 for overlapping computation and data transfer.
492505 record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
@@ -572,7 +585,7 @@ def _apply_group_offloading_leaf_level(
572585 offload_device : torch .device ,
573586 onload_device : torch .device ,
574587 non_blocking : bool ,
575- stream : Optional [torch .cuda .Stream ] = None ,
588+ stream : Union [torch .cuda .Stream , torch . Stream , None ] = None ,
576589 record_stream : Optional [bool ] = False ,
577590 low_cpu_mem_usage : bool = False ,
578591) -> None :
@@ -592,7 +605,7 @@ def _apply_group_offloading_leaf_level(
592605 non_blocking (`bool`):
593606 If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
594607 and data transfer.
595- stream (`torch.cuda.Stream`, *optional*):
608+ stream (`torch.cuda.Stream` or `torch.Stream` , *optional*):
596609 If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
597610 for overlapping computation and data transfer.
598611 record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
0 commit comments