Skip to content

Conversation

@hjjq
Copy link
Contributor

@hjjq hjjq commented Jul 16, 2025

E2E Benchmark:

VLLM_ATTENTION_BACKEND=FLASHINFER_MLA vllm serve deepseek-ai/DeepSeek-R1 --block-size 32 --tensor-parallel-size 8 --max-num-seqs 512

python benchmarks/benchmark_serving.py  --model deepseek-ai/DeepSeek-R1 --dataset-name random --ignore-eos --num-prompts 1024 --max-concurrency 512 --random-input-len 1024 --random-output-len 2048

TRTLLM-gen (via Flashinfer) MLA:

============ Serving Benchmark Result ============
Successful requests:                     1024      
Maximum request concurrency:             512       
Benchmark duration (s):                  434.37    
Total input tokens:                      1045957   
Total generated tokens:                  2097152   
Request throughput (req/s):              2.36      
Output token throughput (tok/s):         4827.98   
Total Token throughput (tok/s):          7235.94   
---------------Time to First Token----------------
Mean TTFT (ms):                          13770.84  
Median TTFT (ms):                        18015.33  
P99 TTFT (ms):                           26645.67  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          87.76     
Median TPOT (ms):                        84.81     
P99 TPOT (ms):                           120.83    
---------------Inter-token Latency----------------
Mean ITL (ms):                           87.79     
Median ITL (ms):                         74.03     
P99 ITL (ms):                            413.67    
=====================================================

cutlass MLA (VLLM_ATTENTION_BACKEND=CUTLASS_MLA):

============ Serving Benchmark Result ============
Successful requests:                     1024      
Maximum request concurrency:             512       
Benchmark duration (s):                  455.36    
Total input tokens:                      1045957   
Total generated tokens:                  2097152   
Request throughput (req/s):              2.25      
Output token throughput (tok/s):         4605.44   
Total Token throughput (tok/s):          6902.40   
---------------Time to First Token----------------
Mean TTFT (ms):                          13788.19  
Median TTFT (ms):                        18140.64  
P99 TTFT (ms):                           26975.27  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          91.38     
Median TPOT (ms):                        88.12     
P99 TPOT (ms):                           126.08    
---------------Inter-token Latency----------------
Mean ITL (ms):                           91.41     
Median ITL (ms):                         77.04     
P99 ITL (ms):                            417.73    
==================================================

Kernel-only microbenchmark:

bs=512:
trtllm-gen MLA kernel latency: 0.501ms
cutlass MLA kernel latency:  0.626ms
bs=1:
trtllm-gen: 0.023ms
cutlass: 0.025ms

Under high-concurrency, trtllm-gen is 25% faster than cutlass. But E2E isn't bottlenecked by MLA, therefore E2E speedup is not as significant.

Accuracy tests:

Cutlass MLA trtllm-gen MLA
gsm8k 0.95 0.97
gpqa 0.47 0.47
mmlu 0.85 0.85

commands:
gsm8k:

lm_eval --model local-chat-completions --tasks gsm8k --limit 100 --model_args model=deepseek-ai/DeepSeek-R1,base_url=http://0.0.0.0:8000/v1/chat/completions,max_gen_toks=16384,num_concurrent=64 --batch_size auto --fewshot_as_multiturn --apply_chat_template

gpqa:

lm_eval --model local-completions --tasks gpqa_extended_zeroshot --model_args model=deepseek-ai/DeepSeek-R1,base_url=http://0.0.0.0:8000/v1/completions,max_gen_toks=65535,num_concurrent=64,trust_remote_code=true,max_length=128000 --batch_size auto

mmlu:

lm_eval --model local-completions --tasks mmlu --model_args model=deepseek-ai/DeepSeek-R1,base_url=http://0.0.0.0:8000/v1/completions,max_gen_toks=65535,num_concurrent=64,trust_remote_code=true,max_length=128000  --batch_size auto

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Jul 16, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new FlashInfer MLA (Multi-LoRA Attention) decode kernel for the vLLM V1 engine. There are critical inconsistencies between the backend implementation and its corresponding test file regarding the shape of the kv_cache tensor and the function signature of the FlashInfer kernel being called. These issues need to be addressed.

@mergify
Copy link

mergify bot commented Jul 31, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @hjjq.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: hjjq <[email protected]>
@hjjq hjjq marked this pull request as ready for review August 5, 2025 17:20
@hjjq hjjq changed the title [Draft][Kernel] Flashinfer MLA (trtllm-gen) decode kernel integration [Kernel] Flashinfer MLA (trtllm-gen) decode kernel integration Aug 5, 2025
@kushanam
Copy link
Collaborator

kushanam commented Aug 7, 2025

@LucasWilkinson this is ready for review. Thanks.

@mgoin
Copy link
Member

mgoin commented Aug 14, 2025

Moving to draft while investigating performance

@mgoin mgoin marked this pull request as draft August 14, 2025 15:08
@mergify
Copy link

mergify bot commented Aug 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @hjjq.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 14, 2025
@kushanam
Copy link
Collaborator

@farazkh80 for review.

@mergify mergify bot removed the needs-rebase label Sep 4, 2025
@MatthewBonanni
Copy link
Contributor

@hjjq You might need to implement a quick fix similar to #24453 because of the changes of #23734

@kushanam
Copy link
Collaborator

kushanam commented Sep 9, 2025

@LucasWilkinson could you please review?

@mergify
Copy link

mergify bot commented Sep 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @hjjq.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 9, 2025
@mergify mergify bot removed the needs-rebase label Sep 9, 2025
hjjq added 2 commits September 9, 2025 15:38
Signed-off-by: hjjq <[email protected]>
Signed-off-by: hjjq <[email protected]>
@hjjq
Copy link
Contributor Author

hjjq commented Sep 9, 2025

Thanks @MatthewBonanni and @LucasWilkinson, I've made the changes and verified correctness with gsm8k. I've also reverted @mgoin 's changes so that test_flex_attention and test_triton_flash_attention are not run anymore. The current failures seem to be unrelated and exist on main.

@benchislett
Copy link
Collaborator

CI on main that confirms these test failures are not caused by this PR:
https://buildkite.com/vllm/ci/builds/30011/steps/canvas?sid=01992fa8-939a-4ba4-8585-1e5074bcd92d

@MatthewBonanni
Copy link
Contributor

@hjjq @LucasWilkinson the remaining CI failures are all failed on main

@mergify
Copy link

mergify bot commented Sep 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @hjjq.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 10, 2025
@mergify mergify bot removed the needs-rebase label Sep 10, 2025
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) September 10, 2025 20:48
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@simon-mo simon-mo merged commit dcb28a3 into vllm-project:main Sep 10, 2025
70 of 73 checks passed
skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
…project#21078)

Signed-off-by: hjjq <[email protected]>
Signed-off-by: Michael Goin <[email protected]>
Signed-off-by: mgoin <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…project#21078)

Signed-off-by: hjjq <[email protected]>
Signed-off-by: Michael Goin <[email protected]>
Signed-off-by: mgoin <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…project#21078)

Signed-off-by: hjjq <[email protected]>
Signed-off-by: Michael Goin <[email protected]>
Signed-off-by: mgoin <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…project#21078)

Signed-off-by: hjjq <[email protected]>
Signed-off-by: Michael Goin <[email protected]>
Signed-off-by: mgoin <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants