3030from text_generation_server .utils .dist import MEMORY_FRACTION
3131
3232tracer = trace .get_tracer (__name__ )
33-
33+ from text_generation_server . utils . import_utils import IS_CUDA_SYSTEM , IS_ROCM_SYSTEM , IS_XPU_SYSTEM
3434
3535@dataclass
3636class FlashCausalLMBatch (Batch ):
@@ -679,7 +679,10 @@ def batch_type(self) -> Type[FlashCausalLMBatch]:
679679 return FlashCausalLMBatch
680680
681681 def warmup (self , batch : FlashCausalLMBatch ):
682- torch .cuda .empty_cache ()
682+ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM :
683+ torch .cuda .empty_cache ()
684+ elif IS_XPU_SYSTEM :
685+ torch .xpu .empty_cache ()
683686 try :
684687 cache_manager = set_cache_manager (
685688 batch .blocks ,
@@ -697,20 +700,29 @@ def warmup(self, batch: FlashCausalLMBatch):
697700 f"You need to decrease `--max-batch-prefill-tokens`"
698701 ) from e
699702
700- torch .cuda .synchronize (self .device )
703+ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM :
704+ torch .cuda .synchronize (self .device )
705+ elif IS_XPU_SYSTEM :
706+ torch .xpu .synchronize (self .device )
701707
702708 # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
703709 # Calculate the number of blocks that can be allocated with the free memory
704710 dtype_size = torch .tensor ([], dtype = self .dtype ).element_size ()
705711 cache_block_size = BLOCK_SIZE * self .num_kv_heads * self .head_size
706712 total_cache_size = self .num_layers * cache_block_size * 2 * dtype_size
707713
708- total_free_memory , _ = torch .cuda .mem_get_info (self .device )
709- total_gpu_memory = torch .cuda .get_device_properties (self .device ).total_memory
714+ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM :
715+ total_free_memory , _ = torch .cuda .mem_get_info (self .device )
716+ total_gpu_memory = torch .cuda .get_device_properties (self .device ).total_memory
710717
711- free_memory = max (
712- 0 , total_free_memory - (1 - MEMORY_FRACTION ) * total_gpu_memory
713- )
718+ free_memory = max (
719+ 0 , total_free_memory - (1 - MEMORY_FRACTION ) * total_gpu_memory
720+ )
721+ elif IS_XPU_SYSTEM :
722+ total_gpu_memory = torch .xpu .get_device_properties (self .device ).total_memory
723+ free_memory = int (total_gpu_memory * 0.5 )
724+ else :
725+ raise NotImplementedError ("FlashModel is only available on GPU" )
714726
715727 num_blocks = (
716728 int (free_memory // total_cache_size )
0 commit comments