Skip to content

Commit 1106bff

Browse files
committed
Squash commit
Signed-off-by: hjjq <[email protected]>
1 parent c494f96 commit 1106bff

File tree

6 files changed

+246
-0
lines changed

6 files changed

+246
-0
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
import torch.nn.functional as F
6+
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
7+
from torch import Tensor
8+
9+
from vllm.platforms import current_platform
10+
11+
FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
12+
13+
if not current_platform.has_device_capability(100):
14+
pytest.skip(
15+
reason="FlashInfer MLA Requires compute capability of 10 or above.",
16+
allow_module_level=True)
17+
18+
19+
def ref_mla(
20+
out: Tensor, # (bs, num_heads, v_head_dim)
21+
query: Tensor, # (bs, num_heads, head_dim)
22+
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
23+
scale: float,
24+
block_tables: Tensor, # (bs, max_num_blocks)
25+
seq_lens: Tensor, # (bs,)
26+
):
27+
bs, num_heads, v_head_dim = out.shape
28+
head_dim = query.shape[2]
29+
30+
for i in range(bs):
31+
# gather and flatten KV-cache
32+
kv = kv_cache[
33+
block_tables[i]] # (max_num_blocks, block_size, head_dim)
34+
kv = kv.view(1, -1,
35+
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
36+
v = kv[:, :, :v_head_dim]
37+
38+
q = query[i].view(num_heads, 1, head_dim)
39+
o = F.scaled_dot_product_attention(q,
40+
kv,
41+
v,
42+
scale=scale,
43+
enable_gqa=True)
44+
out[i] = o.view(num_heads, v_head_dim)
45+
46+
return out
47+
48+
49+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
50+
@pytest.mark.parametrize("bs", [1, 2, 4, 16])
51+
@pytest.mark.parametrize("block_size", [32, 64])
52+
def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int):
53+
torch.set_default_device('cuda')
54+
torch.manual_seed(42)
55+
56+
# Deepseek R1 config
57+
num_heads = 128
58+
kv_lora_rank = 512
59+
qk_nope_head_dim = 128
60+
qk_rope_head_dim = 64
61+
qk_head_dim = kv_lora_rank + qk_rope_head_dim
62+
scale = (qk_nope_head_dim + qk_rope_head_dim)**-0.5
63+
64+
MAX_SEQ_LEN = 1024
65+
66+
seq_lens = [torch.randint(2, MAX_SEQ_LEN, (1, )).item() for _ in range(bs)]
67+
seq_lens[-1] = MAX_SEQ_LEN
68+
max_seq_len = max(seq_lens)
69+
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32)
70+
71+
# Generate block tables with random but unique block IDs
72+
# From https://github.com/flashinfer-ai/flashinfer/pull/1222
73+
blocks_per_seq = (seq_lens_tensor + block_size - 1) // block_size
74+
max_num_blocks_per_seq = max(blocks_per_seq.max().item(), 4)
75+
total_blocks_needed = sum(blocks_per_seq)
76+
# Get random unique IDs for all blocks
77+
all_block_ids = torch.randperm(total_blocks_needed)
78+
79+
block_id = 0
80+
block_tables = torch.zeros(
81+
(bs, max_num_blocks_per_seq),
82+
dtype=torch.int32,
83+
)
84+
85+
# Populate block tables and track block assignments
86+
block_id = 0
87+
for i in range(bs):
88+
num_blocks_needed = blocks_per_seq[i]
89+
block_tables[i, :num_blocks_needed] = all_block_ids[block_id:block_id +
90+
num_blocks_needed]
91+
block_id += num_blocks_needed
92+
93+
kv_cache = torch.randn(block_tables.numel(), block_size,
94+
qk_head_dim).to(dtype)
95+
q = torch.randn(bs, num_heads, qk_head_dim).to(dtype)
96+
97+
out_ref = q.new_zeros(bs, num_heads, kv_lora_rank)
98+
ref_mla(out_ref, q, kv_cache, scale, block_tables, seq_lens_tensor)
99+
100+
workspace_buffer = torch.empty(
101+
FLASHINFER_WORKSPACE_BUFFER_SIZE,
102+
dtype=torch.uint8,
103+
device=q.device,
104+
)
105+
# Flashinfer MLA expects the query to be of shape
106+
# (bs, q_len_per_request, num_heads, qk_head_dim),
107+
# where q_len_per_request is the MTP query length (=1 without MTP)
108+
q = q.unsqueeze(1)
109+
110+
out_ans = trtllm_batch_decode_with_kv_cache_mla(
111+
query=q,
112+
kv_cache=kv_cache.unsqueeze(1),
113+
workspace_buffer=workspace_buffer,
114+
qk_nope_head_dim=qk_nope_head_dim,
115+
kv_lora_rank=kv_lora_rank,
116+
qk_rope_head_dim=qk_rope_head_dim,
117+
block_tables=block_tables,
118+
seq_lens=seq_lens_tensor,
119+
max_seq_len=max_seq_len,
120+
bmm1_scale=scale,
121+
)
122+
out_ans = out_ans.squeeze(1)
123+
torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2)

