Skip to content

Commit 9d80e60

Browse files
committed
add spec decode to hpu_model_runner
Status 1. eagle and ngram is working TODO add prefill to draft model performance Signed-off-by: Chendi.Xue <[email protected]>
1 parent efdf1d7 commit 9d80e60

File tree

8 files changed

+970
-71
lines changed

8 files changed

+970
-71
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ numpy==1.26.4
55
tabulate
66
setuptools>=77.0.3,<80.0.0
77
setuptools-scm>=8
8+
numba

tests/full_tests/spec_decode.py

Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
from vllm import LLM, SamplingParams
2+
3+
import os
4+
import time
5+
import argparse
6+
import multiprocessing
7+
import logging
8+
from vllm.v1.metrics.reader import Counter, Vector
9+
10+
logging.basicConfig(
11+
level=logging.INFO,
12+
format="[%(levelname)s][%(processName)s][%(asctime)s] %(message)s",
13+
)
14+
15+
os.environ["VLLM_SKIP_WARMUP"] = "true"
16+
os.environ["VLLM_CONTIGUOUS_PA"] = "false"
17+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
18+
19+
20+
def time_generation(llm: LLM,
21+
prompts: list[str],
22+
sampling_params: SamplingParams,
23+
num_spec_tokens=5):
24+
# Generate texts from the prompts. The output is a list of RequestOutput
25+
# objects that contain the prompt, generated text, and other information.
26+
# Warmup first
27+
logging.info("Warming up the model...")
28+
llm.generate(prompts, sampling_params)
29+
llm.generate(prompts, sampling_params)
30+
logging.info("Starting generation...")
31+
start = time.time()
32+
outputs = llm.generate(prompts, sampling_params)
33+
end = time.time()
34+
latency = end - start
35+
logging.info("Generation completed in %.2f seconds.", latency)
36+
# Print the outputs.
37+
ret = []
38+
acceptance_counts = [0] * (num_spec_tokens + 1)
39+
for output in outputs:
40+
generated_text = output.outputs[0].text
41+
ret.append(generated_text)
42+
43+
try:
44+
metrics = llm.llm_engine.get_metrics()
45+
except Exception as e:
46+
logging.error("Error getting metrics: %s", e)
47+
result_dict = {
48+
'ret_spec': ret,
49+
'latency': latency,
50+
'acc_counts': acceptance_counts,
51+
'acc_rate': 0.0,
52+
'num_draft_tokens': 0,
53+
'num_drafts': 0,
54+
}
55+
return result_dict
56+
num_drafts = 0
57+
num_draft_tokens = 0
58+
num_accepted_tokens = 0
59+
for metric in metrics:
60+
if metric.name == "vllm:spec_decode_num_drafts":
61+
assert isinstance(metric, Counter)
62+
num_drafts += metric.value
63+
elif metric.name == "vllm:spec_decode_num_draft_tokens":
64+
assert isinstance(metric, Counter)
65+
num_draft_tokens += metric.value
66+
elif metric.name == "vllm:spec_decode_num_accepted_tokens":
67+
assert isinstance(metric, Counter)
68+
num_accepted_tokens += metric.value
69+
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
70+
assert isinstance(metric, Vector)
71+
for pos in range(len(metric.values)):
72+
acceptance_counts[pos] += metric.values[pos]
73+
74+
accept_rate = num_accepted_tokens / num_draft_tokens \
75+
if num_draft_tokens > 0 else 0.0
76+
result_dict = {
77+
'ret_spec': ret,
78+
'latency': latency,
79+
'acc_counts': acceptance_counts,
80+
'acc_rate': accept_rate,
81+
'num_draft_tokens': num_draft_tokens,
82+
'num_drafts': num_drafts,
83+
}
84+
return result_dict
85+
86+
87+
def test_ngram(is_enable, args, prompts, sampling_params, task_key,
88+
result_queue):
89+
if not is_enable:
90+
llm = LLM(
91+
model="Qwen/Qwen3-4B",
92+
disable_log_stats=False,
93+
)
94+
else:
95+
llm = LLM(
96+
model="Qwen/Qwen3-4B",
97+
speculative_config={
98+
"method": "ngram",
99+
"prompt_lookup_max": 3,
100+
"num_speculative_tokens": args.num_spec_tokens,
101+
},
102+
disable_log_stats=False,
103+
)
104+
105+
result_dict = time_generation(llm, prompts, sampling_params,
106+
args.num_spec_tokens)
107+
108+
result_queue.put((task_key, result_dict))
109+
110+
111+
def test_eagle_model(is_enable, args, prompts, sampling_params, task_key,
112+
result_queue):
113+
if not is_enable:
114+
llm = LLM(
115+
model="meta-llama/Meta-Llama-3-8B-Instruct",
116+
disable_log_stats=False,
117+
enforce_eager=args.enforce_eager,
118+
)
119+
else:
120+
llm = LLM(
121+
model="meta-llama/Meta-Llama-3-8B-Instruct",
122+
speculative_config={
123+
"model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
124+
"num_speculative_tokens": args.num_spec_tokens,
125+
},
126+
disable_log_stats=False,
127+
enforce_eager=args.enforce_eager,
128+
)
129+
130+
result_dict = time_generation(llm, prompts, sampling_params,
131+
args.num_spec_tokens)
132+
result_queue.put((task_key, result_dict))
133+
134+
135+
def test_medusa_model(is_enable, args, prompts, sampling_params, task_key,
136+
result_queue):
137+
if not is_enable:
138+
llm = LLM(
139+
model="JackFram/llama-68m",
140+
disable_log_stats=False,
141+
enforce_eager=args.enforce_eager,
142+
)
143+
else:
144+
llm = LLM(
145+
model="JackFram/llama-68m",
146+
speculative_config={
147+
"model": "abhigoyal/vllm-medusa-llama-68m-random",
148+
"num_speculative_tokens": args.num_spec_tokens,
149+
},
150+
disable_log_stats=False,
151+
enforce_eager=args.enforce_eager,
152+
)
153+
154+
result_dict = time_generation(llm, prompts, sampling_params,
155+
args.num_spec_tokens)
156+
result_queue.put((task_key, result_dict))
157+
158+
159+
def test_mtp_model(is_enable, args, prompts, sampling_params, task_key,
160+
result_queue):
161+
if not is_enable:
162+
llm = LLM(
163+
model="Qwen/Qwen3-4B",
164+
disable_log_stats=False,
165+
)
166+
else:
167+
llm = LLM(
168+
model="Qwen/Qwen3-4B",
169+
speculative_config={
170+
"method": "deepseek_mtp",
171+
"model": "Qwen/Qwen3-0.6B",
172+
"num_speculative_tokens": args.num_spec_tokens,
173+
},
174+
disable_log_stats=False,
175+
)
176+
177+
result_dict = time_generation(llm, prompts, sampling_params,
178+
args.num_spec_tokens)
179+
result_queue.put((task_key, result_dict))
180+
181+
182+
if __name__ == "__main__":
183+
multiprocessing.set_start_method("spawn", force=True)
184+
parser = argparse.ArgumentParser(description="Test spec decode.")
185+
parser.add_argument("--batch_size", type=int, default=8)
186+
parser.add_argument("--osl", type=int, default=50)
187+
parser.add_argument("--num_spec_tokens",
188+
type=int,
189+
default=1,
190+
help="Number of speculative tokens to generate.")
191+
parser.add_argument("--task",
192+
type=str,
193+
default="eagle",
194+
help="Tasks to run the evaluation on.")
195+
parser.add_argument(
196+
"--run_base",
197+
action="store_true",
198+
help="Run the baseline tasks without speculative decoding.")
199+
parser.add_argument("--enforce_eager",
200+
action="store_true",
201+
help="Enforce eager execution for Eagle model.")
202+
203+
# 'ngram', 'eagle', 'eagle3', 'medusa', 'mlp_speculator',
204+
# 'draft_model' or 'deepseek_mtp
205+
# V1 does not support draft_model yet.
206+
# MLP speculator => https://github.com/vllm-project/vllm/pull/21276
207+
args = parser.parse_args()
208+
209+
# Sample prompts.
210+
prompts = [
211+
"Hello, my name is",
212+
"The president of the United States is",
213+
"The capital of France is",
214+
"The future of AI is",
215+
"San Francisco is know for its",
216+
"Facebook was created in 2004 by",
217+
"Curious George is a",
218+
"Python 3.11 brings improvements to its",
219+
]
220+
if args.batch_size < len(prompts):
221+
prompts = prompts[:args.batch_size]
222+
else:
223+
prompts = prompts * (args.batch_size // len(prompts)
224+
) + prompts[:args.batch_size % len(prompts)]
225+
226+
sampling_params = SamplingParams(temperature=0,
227+
max_tokens=args.osl,
228+
ignore_eos=True)
229+
230+
task_queue = {}
231+
result_queue = multiprocessing.Queue()
232+
task = args.task
233+
if task == "ngram":
234+
if args.run_base:
235+
task_queue['baseline_ngram'] = {
236+
'proc':
237+
multiprocessing.Process(target=test_ngram,
238+
args=(False, args, prompts,
239+
sampling_params,
240+
'baseline_ngram', result_queue))
241+
}
242+
task_queue['spec_ngram'] = {
243+
'proc':
244+
multiprocessing.Process(target=test_ngram,
245+
args=(True, args, prompts, sampling_params,
246+
'spec_ngram', result_queue))
247+
}
248+
elif task == "deepseek_mtp":
249+
if args.run_base:
250+
task_queue['baseline_mtp'] = {
251+
'proc':
252+
multiprocessing.Process(target=test_mtp_model,
253+
args=(False, args, prompts,
254+
sampling_params, 'baseline_mtp',
255+
result_queue))
256+
}
257+
task_queue['spec_mtp'] = {
258+
'proc':
259+
multiprocessing.Process(target=test_mtp_model,
260+
args=(True, args, prompts, sampling_params,
261+
'spec_mtp', result_queue))
262+
}
263+
elif task == "eagle":
264+
if args.run_base:
265+
task_queue['baseline_eagle'] = {
266+
'proc':
267+
multiprocessing.Process(target=test_eagle_model,
268+
args=(False, args, prompts,
269+
sampling_params,
270+
'baseline_eagle', result_queue))
271+
}
272+
task_queue['spec_eagle'] = {
273+
'proc':
274+
multiprocessing.Process(target=test_eagle_model,
275+
args=(True, args, prompts, sampling_params,
276+
'spec_eagle', result_queue))
277+
}
278+
elif task == "medusa":
279+
if args.run_base:
280+
task_queue['baseline_eagle'] = {
281+
'proc':
282+
multiprocessing.Process(target=test_medusa_model,
283+
args=(False, args, prompts,
284+
sampling_params,
285+
'baseline_medusa', result_queue))
286+
}
287+
task_queue['spec_medusa'] = {
288+
'proc':
289+
multiprocessing.Process(target=test_medusa_model,
290+
args=(True, args, prompts, sampling_params,
291+
'spec_medusa', result_queue))
292+
}
293+
294+
try:
295+
for key, task in task_queue.items():
296+
logging.info(
297+
"=============== Starting task: %s ====================", key)
298+
task['proc'].start()
299+
task['proc'].join()
300+
logging.info(
301+
"=============== Task %s completed. ====================", key)
302+
for _ in range(len(task_queue)):
303+
key, result_data = result_queue.get()
304+
task_queue[key]['result'] = result_data
305+
except KeyboardInterrupt:
306+
logging.info("Interrupted by user, terminating processes...")
307+
finally:
308+
for key, proc in task_queue.items():
309+
print(f"================= {key} =================")
310+
print(f"latency: {proc['result']['latency']}")
311+
print(f"acc_counts: {proc['result']['acc_counts']}")
312+
print(f"acc_rate: {proc['result']['acc_rate']}")
313+
print(f"num_draft_tokens: {proc['result']['num_draft_tokens']}")
314+
print(f"num_drafts: {proc['result']['num_drafts']}")
315+
for prompt, text in zip(prompts, proc['result']['ret_spec']):
316+
print("---")
317+
print(f"Prompt: {prompt}")
318+
print(f"Generated text: {text[:200]}'...'")
319+
print("=========================================")
320+
if proc['proc'].is_alive():
321+
proc['proc'].terminate()
322+
proc['proc'].join(timeout=2)
323+
logging.info("Benchmark finished.")

vllm_gaudi/ops/hpu_rotary_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def forward_oot(
6060
if hasattr(self, "scaling_factors") or hasattr(
6161
self, "scaling_factor") or self.sin is None:
6262
self.prepare_cos_sin(positions, offsets)
63-
num_tokens = positions.shape[0] * positions.shape[1]
63+
num_tokens = positions.numel()
6464
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
6565
# to query hidden dimension, so the original tensors need to be
6666
# expanded

vllm_gaudi/platform.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
3939
has_sink: bool) -> str:
4040
if use_v1 and not use_mla:
4141
logger.info("Using HPUAttentionV1 backend.")
42-
return "vllm_gaudi.attention.backends.hpu_attn.HPUAttentionBackend"
42+
return "vllm_gaudi.v1.attention.backends.hpu_attn.HPUAttentionBackendV1"
4343
if use_v1 and use_mla:
4444
logger.info("Using HPUAttentionMLA backend.")
4545
return ("vllm_gaudi.attention.backends.hpu_attn."
@@ -108,6 +108,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
108108
# Activate custom ops for v1.
109109
compilation_config.custom_ops = ["all"]
110110
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
111+
compilation_config.cudagraph_capture_sizes = []
111112

112113
if compilation_config.level != CompilationLevel.NO_COMPILATION:
113114
logger.info("[HPU] Forcing CompilationLevel.NO_COMPILATION "

0 commit comments

Comments
 (0)