Skip to content

Commit cfe3361

Browse files
committed
Fix per comments
Signed-off-by: Chendi.Xue <[email protected]>
1 parent fa2ef49 commit cfe3361

File tree

3 files changed

+41
-30
lines changed

3 files changed

+41
-30
lines changed

tests/full_tests/spec_decode.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@
2020
def time_generation(llm: LLM,
2121
prompts: list[str],
2222
sampling_params: SamplingParams,
23-
num_spec_tokens=5):
23+
num_spec_tokens=5,
24+
num_warmups=1):
2425
# Generate texts from the prompts. The output is a list of RequestOutput
2526
# objects that contain the prompt, generated text, and other information.
2627
# Warmup first
2728
logging.info("Warming up the model...")
28-
llm.generate(prompts, sampling_params)
29-
llm.generate(prompts, sampling_params)
29+
for _ in range(num_warmups):
30+
llm.generate(prompts, sampling_params)
3031
logging.info("Starting generation...")
3132
start = time.time()
3233
outputs = llm.generate(prompts, sampling_params)
@@ -103,7 +104,7 @@ def test_ngram(is_enable, args, prompts, sampling_params, task_key,
103104
)
104105

105106
result_dict = time_generation(llm, prompts, sampling_params,
106-
args.num_spec_tokens)
107+
args.num_spec_tokens, args.num_warmups)
107108

108109
result_queue.put((task_key, result_dict))
109110

@@ -128,7 +129,7 @@ def test_eagle_model(is_enable, args, prompts, sampling_params, task_key,
128129
)
129130

130131
result_dict = time_generation(llm, prompts, sampling_params,
131-
args.num_spec_tokens)
132+
args.num_spec_tokens, args.num_warmups)
132133
result_queue.put((task_key, result_dict))
133134

134135

@@ -152,7 +153,7 @@ def test_medusa_model(is_enable, args, prompts, sampling_params, task_key,
152153
)
153154

154155
result_dict = time_generation(llm, prompts, sampling_params,
155-
args.num_spec_tokens)
156+
args.num_spec_tokens, args.num_warmups)
156157
result_queue.put((task_key, result_dict))
157158

158159

@@ -175,7 +176,7 @@ def test_mtp_model(is_enable, args, prompts, sampling_params, task_key,
175176
)
176177

177178
result_dict = time_generation(llm, prompts, sampling_params,
178-
args.num_spec_tokens)
179+
args.num_spec_tokens, args.num_warmups)
179180
result_queue.put((task_key, result_dict))
180181

181182

@@ -199,6 +200,10 @@ def test_mtp_model(is_enable, args, prompts, sampling_params, task_key,
199200
parser.add_argument("--enforce_eager",
200201
action="store_true",
201202
help="Enforce eager execution for Eagle model.")
203+
parser.add_argument("--num_warmups",
204+
type=int,
205+
default=1,
206+
help="Number of warmup runs before timing.")
202207

203208
# 'ngram', 'eagle', 'eagle3', 'medusa', 'mlp_speculator',
204209
# 'draft_model' or 'deepseek_mtp

vllm_gaudi/v1/sample/hpu_rejection_sampler.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
13
from vllm.v1.sample import rejection_sampler
24
import torch
35
from typing import Optional
@@ -68,7 +70,9 @@ def rejection_greedy_sample_pytorch(
6870
# This loop is a direct translation of the Triton kernel's core logic.
6971
rejected = False
7072
for pos in range(num_draft_tokens):
71-
if not rejected:
73+
if rejected:
74+
break
75+
else:
7276
draft_token = draft_token_ids[start_idx + pos]
7377
target_token = target_argmax[start_idx + pos]
7478

@@ -79,11 +83,6 @@ def rejection_greedy_sample_pytorch(
7983
# all subsequent tokens.
8084
if draft_token != target_token:
8185
rejected = True
82-
else:
83-
# This `break` is a Pythonic optimization. The original Triton
84-
# kernel continues the loop but the `if not rejected` check
85-
# prevents further operations. Breaking is more efficient here.
86-
break
8786

8887
# If the entire draft sequence was accepted without any rejection,
8988
# append the bonus token.
@@ -148,8 +147,7 @@ def rejection_sample(
148147
bonus_token_ids,
149148
is_greedy,
150149
)
151-
if sampling_metadata.all_greedy:
152-
return output_token_ids
150+
return output_token_ids
153151

154152

155153
rejection_sampler.rejection_sample = rejection_sample

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,14 +1629,18 @@ def _prepare_decode_inputs(self,
16291629
# self.input_batch.num_computed_tokens_cpu[req_indices]
16301630
positions = torch.zeros((padded_batch_size, num_tokens),
16311631
dtype=torch.int32)
1632-
# per request using universal self.positions_cpu then pad
1633-
position_split_tensors = torch.split(
1634-
self.positions_cpu[:total_num_scheduled_tokens],
1635-
num_tokens_per_req)
1636-
positions[:num_decodes] = \
1637-
pad_sequence(list(position_split_tensors),
1638-
batch_first=True,
1639-
padding_value=0)[:num_decodes]
1632+
if num_tokens == 1:
1633+
positions[:num_decodes] = self.positions_cpu[:num_decodes].reshape(
1634+
-1, 1)
1635+
else:
1636+
# per request using universal self.positions_cpu then pad
1637+
position_split_tensors = torch.split(
1638+
self.positions_cpu[:total_num_scheduled_tokens],
1639+
num_tokens_per_req)
1640+
positions[:num_decodes] = \
1641+
pad_sequence(list(position_split_tensors),
1642+
batch_first=True,
1643+
padding_value=0)[:num_decodes]
16401644

16411645
padded_index = torch.zeros((padded_batch_size, num_tokens),
16421646
dtype=torch.int64)
@@ -1680,13 +1684,17 @@ def _prepare_decode_inputs(self,
16801684
# self.input_batch.token_ids_cpu[:total_num_scheduled_tokens]
16811685
token_ids = torch.zeros((padded_batch_size, num_tokens),
16821686
dtype=torch.int32)
1683-
token_ids_split_tensors = torch.split(
1684-
self.input_ids_cpu[:total_num_scheduled_tokens],
1685-
num_tokens_per_req)
1686-
token_ids[:num_decodes] = \
1687-
pad_sequence(list(token_ids_split_tensors),
1688-
batch_first=True,
1689-
padding_value=0)[:num_decodes]
1687+
if num_tokens == 1:
1688+
token_ids[:num_decodes] = self.input_ids_cpu[:num_decodes].reshape(
1689+
-1, 1)
1690+
else:
1691+
token_ids_split_tensors = torch.split(
1692+
self.input_ids_cpu[:total_num_scheduled_tokens],
1693+
num_tokens_per_req)
1694+
token_ids[:num_decodes] = \
1695+
pad_sequence(list(token_ids_split_tensors),
1696+
batch_first=True,
1697+
padding_value=0)[:num_decodes]
16901698

16911699
###################################
16921700
# SLOT_MAPPING [batch, 1]

0 commit comments

Comments
 (0)