vllm/engine/arg_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,6 +1465,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14651465
"FLASHMLA",
14661466
"FLASHINFER",
14671467
"FLASHINFER_VLLM_V1",
1468+
"FLASHINFER_MLA",
14681469
"ROCM_AITER_MLA",
14691470
"TORCH_SDPA_VLLM_V1",
14701471
"FLEX_ATTENTION",

vllm/platforms/cuda.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,19 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
227227
if use_mla:
228228
# TODO(lucas): refactor to be more concise
229229
# we should probably consider factoring out V1 here
230+
if selected_backend == _Backend.FLASHINFER_MLA:
231+
if use_v1 and cls.has_device_capability(100):
232+
from vllm.v1.attention.backends.utils import (
233+
set_kv_cache_layout)
234+
set_kv_cache_layout("HND")
235+
logger.info_once(
236+
"Using FlashInfer MLA backend on V1 engine.")
237+
return ("vllm.v1.attention.backends.mla."
238+
"flashinfer_mla.FlashInferMLABackend")
239+
else:
240+
logger.warning(
241+
"FlashInfer MLA backend is only supported on V1 engine"
242+
" and requires compute capability 10.0")
230243
if selected_backend == _Backend.CUTLASS_MLA:
231244
if use_v1:
232245
logger.info_once("Using Cutlass MLA backend on V1 engine.")

vllm/platforms/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class _Backend(enum.Enum):
5050
TORCH_SDPA = enum.auto()
5151
FLASHINFER = enum.auto()
5252
FLASHINFER_VLLM_V1 = enum.auto()
53+
FLASHINFER_MLA = enum.auto()
5354
TRITON_MLA = enum.auto() # Supported by V1
5455
TRITON_MLA_VLLM_V1 = enum.auto()
5556
FLASHMLA_VLLM_V1 = enum.auto()

vllm/v1/attention/backends/mla/common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ class MLACommonMetadata(Generic[D]):
349349

350350
num_reqs: int
351351
max_query_len: int
352+
max_seq_len: int
352353

353354
num_actual_tokens: int # Number of tokens excluding padding.
354355
query_start_loc: torch.Tensor
@@ -586,6 +587,7 @@ def build(self,
586587
num_reqs = common_attn_metadata.num_reqs
587588
num_tokens = common_attn_metadata.num_actual_tokens
588589
max_query_len = common_attn_metadata.max_query_len
590+
max_seq_len = common_attn_metadata.seq_lens_cpu.max().item()
589591

590592
# Note(simon): be careful about the CPU <> GPU memory movement in this
591593
# function. We should avoid GPU -> CPU sync as much as possible because
@@ -710,6 +712,7 @@ def build(self,
710712
attn_metadata = self.metadata_cls(
711713
num_reqs=common_attn_metadata.num_reqs,
712714
max_query_len=common_attn_metadata.max_query_len,
715+
max_seq_len=max_seq_len,
713716
num_actual_tokens=num_tokens,
714717
query_start_loc=query_start_loc,
715718
slot_mapping=slot_mapping,
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from typing import Optional
5+
6+
import torch
7+
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
8+
9+
from vllm.attention.backends.abstract import (AttentionType,
10+
is_quantized_kv_cache)
11+
from vllm.logger import init_logger
12+
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
13+
MLACommonImpl,
14+
MLACommonMetadata)
15+
16+
logger = init_logger(__name__)
17+
18+
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
19+
20+
21+
class FlashInferMLABackend(MLACommonBackend):
22+
23+
@staticmethod
24+
def get_name() -> str:
25+
return "FLASHINFER_MLA"
26+
27+
@staticmethod
28+
def get_impl_cls() -> type["FlashInferMLAImpl"]:
29+
return FlashInferMLAImpl
30+
31+
32+
g_fi_workspace = torch.empty(
33+
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
34+
dtype=torch.uint8,
35+
device="cuda",
36+
)
37+
38+
39+
class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
40+
41+
def __init__(
42+
self,
43+
num_heads: int,
44+
head_size: int,
45+
scale: float,
46+
num_kv_heads: int,
47+
alibi_slopes: Optional[list[float]],
48+
sliding_window: Optional[int],
49+
kv_cache_dtype: str,
50+
logits_soft_cap: Optional[float],
51+
attn_type: str,
52+
kv_sharing_target_layer_name: Optional[str],
53+
# MLA Specific Arguments
54+
**mla_args) -> None:
55+
super().__init__(num_heads, head_size, scale, num_kv_heads,
56+
alibi_slopes, sliding_window, kv_cache_dtype,
57+
logits_soft_cap, attn_type,
58+
kv_sharing_target_layer_name, **mla_args)
59+
60+
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
61+
if any(unsupported_features):
62+
raise NotImplementedError(
63+
"FlashInferMLAImpl does not support one of the following: "
64+
"alibi_slopes, sliding_window, logits_soft_cap")
65+
66+
if attn_type != AttentionType.DECODER:
67+
raise NotImplementedError("Encoder self-attention and "
68+
"encoder/decoder cross-attention "
69+
"are not implemented for "
70+
"FlashInferMLAImpl")
71+
72+
if is_quantized_kv_cache(self.kv_cache_dtype):
73+
raise NotImplementedError(
74+
"FlashInferMLA V1 with FP8 KV cache not yet supported")
75+
76+
self._workspace_buffer = g_fi_workspace
77+
78+
def _forward_decode(
79+
self,
80+
q_nope: torch.Tensor,
81+
q_pe: torch.Tensor,
82+
kv_c_and_k_pe_cache: torch.Tensor,
83+
attn_metadata: MLACommonMetadata,
84+
) -> torch.Tensor:
85+
assert kv_c_and_k_pe_cache.numel() > 0
86+
assert attn_metadata.decode is not None
87+
88+
q = torch.cat([q_nope, q_pe], dim=-1)
89+
# trtllm API requires extra dimension q_len_per_request for MTP
90+
q = q.unsqueeze(1)
91+
92+
o = trtllm_batch_decode_with_kv_cache_mla(
93+
query=q,
94+
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
95+
workspace_buffer=self._workspace_buffer,
96+
qk_nope_head_dim=self.qk_nope_head_dim,
97+
kv_lora_rank=self.kv_lora_rank,
98+
qk_rope_head_dim=self.qk_rope_head_dim,
99+
block_tables=attn_metadata.decode.block_table,
100+
seq_lens=attn_metadata.decode.seq_lens,
101+
max_seq_len=attn_metadata.max_seq_len,
102+
bmm1_scale=self.scale,
103+
)
104+
105+
return self._v_up_proj(o)

0 commit comments

Comments
 (0)