-
-
Couldn't load subscription status.
- Fork 10.9k
[NVIDIA] Support Flashinfer TRTLLM FP8-q/kv/out Attention Kernel #21716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVIDIA] Support Flashinfer TRTLLM FP8-q/kv/out Attention Kernel #21716
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for the Flashinfer TRT-LLM FP8-query/output attention kernel. The changes span across benchmarks, tests, and core attention backend logic. The main changes involve updating the Flashinfer API calls to support an out parameter for in-place operations, and adding logic to handle FP8 quantization for queries and outputs. The PR also includes a significant refactoring of CUDA graph support for attention backends.
My review identifies two main issues. First, a critical bug in vllm/attention/layer.py where the query_scale for FP8 quantization is not being correctly propagated to the attention implementation. Second, a high-severity issue in vllm/v1/attention/backends/flashinfer.py where the usage of the TRT-LLM attention kernel is hardcoded, which limits flexibility.
c999f36 to
689b426
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
689b426 to
577d49f
Compare
577d49f to
b30f23a
Compare
3f5c953 to
37570d2
Compare
70f14ae to
9777371
Compare
vllm/utils/flashinfer.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not related to this PR, but I think has_nvidia_artifactory() can be removed because FlashInfer now supports downloading all cubins at once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can do this in another PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it.
9777371 to
90699f3
Compare
Signed-off-by: elvischenv <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is looking really good! Thanks for all your hard work
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: elvischenv <[email protected]>
…m-project#21716) Signed-off-by: elvischenv <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
…m-project#21716) Signed-off-by: elvischenv <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
…m-project#21716) Signed-off-by: elvischenv <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
…m-project#21716) Signed-off-by: elvischenv <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Luka Govedič <[email protected]> Signed-off-by: Duncan Moss <[email protected]>
…m-project#21716) Signed-off-by: elvischenv <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
…m-project#21716) Signed-off-by: elvischenv <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Luka Govedič <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
…m-project#21716) Signed-off-by: elvischenv <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
…m-project#21716) Signed-off-by: elvischenv <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
…m-project#21716) Signed-off-by: elvischenv <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
…m-project#21716) Signed-off-by: elvischenv <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
|
May I ask why TTFT increases when qkv fp8 is enabled?I assume fp8 Tensor Core should be used for accelerating qk and pv matmul when qkv is quantized to fp8. |
Using FP8 kv-cache introduces an additional FP8-Quant kernel for the Query tensor, so the performance may have a small drop if the attention speed up is too small. Ideally, that Quant should be fused with RoPE and that work is tracked in #24678 |
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.Purpose
AttentionStaticQuantPatternin fusion_attn pass, which will fuse the attn+fp8_quant pattern for using the TRTLLM FP8-in FP8-out kernel.Test Plan && Test Result
Functional:
tests/kernels/attention/test_flashinfer_trtllm_attention.pytests/compile/test_fusion_attn.py::test_attention_quant_patternE2E Performance:
nvidia/Llama-4-Scout-17B-16E-Instruct-FP8main
PR
(Optional) Documentation Update