Skip to content

Commit 6d9716a

Browse files
committed
support prefill cache mode use fia op
Signed-off-by: shiyuan680 <[email protected]>
1 parent dd7a250 commit 6d9716a

File tree

2 files changed

+43
-14
lines changed

2 files changed

+43
-14
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -373,18 +373,43 @@ def _forward_prefill_cache_hit(
373373
batch_size = attn_metadata.query_lens.shape[0]
374374
block_table = attn_metadata.block_tables[:batch_size, :]
375375

376-
torch_npu._npu_flash_attention_qlens(
377-
query=query,
378-
key_cache=self.key_cache,
379-
value_cache=self.value_cache,
380-
block_table=block_table,
381-
mask=compress_mask,
382-
seq_len=attn_metadata.query_lens,
383-
context_lens=attn_metadata.seq_lens,
384-
num_kv_heads=self.num_kv_heads,
385-
num_heads=self.num_heads,
386-
scale_value=self.scale,
387-
out=output)
376+
if torch.version.cann.startswith("8.3"):
377+
# TODO:The npu_fused_infer_attention_score op is planned to
378+
# be utilized in a wider range in upcoming versions.
379+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
380+
key = self.key_cache.view( # type: ignore
381+
num_block, block_size, -1)
382+
value = self.value_cache.view( # type: ignore
383+
num_block, block_size, -1)
384+
385+
output, _ = torch_npu.npu_fused_infer_attention_score(
386+
query=query,
387+
key=key,
388+
value=value,
389+
atten_mask=compress_mask,
390+
block_table=block_table,
391+
input_layout="TND",
392+
block_size=block_size,
393+
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
394+
actual_seq_lengths_kv=attn_metadata.actual_seq_lengths_q,
395+
num_key_value_heads=self.num_kv_heads,
396+
num_heads=self.num_heads,
397+
scale=self.scale,
398+
sparse_mode=3,
399+
)
400+
else:
401+
torch_npu._npu_flash_attention_qlens(
402+
query=query,
403+
key_cache=self.key_cache,
404+
value_cache=self.value_cache,
405+
block_table=block_table,
406+
mask=compress_mask,
407+
seq_len=attn_metadata.query_lens,
408+
context_lens=attn_metadata.seq_lens,
409+
num_kv_heads=self.num_kv_heads,
410+
num_heads=self.num_heads,
411+
scale_value=self.scale,
412+
out=output)
388413
return output
389414

390415
def _forward_decode_only(

vllm_ascend/worker/model_runner_v1.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -907,8 +907,12 @@ def _make_attention_mask(self, seq_lens, position,
907907
max_seq_len, self.dtype, self.device)
908908
# Prefill with cache hit.
909909
elif attn_state == AscendAttentionState.PrefillCacheHit:
910-
return self.attn_mask_builder.get_attn_mask(
911-
128, self.dtype, self.device)
910+
if torch.version.cann.startswith("8.3"):
911+
return ~torch.tril(
912+
torch.ones((2048, 2048), dtype=torch.bool, device=self.device))
913+
else:
914+
return self.attn_mask_builder.get_attn_mask(
915+
128, self.dtype, self.device)
912916
# Decode-only situation.
913917
else:
914918
return None

0 commit comments

Comments
 (0)