Skip to content

Commit c9ff559

Browse files
committed
add new e2e tests for aclgraph
Signed-off-by: lilinsiman <[email protected]>
1 parent e4acb2d commit c9ff559

File tree

3 files changed

+277
-0
lines changed

3 files changed

+277
-0
lines changed

.github/workflows/_e2e_test.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ jobs:
8989
# the test separately.
9090
9191
pytest -sv tests/e2e/singlecard/test_aclgraph.py
92+
pytest -sv tests/e2e/singlecard/test_aclgraph_mem.py
9293
pytest -sv tests/e2e/singlecard/test_ascend_scheduler.py
9394
pytest -sv tests/e2e/singlecard/test_camem.py
9495
pytest -sv tests/e2e/singlecard/test_chunked.py
@@ -178,6 +179,7 @@ jobs:
178179
# pytest -sv tests/e2e/multicard/test_external_launcher.py
179180
pytest -sv tests/e2e/multicard/test_fused_moe_allgather_ep.py
180181
pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py
182+
pytest -sv tests/e2e/multicard/test_aclgraph_replay_capture.py
181183
182184
# To avoid oom, we need to run the test in a single process.
183185
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import contextlib
19+
import gc
20+
import math
21+
import multiprocessing
22+
import os
23+
import sys
24+
from time import sleep
25+
from unittest.mock import patch
26+
27+
import pytest
28+
import torch
29+
from vllm import LLM, SamplingParams
30+
from vllm.distributed.parallel_state import ( # noqa E402
31+
destroy_distributed_environment, destroy_model_parallel)
32+
33+
MODELS = ["Qwen/Qwen3-0.6B", "vllm-ascend/DeepSeek-V2-Lite-W8A8"]
34+
35+
36+
@pytest.mark.parametrize("model", MODELS)
37+
@pytest.mark.parametrize("max_tokens", [4])
38+
@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"})
39+
def test_aclgraph_capture_replay_dp2(
40+
model: str,
41+
max_tokens: int,
42+
) -> None:
43+
# HCCL_OP_EXPANSION_MODE determines how max_num_batch_sizes is computed.
44+
if 'VLLM_WORKER_MULTIPROC_METHOD' in os.environ:
45+
del os.environ["VLLM_WORKER_MULTIPROC_METHOD"]
46+
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
47+
del os.environ['HCCL_OP_EXPANSION_MODE']
48+
dp_size = 2
49+
tp_size = 1
50+
replay_counter = multiprocessing.Value("i", 0)
51+
capture_counter = multiprocessing.Value("i", 0)
52+
num_hidden_layers_shared = multiprocessing.Value("i", -1)
53+
num_execute_model_shared = multiprocessing.Value("i", 0)
54+
dp_master_ip = "127.0.0.1"
55+
dp_master_port = 11011
56+
57+
def dp_rank_main(global_dp_rank: int, local_dp_rank: int):
58+
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
59+
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
60+
os.environ["VLLM_DP_SIZE"] = str(dp_size)
61+
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
62+
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
63+
64+
original_replay = torch.npu.NPUGraph.replay
65+
66+
def replay_wrapper(self):
67+
with replay_counter.get_lock():
68+
replay_counter.value += 1
69+
return original_replay(self)
70+
71+
original_init = torch.npu.NPUGraph.__init__
72+
73+
def init_wrapper(self, *args, **kwargs):
74+
with capture_counter.get_lock():
75+
capture_counter.value += 1
76+
return original_init(self, *args, **kwargs)
77+
78+
with patch.object(torch.npu.NPUGraph, "replay", replay_wrapper), \
79+
patch.object(torch.npu.NPUGraph, "__init__", init_wrapper):
80+
prompts = [
81+
"Hello, my name is", "The president of the United States is",
82+
"The capital of France is", "The future of AI is"
83+
]
84+
chunk_size = len(prompts) // dp_size
85+
start = global_dp_rank * chunk_size
86+
end = start + chunk_size if global_dp_rank < dp_size - 1 else len(
87+
prompts)
88+
my_prompts = prompts[start:end]
89+
sampling_params = SamplingParams(max_tokens=max_tokens,
90+
temperature=0.0)
91+
92+
def trace_calls(frame, event, arg):
93+
if event == 'call':
94+
code = frame.f_code
95+
func_name = code.co_name
96+
file_name = code.co_filename
97+
if func_name == 'execute_dummy_batch' and 'worker_v1.py' in file_name:
98+
with num_execute_model_shared.get_lock():
99+
num_execute_model_shared.value += 1
100+
return trace_calls
101+
102+
sys.settrace(trace_calls)
103+
if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8":
104+
llm = LLM(
105+
model=model,
106+
quantization="ascend",
107+
tensor_parallel_size=tp_size,
108+
trust_remote_code=True,
109+
)
110+
else:
111+
llm = LLM(
112+
model=model,
113+
tensor_parallel_size=tp_size,
114+
trust_remote_code=True,
115+
)
116+
num_hidden_layers_shared.value = llm.llm_engine.model_config.hf_config.num_hidden_layers
117+
_ = llm.generate(my_prompts, sampling_params)
118+
sys.settrace(None)
119+
120+
# Give engines time to pause their processing loops before exiting.
121+
sleep(5)
122+
del llm
123+
cleanup_env_and_memory()
124+
125+
processes = []
126+
for local_dp_rank in range(dp_size):
127+
global_dp_rank = local_dp_rank
128+
p = multiprocessing.Process(target=dp_rank_main,
129+
args=(global_dp_rank, local_dp_rank))
130+
p.start()
131+
processes.append(p)
132+
133+
for p in processes:
134+
p.join(timeout=900)
135+
if p.exitcode != 0:
136+
if p.exitcode is None:
137+
p.kill()
138+
raise RuntimeError(f"Process {p.pid} timed out")
139+
else:
140+
raise RuntimeError(
141+
f"Process failed with exit code {p.exitcode}")
142+
143+
actual_capture = capture_counter.value
144+
actual_replay = replay_counter.value
145+
num_hidden_layers = num_hidden_layers_shared.value
146+
num_execute_model = num_execute_model_shared.value
147+
148+
num_acl_graphs = num_hidden_layers + 1
149+
num_comm_groups = sum(size > 1 for size in [
150+
dp_size,
151+
tp_size,
152+
])
153+
max_num_batch_sizes = math.floor(
154+
(1800 - num_comm_groups * 40) / num_acl_graphs /
155+
(1 + num_comm_groups * 2))
156+
expected_total_capture = max_num_batch_sizes * num_acl_graphs * dp_size
157+
assert actual_capture == expected_total_capture, (
158+
f"capture count mismatch. Expected: {expected_total_capture}, Got: {actual_capture}"
159+
)
160+
161+
num_inference_steps = max_tokens + 1 # first token + max_tokens
162+
expected_total_replay = num_acl_graphs * num_inference_steps * dp_size + num_execute_model * num_acl_graphs
163+
assert actual_replay == expected_total_replay, (
164+
f"Replay count mismatch. Expected: {expected_total_replay}, Got: {actual_replay}"
165+
)
166+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = 'spawn'
167+
168+
169+
def cleanup_env_and_memory():
170+
destroy_model_parallel()
171+
destroy_distributed_environment()
172+
with contextlib.suppress(AssertionError):
173+
torch.distributed.destroy_process_group()
174+
gc.collect()
175+
torch.npu.empty_cache()
176+
torch.npu.reset_peak_memory_stats()
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import multiprocessing
19+
import os
20+
from unittest.mock import patch
21+
22+
import pytest
23+
import torch
24+
from modelscope import snapshot_download # type: ignore
25+
from vllm import LLM, SamplingParams
26+
27+
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
28+
29+
MODELS = ["Qwen/Qwen3-0.6B", "vllm-ascend/DeepSeek-V2-Lite-W8A8"]
30+
31+
32+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
33+
reason="aclgraph only support on v1")
34+
@pytest.mark.parametrize("model", MODELS)
35+
@pytest.mark.parametrize("max_tokens", [4])
36+
@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"})
37+
def test_aclgraph_mem_use(model: str, max_tokens: int) -> None:
38+
del os.environ["VLLM_WORKER_MULTIPROC_METHOD"]
39+
capture_called = multiprocessing.Value("i", 0) # int, 0 or 1
40+
capture_mem_before = multiprocessing.Value("q", -1) # long long (64-bit)
41+
capture_mem_after = multiprocessing.Value("q", -1) # long long
42+
43+
def capture_model_wrapper(original_method):
44+
45+
def wrapped(self):
46+
mem_before = torch.npu.mem_get_info()[0] # free memory
47+
result = original_method(self)
48+
mem_after = torch.npu.mem_get_info()[0]
49+
with capture_called.get_lock():
50+
capture_called.value = 1
51+
capture_mem_before.value = mem_before
52+
capture_mem_after.value = mem_after
53+
return result
54+
55+
return wrapped
56+
57+
original_capture = NPUModelRunner._capture_model
58+
59+
with patch.object(NPUModelRunner,
60+
'_capture_model',
61+
new=capture_model_wrapper(original_capture)):
62+
prompts = [
63+
"Hello, my name is", "The president of the United States is",
64+
"The capital of France is", "The future of AI is"
65+
]
66+
sampling_params = SamplingParams(max_tokens=max_tokens,
67+
temperature=0.0)
68+
if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8":
69+
vllm_model = LLM(snapshot_download(model),
70+
max_model_len=1024,
71+
quantization="ascend")
72+
else:
73+
vllm_model = LLM(snapshot_download(model))
74+
_ = vllm_model.generate(prompts, sampling_params)
75+
76+
assert capture_called.value == 1, "_capture_model was not called during test"
77+
assert capture_mem_before.value != -1, "capture_mem_before not set"
78+
assert capture_mem_after.value != -1, "capture_mem_after not set"
79+
80+
print("capture_mem_before =", capture_mem_before.value)
81+
print("capture_mem_after =", capture_mem_after.value)
82+
83+
mem_used_by_capture = capture_mem_before.value - capture_mem_after.value
84+
# Empirical observation: capturing ACL graphs for Qwen3-0.6B uses ~0.01 GiB of NPU memory.
85+
# DeepSeek-V2-Lite-W8A8 uses ~0.57 GiB of NPU memory
86+
# a 1.3x tolerance is applied to account for runtime variance.
87+
if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8":
88+
baseline_capture_mem = 0.57
89+
capture_mem_tolerance = 1.3
90+
else:
91+
baseline_capture_mem = 0.20
92+
capture_mem_tolerance = 1.3
93+
max_capture_mem_gib = baseline_capture_mem * capture_mem_tolerance
94+
max_mem_expected = max_capture_mem_gib * (1024**3)
95+
assert mem_used_by_capture < max_mem_expected, (
96+
f"_capture_model used more memory than expected. "
97+
f"Used: {mem_used_by_capture / (1024**3):.2f} GiB, "
98+
f"Expected: < {max_capture_mem_gib:.2f} GiB")
99+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = 'spawn'

0 commit comments

Comments
 (0)