Skip to content

Conversation

@shiyuan680
Copy link
Contributor

@shiyuan680 shiyuan680 commented Oct 23, 2025

What this PR does / why we need it?

support prefill cache mode use fia op for full graph

Does this PR introduce any user-facing change?

How was this patch tested?

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the Fused Infer Attention (FIA) operator for prefill with cache hits on CANN 8.3. The changes introduce a new code path in the attention implementation and adjust the attention mask creation accordingly. While the changes are generally in the right direction, I have identified a critical issue regarding the incorrect sequence lengths being passed to the new operator, which could lead to incorrect attention results. Additionally, I've noted the use of a hardcoded value for an attention mask dimension, which could cause issues if model configurations change, and I've suggested a more robust alternative.

input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=attn_metadata.actual_seq_lengths_q,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The actual_seq_lengths_kv parameter is incorrectly set to attn_metadata.actual_seq_lengths_q, which represents the query lengths. In prefill scenarios, the key/value sequence lengths should correspond to the total context lengths, not just the new query tokens. This will likely lead to incorrect attention calculations because the model will attend to the wrong token range in the KV cache. The correct value, representing the context lengths, appears to be available in attn_metadata.seq_lens_list, which is used in other calls to this same operator within this file.

Suggested change
actual_seq_lengths_kv=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=attn_metadata.seq_lens_list,

Comment on lines 911 to 914
return ~torch.tril(torch.ones((2048, 2048),
dtype=torch.bool,
device=self.device)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The dimension 2048 for the attention mask is hardcoded. This can lead to runtime errors if the model's max_model_len is configured to be larger than 2048. To make the code more robust and maintainable, it's better to use a value derived from the model configuration, such as self.model_config.max_model_len.

If 2048 is a strict limitation of the underlying kernel for CANN 8.3, it should be defined as a named constant (e.g., _CANN_8_3_FIA_MAX_LEN = 2048) and an assertion should be added during initialization to ensure self.model_config.max_model_len does not exceed this limit.

Suggested change
return ~torch.tril(torch.ones((2048, 2048),
dtype=torch.bool,
device=self.device)
)
return ~torch.tril(torch.ones((self.model_config.max_model_len, self.model_config.max_model_len),
dtype=torch.bool,
device=self.device)
)

@shiyuan680 shiyuan680 force-pushed the fia_replace branch 3 times, most recently from 6d9716a to e90f3f2 Compare October 23, 2025 08:03
return self.attn_mask_builder.get_attn_mask(
128, self.dtype, self.device)
if torch.version.cann.startswith("8.3"):
return self.attn_mask_builder.get_attn_mask(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can use chunkprefill attn_mask fuc instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@shiyuan680 shiyuan680 closed this Oct 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants