Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm_ascend/attention/attention_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def get_mask_scale_factor(dtype: torch.dtype = torch.float16):

def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
device: torch.device):
if max_seq_len == 2048 and torch.version.cann.startswith("8.3"):
return self.chunked_prefill_attn_mask.to(torch.bool)
self._update_attn_cache(max_seq_len, dtype)
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
).to(device, non_blocking=True)
Expand Down
51 changes: 39 additions & 12 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,18 +373,45 @@ def _forward_prefill_cache_hit(
batch_size = attn_metadata.query_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :]

torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
block_table=block_table,
mask=compress_mask,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
if torch.version.cann.startswith("8.3"):
# TODO:The npu_fused_infer_attention_score op is planned to
# be utilized in a wider range in upcoming versions.
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)

assert block_size == 128, "only support block_size is 128"

output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=compress_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
else:
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
block_table=block_table,
mask=compress_mask,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
return output

def _forward_decode_only(
Expand Down
8 changes: 6 additions & 2 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,8 +907,12 @@ def _make_attention_mask(self, seq_lens, position,
max_seq_len, self.dtype, self.device)
# Prefill with cache hit.
elif attn_state == AscendAttentionState.PrefillCacheHit:
return self.attn_mask_builder.get_attn_mask(
128, self.dtype, self.device)
if torch.version.cann.startswith("8.3"):
return self.attn_mask_builder.get_attn_mask(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can use chunkprefill attn_mask fuc instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

2048, self.dtype, self.device)
else:
return self.attn_mask_builder.get_attn_mask(
128, self.dtype, self.device)
# Decode-only situation.
else:
return None
Expand Down
Loading