-
Notifications
You must be signed in to change notification settings - Fork 515
support prefill cache mode use fia op #3652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for the Fused Infer Attention (FIA) operator for prefill with cache hits on CANN 8.3. The changes introduce a new code path in the attention implementation and adjust the attention mask creation accordingly. While the changes are generally in the right direction, I have identified a critical issue regarding the incorrect sequence lengths being passed to the new operator, which could lead to incorrect attention results. Additionally, I've noted the use of a hardcoded value for an attention mask dimension, which could cause issues if model configurations change, and I've suggested a more robust alternative.
| input_layout="TND", | ||
| block_size=block_size, | ||
| actual_seq_lengths=attn_metadata.actual_seq_lengths_q, | ||
| actual_seq_lengths_kv=attn_metadata.actual_seq_lengths_q, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The actual_seq_lengths_kv parameter is incorrectly set to attn_metadata.actual_seq_lengths_q, which represents the query lengths. In prefill scenarios, the key/value sequence lengths should correspond to the total context lengths, not just the new query tokens. This will likely lead to incorrect attention calculations because the model will attend to the wrong token range in the KV cache. The correct value, representing the context lengths, appears to be available in attn_metadata.seq_lens_list, which is used in other calls to this same operator within this file.
| actual_seq_lengths_kv=attn_metadata.actual_seq_lengths_q, | |
| actual_seq_lengths_kv=attn_metadata.seq_lens_list, |
| return ~torch.tril(torch.ones((2048, 2048), | ||
| dtype=torch.bool, | ||
| device=self.device) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dimension 2048 for the attention mask is hardcoded. This can lead to runtime errors if the model's max_model_len is configured to be larger than 2048. To make the code more robust and maintainable, it's better to use a value derived from the model configuration, such as self.model_config.max_model_len.
If 2048 is a strict limitation of the underlying kernel for CANN 8.3, it should be defined as a named constant (e.g., _CANN_8_3_FIA_MAX_LEN = 2048) and an assertion should be added during initialization to ensure self.model_config.max_model_len does not exceed this limit.
| return ~torch.tril(torch.ones((2048, 2048), | |
| dtype=torch.bool, | |
| device=self.device) | |
| ) | |
| return ~torch.tril(torch.ones((self.model_config.max_model_len, self.model_config.max_model_len), | |
| dtype=torch.bool, | |
| device=self.device) | |
| ) |
6d9716a to
e90f3f2
Compare
e90f3f2 to
6a5fc6a
Compare
Signed-off-by: shiyuan680 <[email protected]>
6a5fc6a to
7c4072e
Compare
| return self.attn_mask_builder.get_attn_mask( | ||
| 128, self.dtype, self.device) | ||
| if torch.version.cann.startswith("8.3"): | ||
| return self.attn_mask_builder.get_attn_mask( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can use chunkprefill attn_mask fuc instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
What this PR does / why we need it?
support prefill cache mode use fia op for full graph
Does this PR introduce any user-facing change?
How was this patch tested?