Skip to content

Commit 5d28ab4

Browse files
committed
Refactor spec decode to support efficient padded speculation
Signed-off-by: xuyexiong <[email protected]>
1 parent cba69e1 commit 5d28ab4

File tree

7 files changed

+1165
-442
lines changed

7 files changed

+1165
-442
lines changed

tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from __future__ import annotations
22

3+
import os
4+
35
import pytest
46
from vllm import SamplingParams
57
from vllm.config import CompilationConfig, CUDAGraphMode
68

79
from tests.e2e.conftest import VllmRunner
810

11+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
12+
913

1014
@pytest.fixture
1115
def sampling_config():
@@ -17,12 +21,11 @@ def model_name():
1721
return "wemaster/deepseek_mtp_main_random_bf16"
1822

1923

20-
def mtp_correctness(
21-
sampling_config: SamplingParams,
22-
model_name: str,
23-
num_speculative_tokens: int,
24-
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
25-
):
24+
def mtp_correctness(sampling_config: SamplingParams,
25+
model_name: str,
26+
num_speculative_tokens: int,
27+
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
28+
disable_padded_drafter_batch=True):
2629
example_prompts = [
2730
"Hello, my name is",
2831
"The president of the United States is",
@@ -54,6 +57,7 @@ def mtp_correctness(
5457
speculative_config={
5558
"method": "deepseek_mtp",
5659
"num_speculative_tokens": num_speculative_tokens,
60+
"disable_padded_drafter_batch": disable_padded_drafter_batch,
5761
},
5862
enforce_eager=False,
5963
max_model_len=2000,
@@ -110,3 +114,23 @@ def test_mtp2_correctness_full_graph(
110114
model_name: str,
111115
):
112116
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)
117+
118+
119+
def test_mtp1_correctness_piecewise_graph_with_pad(
120+
sampling_config: SamplingParams,
121+
model_name: str,
122+
):
123+
mtp_correctness(sampling_config,
124+
model_name,
125+
1,
126+
disable_padded_drafter_batch=False)
127+
128+
129+
def test_mtp2_correctness_piecewise_graph_with_pad(
130+
sampling_config: SamplingParams,
131+
model_name: str,
132+
):
133+
mtp_correctness(sampling_config,
134+
model_name,
135+
2,
136+
disable_padded_drafter_batch=False)

vllm_ascend/spec_decode/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,21 @@
1919
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
2020
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
2121
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
22+
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
2223

2324

24-
def get_spec_decode_method(method, vllm_config, device, runner):
25+
def get_spec_decode_method(method,
26+
vllm_config,
27+
device,
28+
runner,
29+
is_torchair_graph=False):
2530
if method == "ngram":
2631
return NgramProposer(vllm_config, device, runner)
2732
elif method in ["eagle", "eagle3"]:
2833
return EagleProposer(vllm_config, device, runner)
2934
elif method == 'deepseek_mtp':
35+
if is_torchair_graph:
36+
return TorchairMtpProposer(vllm_config, device, runner)
3037
return MtpProposer(vllm_config, device, runner)
3138
else:
3239
raise ValueError("Unknown speculative decoding method: "

0 commit comments

Comments
 (0)