11from __future__ import annotations
22
3+ import os
4+
35import pytest
46from vllm import SamplingParams
57from vllm .config import CompilationConfig , CUDAGraphMode
68
79from tests .e2e .conftest import VllmRunner
810
11+ os .environ ["VLLM_WORKER_MULTIPROC_METHOD" ] = "spawn"
12+
913
1014@pytest .fixture
1115def sampling_config ():
@@ -17,12 +21,11 @@ def model_name():
1721 return "wemaster/deepseek_mtp_main_random_bf16"
1822
1923
20- def mtp_correctness (
21- sampling_config : SamplingParams ,
22- model_name : str ,
23- num_speculative_tokens : int ,
24- graph_mode : CUDAGraphMode = CUDAGraphMode .PIECEWISE ,
25- ):
24+ def mtp_correctness (sampling_config : SamplingParams ,
25+ model_name : str ,
26+ num_speculative_tokens : int ,
27+ graph_mode : CUDAGraphMode = CUDAGraphMode .PIECEWISE ,
28+ disable_padded_drafter_batch = True ):
2629 example_prompts = [
2730 "Hello, my name is" ,
2831 "The president of the United States is" ,
@@ -54,6 +57,7 @@ def mtp_correctness(
5457 speculative_config = {
5558 "method" : "deepseek_mtp" ,
5659 "num_speculative_tokens" : num_speculative_tokens ,
60+ "disable_padded_drafter_batch" : disable_padded_drafter_batch ,
5761 },
5862 enforce_eager = False ,
5963 max_model_len = 2000 ,
@@ -110,3 +114,23 @@ def test_mtp2_correctness_full_graph(
110114 model_name : str ,
111115):
112116 mtp_correctness (sampling_config , model_name , 2 , CUDAGraphMode .FULL )
117+
118+
119+ def test_mtp1_correctness_piecewise_graph_with_pad (
120+ sampling_config : SamplingParams ,
121+ model_name : str ,
122+ ):
123+ mtp_correctness (sampling_config ,
124+ model_name ,
125+ 1 ,
126+ disable_padded_drafter_batch = False )
127+
128+
129+ def test_mtp2_correctness_piecewise_graph_with_pad (
130+ sampling_config : SamplingParams ,
131+ model_name : str ,
132+ ):
133+ mtp_correctness (sampling_config ,
134+ model_name ,
135+ 2 ,
136+ disable_padded_drafter_batch = False )
0 commit comments