Skip to content

Commit da33ab3

Browse files
benchislettcharlifu
authored andcommitted
[Spec Decode] Efficient padded speculation (vllm-project#24539)
Signed-off-by: Benjamin Chislett <[email protected]> Signed-off-by: charlifu <[email protected]>
1 parent acfc54a commit da33ab3

File tree

5 files changed

+507
-104
lines changed

5 files changed

+507
-104
lines changed

tests/v1/spec_decode/test_eagle.py

Lines changed: 174 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from vllm.model_executor.models.llama import LlamaForCausalLM
2020
from vllm.platforms import current_platform
2121
from vllm.v1.spec_decode.eagle import EagleProposer
22+
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
23+
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
2224

2325
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
2426
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
@@ -64,6 +66,86 @@ def _create_proposer(
6466
device=current_platform.device_type)
6567

6668

69+
def test_prepare_next_token_ids():
70+
"""
71+
Test for prepare_next_token_ids_cpu and prepare_next_token_ids_padded.
72+
Each will produce a device tensor of next_token_ids, taking as input
73+
either the GPU tensor of sampled_token_ids with -1 for rejected tokens,
74+
or the CPU python list[list[int]] with the rejected tokens removed.
75+
"""
76+
device = torch.device(current_platform.device_type)
77+
78+
num_requests = 4
79+
num_speculative_tokens = 4
80+
batch_spec = BatchSpec(
81+
seq_lens=[num_speculative_tokens + 1] * num_requests,
82+
query_lens=[num_speculative_tokens + 1] * num_requests,
83+
)
84+
85+
req_ids = [f"req_{i+1}" for i in range(num_requests)]
86+
mock_input_batch = mock.MagicMock(spec=InputBatch)
87+
mock_input_batch.req_ids = req_ids
88+
mock_input_batch.num_reqs = num_requests
89+
mock_input_batch.vocab_size = 100
90+
91+
mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids}
92+
mock_requests = {}
93+
for req_id in req_ids:
94+
mock_request = mock.MagicMock(spec=CachedRequestState)
95+
# Each request will have a backup next token id of 10, 20, 30, 40
96+
mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10
97+
mock_request.num_computed_tokens = 0
98+
mock_requests[req_id] = mock_request
99+
100+
sampled_token_ids = [
101+
[0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled
102+
[0, 1, 2, 3, 4], # all accepted, "4" sampled
103+
[-1, -1, -1, -1, -1], # sampling skipped, use backup token "30"
104+
[-1, -1, -1, -1, -1] # this request will be discarded
105+
]
106+
sampled_token_ids_tensor = torch.tensor(sampled_token_ids,
107+
dtype=torch.int32,
108+
device=device)
109+
sampled_token_ids_cpu = [[i for i in seq if i != -1]
110+
for seq in sampled_token_ids]
111+
112+
expected_next_token_ids_cpu = [1, 4, 30, 40]
113+
expected_next_token_ids_tensor = torch.tensor(expected_next_token_ids_cpu,
114+
dtype=torch.int32,
115+
device=device)
116+
117+
proposer = _create_proposer("eagle", num_speculative_tokens)
118+
119+
next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu(
120+
sampled_token_ids_cpu, mock_requests, mock_input_batch,
121+
mock_num_scheduled_tokens)
122+
123+
assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor)
124+
125+
common_attn_metadata = create_common_attn_metadata(
126+
batch_spec,
127+
block_size=16,
128+
device=device,
129+
)
130+
131+
discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device)
132+
num_discarded_reqs = 1
133+
134+
expected_valid_sampled_tokens_count = torch.tensor([2, 5, 0, 0],
135+
dtype=torch.int32,
136+
device=device)
137+
138+
next_token_ids_from_padded, valid_sampled_tokens_count = \
139+
proposer.prepare_next_token_ids_padded(
140+
common_attn_metadata, sampled_token_ids_tensor, mock_requests,
141+
mock_input_batch, discarded_req_indices, num_discarded_reqs)
142+
143+
assert torch.equal(next_token_ids_from_padded,
144+
expected_next_token_ids_tensor)
145+
assert torch.equal(valid_sampled_tokens_count,
146+
expected_valid_sampled_tokens_count)
147+
148+
67149
def test_prepare_inputs():
68150
"""
69151
cu_target_query_lens: [0, a, a + b, a + b + c]
@@ -90,10 +172,24 @@ def test_prepare_inputs():
90172
device=device,
91173
)
92174

93-
# Rejected tokens per request: [1, 3, 2]
94-
num_rejected_tokens = torch.tensor([1, 3, 2],
95-
dtype=torch.int32,
96-
device=device)
175+
# If there are `k` sampled tokens, then `k-1` tokens are draft tokens
176+
# from the previous iteration, and the last token is the bonus token sampled
177+
# from the base model.
178+
num_draft_tokens = [3, 6, 4] # one less than query_lens
179+
# num rejected tokens is [1, 3, 2]
180+
ACCEPT_TOKEN = 0
181+
BONUS_TOKEN = 1
182+
REJECT_TOKEN = -1
183+
sampled_token_ids = [
184+
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
185+
[
186+
ACCEPT_TOKEN, ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN,
187+
REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN
188+
],
189+
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN]
190+
]
191+
sampled_token_ids = [[i for i in seq if i != REJECT_TOKEN]
192+
for seq in sampled_token_ids]
97193

98194
# Expected calculations:
99195
# query_len_per_req = [4, 7, 5]
@@ -125,14 +221,85 @@ def test_prepare_inputs():
125221
proposer = _create_proposer("eagle", 1)
126222

127223
updated_metadata, token_indices = proposer.prepare_inputs(
128-
common_attn_metadata, num_rejected_tokens.cpu())
224+
common_attn_metadata, sampled_token_ids, num_draft_tokens)
129225

130226
assert torch.equal(updated_metadata.query_start_loc,
131227
expected_cu_num_tokens)
132228
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
133229
assert torch.equal(token_indices, expected_token_indices)
134230

135231

232+
def test_prepare_inputs_padded():
233+
"""
234+
Input scenario is 3 requests with num_speculative_tokens == 2 and:
235+
- Request 1: query_len = 3, rejected = 1
236+
- Request 2: query_len = 3, rejected = 0
237+
- Request 3: query_len = 3, rejected = 2
238+
239+
Expected outputs:
240+
token_indices: [0, 1, 2,
241+
3, 4, 5,
242+
6, 7, 8]
243+
Reason: Deferred computation should not disturb the original indices.
244+
245+
token_indices_to_sample: [1, 5, 6]
246+
Reason: After accounting for rejections, these are the valid token positions
247+
from the original indices to sample from.
248+
"""
249+
250+
device = torch.device(current_platform.device_type)
251+
252+
expected_token_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8],
253+
dtype=torch.int32,
254+
device=device)
255+
expected_token_indices_to_sample = torch.tensor([1, 5, 6],
256+
dtype=torch.int32,
257+
device=device)
258+
259+
num_speculative_tokens = 2
260+
batch_spec = BatchSpec(
261+
seq_lens=[3, 3, 3],
262+
query_lens=[3, 3, 3],
263+
)
264+
265+
common_attn_metadata = create_common_attn_metadata(
266+
batch_spec,
267+
block_size=16,
268+
device=device,
269+
)
270+
271+
# Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9]
272+
expected_query_start_loc = torch.tensor([0, 3, 6, 9],
273+
dtype=torch.int32,
274+
device=device)
275+
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
276+
draft_token_ids=[[0] * num_speculative_tokens] * 3,
277+
device=device,
278+
)
279+
280+
# num_rejected_tokens = [1, 0, 2]
281+
# num_draft_tokens = [2, 2, 2]
282+
# valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens
283+
valid_sampled_tokens_count = torch.tensor([2, 3, 1],
284+
dtype=torch.int32,
285+
device=device)
286+
287+
proposer = _create_proposer("eagle", num_speculative_tokens)
288+
289+
output_metadata, token_indices, token_indices_to_sample = \
290+
proposer.prepare_inputs_padded(
291+
common_attn_metadata,
292+
spec_decode_metadata,
293+
valid_sampled_tokens_count)
294+
295+
assert output_metadata.max_query_len == 3
296+
assert torch.equal(output_metadata.query_start_loc,
297+
expected_query_start_loc)
298+
assert torch.equal(token_indices, expected_token_indices)
299+
assert torch.equal(token_indices_to_sample,
300+
expected_token_indices_to_sample)
301+
302+
136303
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
137304
@pytest.mark.parametrize("attn_backend",
138305
get_attn_backend_list_based_on_platform())
@@ -373,6 +540,7 @@ def create_deterministic_logits(token_ids):
373540
target_positions=target_positions,
374541
target_hidden_states=target_hidden_states,
375542
next_token_ids=next_token_ids,
543+
last_token_indices=None,
376544
common_attn_metadata=common_attn_metadata,
377545
sampling_metadata=sampling_metadata)
378546

@@ -526,6 +694,7 @@ def create_deterministic_logits(token_ids, k: int):
526694
target_positions=target_positions,
527695
target_hidden_states=target_hidden_states,
528696
next_token_ids=next_token_ids,
697+
last_token_indices=None,
529698
common_attn_metadata=common_attn_metadata,
530699
sampling_metadata=sampling_metadata)
531700
assert result.shape == (batch_size, num_speculative_tokens)

vllm/config/speculative.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ class SpeculativeConfig:
8383
disable_by_batch_size: Optional[int] = None
8484
"""Disable speculative decoding for new incoming requests when the number
8585
of enqueued requests is larger than this value, if provided."""
86+
disable_padded_drafter_batch: bool = False
87+
"""Disable input padding for speculative decoding. If set to True,
88+
speculative input batches can contain sequences of different lengths,
89+
which may only be supported by certain attention backends. This currently
90+
only affects the EAGLE method of speculation."""
8691

8792
# Ngram proposer configuration
8893
prompt_lookup_max: Optional[int] = None

0 commit comments

Comments
 (0)