@@ -373,18 +373,43 @@ def _forward_prefill_cache_hit(
373373 batch_size = attn_metadata .query_lens .shape [0 ]
374374 block_table = attn_metadata .block_tables [:batch_size , :]
375375
376- torch_npu ._npu_flash_attention_qlens (
377- query = query ,
378- key_cache = self .key_cache ,
379- value_cache = self .value_cache ,
380- block_table = block_table ,
381- mask = compress_mask ,
382- seq_len = attn_metadata .query_lens ,
383- context_lens = attn_metadata .seq_lens ,
384- num_kv_heads = self .num_kv_heads ,
385- num_heads = self .num_heads ,
386- scale_value = self .scale ,
387- out = output )
376+ if torch .version .cann .startswith ("8.3" ):
377+ # TODO:The npu_fused_infer_attention_score op is planned to
378+ # be utilized in a wider range in upcoming versions.
379+ num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
380+ key = self .key_cache .view ( # type: ignore
381+ num_block , block_size , - 1 )
382+ value = self .value_cache .view ( # type: ignore
383+ num_block , block_size , - 1 )
384+
385+ output , _ = torch_npu .npu_fused_infer_attention_score (
386+ query = query ,
387+ key = key ,
388+ value = value ,
389+ atten_mask = compress_mask ,
390+ block_table = block_table ,
391+ input_layout = "TND" ,
392+ block_size = block_size ,
393+ actual_seq_lengths = attn_metadata .actual_seq_lengths_q ,
394+ actual_seq_lengths_kv = attn_metadata .actual_seq_lengths_q ,
395+ num_key_value_heads = self .num_kv_heads ,
396+ num_heads = self .num_heads ,
397+ scale = self .scale ,
398+ sparse_mode = 3 ,
399+ )
400+ else :
401+ torch_npu ._npu_flash_attention_qlens (
402+ query = query ,
403+ key_cache = self .key_cache ,
404+ value_cache = self .value_cache ,
405+ block_table = block_table ,
406+ mask = compress_mask ,
407+ seq_len = attn_metadata .query_lens ,
408+ context_lens = attn_metadata .seq_lens ,
409+ num_kv_heads = self .num_kv_heads ,
410+ num_heads = self .num_heads ,
411+ scale_value = self .scale ,
412+ out = output )
388413 return output
389414
390415 def _forward_decode_only (
0 commit comments