1919from vllm .model_executor .models .llama import LlamaForCausalLM
2020from vllm .platforms import current_platform
2121from vllm .v1 .spec_decode .eagle import EagleProposer
22+ from vllm .v1 .spec_decode .metadata import SpecDecodeMetadata
23+ from vllm .v1 .worker .gpu_input_batch import CachedRequestState , InputBatch
2224
2325model_dir = "meta-llama/Llama-3.1-8B-Instruct"
2426eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
@@ -64,6 +66,86 @@ def _create_proposer(
6466 device = current_platform .device_type )
6567
6668
69+ def test_prepare_next_token_ids ():
70+ """
71+ Test for prepare_next_token_ids_cpu and prepare_next_token_ids_padded.
72+ Each will produce a device tensor of next_token_ids, taking as input
73+ either the GPU tensor of sampled_token_ids with -1 for rejected tokens,
74+ or the CPU python list[list[int]] with the rejected tokens removed.
75+ """
76+ device = torch .device (current_platform .device_type )
77+
78+ num_requests = 4
79+ num_speculative_tokens = 4
80+ batch_spec = BatchSpec (
81+ seq_lens = [num_speculative_tokens + 1 ] * num_requests ,
82+ query_lens = [num_speculative_tokens + 1 ] * num_requests ,
83+ )
84+
85+ req_ids = [f"req_{ i + 1 } " for i in range (num_requests )]
86+ mock_input_batch = mock .MagicMock (spec = InputBatch )
87+ mock_input_batch .req_ids = req_ids
88+ mock_input_batch .num_reqs = num_requests
89+ mock_input_batch .vocab_size = 100
90+
91+ mock_num_scheduled_tokens = {req_id : 0 for req_id in req_ids }
92+ mock_requests = {}
93+ for req_id in req_ids :
94+ mock_request = mock .MagicMock (spec = CachedRequestState )
95+ # Each request will have a backup next token id of 10, 20, 30, 40
96+ mock_request .get_token_id .return_value = int (req_id .split ("_" )[1 ]) * 10
97+ mock_request .num_computed_tokens = 0
98+ mock_requests [req_id ] = mock_request
99+
100+ sampled_token_ids = [
101+ [0 , 1 , - 1 , - 1 , - 1 ], # 1 accepted, 3 rejected, "1" sampled
102+ [0 , 1 , 2 , 3 , 4 ], # all accepted, "4" sampled
103+ [- 1 , - 1 , - 1 , - 1 , - 1 ], # sampling skipped, use backup token "30"
104+ [- 1 , - 1 , - 1 , - 1 , - 1 ] # this request will be discarded
105+ ]
106+ sampled_token_ids_tensor = torch .tensor (sampled_token_ids ,
107+ dtype = torch .int32 ,
108+ device = device )
109+ sampled_token_ids_cpu = [[i for i in seq if i != - 1 ]
110+ for seq in sampled_token_ids ]
111+
112+ expected_next_token_ids_cpu = [1 , 4 , 30 , 40 ]
113+ expected_next_token_ids_tensor = torch .tensor (expected_next_token_ids_cpu ,
114+ dtype = torch .int32 ,
115+ device = device )
116+
117+ proposer = _create_proposer ("eagle" , num_speculative_tokens )
118+
119+ next_token_ids_from_cpu = proposer .prepare_next_token_ids_cpu (
120+ sampled_token_ids_cpu , mock_requests , mock_input_batch ,
121+ mock_num_scheduled_tokens )
122+
123+ assert torch .equal (next_token_ids_from_cpu , expected_next_token_ids_tensor )
124+
125+ common_attn_metadata = create_common_attn_metadata (
126+ batch_spec ,
127+ block_size = 16 ,
128+ device = device ,
129+ )
130+
131+ discarded_req_indices = torch .tensor ([3 ], dtype = torch .int64 , device = device )
132+ num_discarded_reqs = 1
133+
134+ expected_valid_sampled_tokens_count = torch .tensor ([2 , 5 , 0 , 0 ],
135+ dtype = torch .int32 ,
136+ device = device )
137+
138+ next_token_ids_from_padded , valid_sampled_tokens_count = \
139+ proposer .prepare_next_token_ids_padded (
140+ common_attn_metadata , sampled_token_ids_tensor , mock_requests ,
141+ mock_input_batch , discarded_req_indices , num_discarded_reqs )
142+
143+ assert torch .equal (next_token_ids_from_padded ,
144+ expected_next_token_ids_tensor )
145+ assert torch .equal (valid_sampled_tokens_count ,
146+ expected_valid_sampled_tokens_count )
147+
148+
67149def test_prepare_inputs ():
68150 """
69151 cu_target_query_lens: [0, a, a + b, a + b + c]
@@ -90,10 +172,24 @@ def test_prepare_inputs():
90172 device = device ,
91173 )
92174
93- # Rejected tokens per request: [1, 3, 2]
94- num_rejected_tokens = torch .tensor ([1 , 3 , 2 ],
95- dtype = torch .int32 ,
96- device = device )
175+ # If there are `k` sampled tokens, then `k-1` tokens are draft tokens
176+ # from the previous iteration, and the last token is the bonus token sampled
177+ # from the base model.
178+ num_draft_tokens = [3 , 6 , 4 ] # one less than query_lens
179+ # num rejected tokens is [1, 3, 2]
180+ ACCEPT_TOKEN = 0
181+ BONUS_TOKEN = 1
182+ REJECT_TOKEN = - 1
183+ sampled_token_ids = [
184+ [ACCEPT_TOKEN , ACCEPT_TOKEN , REJECT_TOKEN , BONUS_TOKEN ],
185+ [
186+ ACCEPT_TOKEN , ACCEPT_TOKEN , ACCEPT_TOKEN , REJECT_TOKEN ,
187+ REJECT_TOKEN , REJECT_TOKEN , BONUS_TOKEN
188+ ],
189+ [ACCEPT_TOKEN , ACCEPT_TOKEN , REJECT_TOKEN , REJECT_TOKEN , BONUS_TOKEN ]
190+ ]
191+ sampled_token_ids = [[i for i in seq if i != REJECT_TOKEN ]
192+ for seq in sampled_token_ids ]
97193
98194 # Expected calculations:
99195 # query_len_per_req = [4, 7, 5]
@@ -125,14 +221,85 @@ def test_prepare_inputs():
125221 proposer = _create_proposer ("eagle" , 1 )
126222
127223 updated_metadata , token_indices = proposer .prepare_inputs (
128- common_attn_metadata , num_rejected_tokens . cpu () )
224+ common_attn_metadata , sampled_token_ids , num_draft_tokens )
129225
130226 assert torch .equal (updated_metadata .query_start_loc ,
131227 expected_cu_num_tokens )
132228 assert token_indices .shape [0 ] == expected_cu_num_tokens [- 1 ].item ()
133229 assert torch .equal (token_indices , expected_token_indices )
134230
135231
232+ def test_prepare_inputs_padded ():
233+ """
234+ Input scenario is 3 requests with num_speculative_tokens == 2 and:
235+ - Request 1: query_len = 3, rejected = 1
236+ - Request 2: query_len = 3, rejected = 0
237+ - Request 3: query_len = 3, rejected = 2
238+
239+ Expected outputs:
240+ token_indices: [0, 1, 2,
241+ 3, 4, 5,
242+ 6, 7, 8]
243+ Reason: Deferred computation should not disturb the original indices.
244+
245+ token_indices_to_sample: [1, 5, 6]
246+ Reason: After accounting for rejections, these are the valid token positions
247+ from the original indices to sample from.
248+ """
249+
250+ device = torch .device (current_platform .device_type )
251+
252+ expected_token_indices = torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ],
253+ dtype = torch .int32 ,
254+ device = device )
255+ expected_token_indices_to_sample = torch .tensor ([1 , 5 , 6 ],
256+ dtype = torch .int32 ,
257+ device = device )
258+
259+ num_speculative_tokens = 2
260+ batch_spec = BatchSpec (
261+ seq_lens = [3 , 3 , 3 ],
262+ query_lens = [3 , 3 , 3 ],
263+ )
264+
265+ common_attn_metadata = create_common_attn_metadata (
266+ batch_spec ,
267+ block_size = 16 ,
268+ device = device ,
269+ )
270+
271+ # Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9]
272+ expected_query_start_loc = torch .tensor ([0 , 3 , 6 , 9 ],
273+ dtype = torch .int32 ,
274+ device = device )
275+ spec_decode_metadata = SpecDecodeMetadata .make_dummy (
276+ draft_token_ids = [[0 ] * num_speculative_tokens ] * 3 ,
277+ device = device ,
278+ )
279+
280+ # num_rejected_tokens = [1, 0, 2]
281+ # num_draft_tokens = [2, 2, 2]
282+ # valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens
283+ valid_sampled_tokens_count = torch .tensor ([2 , 3 , 1 ],
284+ dtype = torch .int32 ,
285+ device = device )
286+
287+ proposer = _create_proposer ("eagle" , num_speculative_tokens )
288+
289+ output_metadata , token_indices , token_indices_to_sample = \
290+ proposer .prepare_inputs_padded (
291+ common_attn_metadata ,
292+ spec_decode_metadata ,
293+ valid_sampled_tokens_count )
294+
295+ assert output_metadata .max_query_len == 3
296+ assert torch .equal (output_metadata .query_start_loc ,
297+ expected_query_start_loc )
298+ assert torch .equal (token_indices , expected_token_indices )
299+ assert torch .equal (token_indices_to_sample ,
300+ expected_token_indices_to_sample )
301+
302+
136303@pytest .mark .parametrize ("method" , ["eagle" , "eagle3" ])
137304@pytest .mark .parametrize ("attn_backend" ,
138305 get_attn_backend_list_based_on_platform ())
@@ -373,6 +540,7 @@ def create_deterministic_logits(token_ids):
373540 target_positions = target_positions ,
374541 target_hidden_states = target_hidden_states ,
375542 next_token_ids = next_token_ids ,
543+ last_token_indices = None ,
376544 common_attn_metadata = common_attn_metadata ,
377545 sampling_metadata = sampling_metadata )
378546
@@ -526,6 +694,7 @@ def create_deterministic_logits(token_ids, k: int):
526694 target_positions = target_positions ,
527695 target_hidden_states = target_hidden_states ,
528696 next_token_ids = next_token_ids ,
697+ last_token_indices = None ,
529698 common_attn_metadata = common_attn_metadata ,
530699 sampling_metadata = sampling_metadata )
531700 assert result .shape == (batch_size , num_speculative_tokens )
0 commit comments