|  | 
|  | 1 | +# | 
|  | 2 | +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | 
|  | 3 | +# Copyright 2023 The vLLM team. | 
|  | 4 | +# | 
|  | 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 6 | +# you may not use this file except in compliance with the License. | 
|  | 7 | +# You may obtain a copy of the License at | 
|  | 8 | +# | 
|  | 9 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 10 | +# | 
|  | 11 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 12 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 14 | +# See the License for the specific language governing permissions and | 
|  | 15 | +# limitations under the License. | 
|  | 16 | +# | 
|  | 17 | + | 
|  | 18 | +import contextlib | 
|  | 19 | +import gc | 
|  | 20 | +import math | 
|  | 21 | +import multiprocessing | 
|  | 22 | +import os | 
|  | 23 | +import sys | 
|  | 24 | +from time import sleep | 
|  | 25 | +from unittest.mock import patch | 
|  | 26 | + | 
|  | 27 | +import pytest | 
|  | 28 | +import torch | 
|  | 29 | +from vllm import LLM, SamplingParams | 
|  | 30 | +from vllm.distributed.parallel_state import (  # noqa E402 | 
|  | 31 | +    destroy_distributed_environment, destroy_model_parallel) | 
|  | 32 | + | 
|  | 33 | +MODELS = ["Qwen/Qwen3-0.6B", "vllm-ascend/DeepSeek-V2-Lite-W8A8"] | 
|  | 34 | + | 
|  | 35 | + | 
|  | 36 | +@pytest.mark.parametrize("model", MODELS) | 
|  | 37 | +@pytest.mark.parametrize("max_tokens", [4]) | 
|  | 38 | +@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"}) | 
|  | 39 | +def test_aclgraph_capture_replay_dp2( | 
|  | 40 | +    model: str, | 
|  | 41 | +    max_tokens: int, | 
|  | 42 | +) -> None: | 
|  | 43 | +    # HCCL_OP_EXPANSION_MODE determines how max_num_batch_sizes is computed. | 
|  | 44 | +    if 'VLLM_WORKER_MULTIPROC_METHOD' in os.environ: | 
|  | 45 | +        del os.environ["VLLM_WORKER_MULTIPROC_METHOD"] | 
|  | 46 | +    if 'HCCL_OP_EXPANSION_MODE' in os.environ: | 
|  | 47 | +        del os.environ['HCCL_OP_EXPANSION_MODE'] | 
|  | 48 | +    dp_size = 2 | 
|  | 49 | +    tp_size = 1 | 
|  | 50 | +    replay_counter = multiprocessing.Value("i", 0) | 
|  | 51 | +    capture_counter = multiprocessing.Value("i", 0) | 
|  | 52 | +    num_hidden_layers_shared = multiprocessing.Value("i", -1) | 
|  | 53 | +    num_execute_model_shared = multiprocessing.Value("i", 0) | 
|  | 54 | +    dp_master_ip = "127.0.0.1" | 
|  | 55 | +    dp_master_port = 11011 | 
|  | 56 | + | 
|  | 57 | +    def dp_rank_main(global_dp_rank: int, local_dp_rank: int): | 
|  | 58 | +        os.environ["VLLM_DP_RANK"] = str(global_dp_rank) | 
|  | 59 | +        os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) | 
|  | 60 | +        os.environ["VLLM_DP_SIZE"] = str(dp_size) | 
|  | 61 | +        os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip | 
|  | 62 | +        os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port) | 
|  | 63 | + | 
|  | 64 | +        original_replay = torch.npu.NPUGraph.replay | 
|  | 65 | + | 
|  | 66 | +        def replay_wrapper(self): | 
|  | 67 | +            with replay_counter.get_lock(): | 
|  | 68 | +                replay_counter.value += 1 | 
|  | 69 | +            return original_replay(self) | 
|  | 70 | + | 
|  | 71 | +        original_init = torch.npu.NPUGraph.__init__ | 
|  | 72 | + | 
|  | 73 | +        def init_wrapper(self, *args, **kwargs): | 
|  | 74 | +            with capture_counter.get_lock(): | 
|  | 75 | +                capture_counter.value += 1 | 
|  | 76 | +            return original_init(self, *args, **kwargs) | 
|  | 77 | + | 
|  | 78 | +        with patch.object(torch.npu.NPUGraph, "replay", replay_wrapper), \ | 
|  | 79 | +             patch.object(torch.npu.NPUGraph, "__init__", init_wrapper): | 
|  | 80 | +            prompts = [ | 
|  | 81 | +                "Hello, my name is", "The president of the United States is", | 
|  | 82 | +                "The capital of France is", "The future of AI is" | 
|  | 83 | +            ] | 
|  | 84 | +            chunk_size = len(prompts) // dp_size | 
|  | 85 | +            start = global_dp_rank * chunk_size | 
|  | 86 | +            end = start + chunk_size if global_dp_rank < dp_size - 1 else len( | 
|  | 87 | +                prompts) | 
|  | 88 | +            my_prompts = prompts[start:end] | 
|  | 89 | +            sampling_params = SamplingParams(max_tokens=max_tokens, | 
|  | 90 | +                                             temperature=0.0) | 
|  | 91 | + | 
|  | 92 | +            def trace_calls(frame, event, arg): | 
|  | 93 | +                if event == 'call': | 
|  | 94 | +                    code = frame.f_code | 
|  | 95 | +                    func_name = code.co_name | 
|  | 96 | +                    file_name = code.co_filename | 
|  | 97 | +                    if func_name == 'execute_dummy_batch' and 'worker_v1.py' in file_name: | 
|  | 98 | +                        with num_execute_model_shared.get_lock(): | 
|  | 99 | +                            num_execute_model_shared.value += 1 | 
|  | 100 | +                return trace_calls | 
|  | 101 | + | 
|  | 102 | +            sys.settrace(trace_calls) | 
|  | 103 | +            if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8": | 
|  | 104 | +                llm = LLM( | 
|  | 105 | +                    model=model, | 
|  | 106 | +                    quantization="ascend", | 
|  | 107 | +                    tensor_parallel_size=tp_size, | 
|  | 108 | +                    trust_remote_code=True, | 
|  | 109 | +                ) | 
|  | 110 | +            else: | 
|  | 111 | +                llm = LLM( | 
|  | 112 | +                    model=model, | 
|  | 113 | +                    tensor_parallel_size=tp_size, | 
|  | 114 | +                    trust_remote_code=True, | 
|  | 115 | +                ) | 
|  | 116 | +            num_hidden_layers_shared.value = llm.llm_engine.model_config.hf_config.num_hidden_layers | 
|  | 117 | +            _ = llm.generate(my_prompts, sampling_params) | 
|  | 118 | +            sys.settrace(None) | 
|  | 119 | + | 
|  | 120 | +            # Give engines time to pause their processing loops before exiting. | 
|  | 121 | +            sleep(5) | 
|  | 122 | +            del llm | 
|  | 123 | +            cleanup_env_and_memory() | 
|  | 124 | + | 
|  | 125 | +    processes = [] | 
|  | 126 | +    for local_dp_rank in range(dp_size): | 
|  | 127 | +        global_dp_rank = local_dp_rank | 
|  | 128 | +        p = multiprocessing.Process(target=dp_rank_main, | 
|  | 129 | +                                    args=(global_dp_rank, local_dp_rank)) | 
|  | 130 | +        p.start() | 
|  | 131 | +        processes.append(p) | 
|  | 132 | + | 
|  | 133 | +    for p in processes: | 
|  | 134 | +        p.join(timeout=900) | 
|  | 135 | +        if p.exitcode != 0: | 
|  | 136 | +            if p.exitcode is None: | 
|  | 137 | +                p.kill() | 
|  | 138 | +                raise RuntimeError(f"Process {p.pid} timed out") | 
|  | 139 | +            else: | 
|  | 140 | +                raise RuntimeError( | 
|  | 141 | +                    f"Process failed with exit code {p.exitcode}") | 
|  | 142 | + | 
|  | 143 | +    actual_capture = capture_counter.value | 
|  | 144 | +    actual_replay = replay_counter.value | 
|  | 145 | +    num_hidden_layers = num_hidden_layers_shared.value | 
|  | 146 | +    num_execute_model = num_execute_model_shared.value | 
|  | 147 | + | 
|  | 148 | +    num_acl_graphs = num_hidden_layers + 1 | 
|  | 149 | +    num_comm_groups = sum(size > 1 for size in [ | 
|  | 150 | +        dp_size, | 
|  | 151 | +        tp_size, | 
|  | 152 | +    ]) | 
|  | 153 | +    max_num_batch_sizes = math.floor( | 
|  | 154 | +        (1800 - num_comm_groups * 40) / num_acl_graphs / | 
|  | 155 | +        (1 + num_comm_groups * 2)) | 
|  | 156 | +    expected_total_capture = max_num_batch_sizes * num_acl_graphs * dp_size | 
|  | 157 | +    assert actual_capture == expected_total_capture, ( | 
|  | 158 | +        f"capture count mismatch. Expected: {expected_total_capture}, Got: {actual_capture}" | 
|  | 159 | +    ) | 
|  | 160 | + | 
|  | 161 | +    num_inference_steps = max_tokens + 1  # first token + max_tokens | 
|  | 162 | +    expected_total_replay = num_acl_graphs * num_inference_steps * dp_size + num_execute_model * num_acl_graphs | 
|  | 163 | +    assert actual_replay == expected_total_replay, ( | 
|  | 164 | +        f"Replay count mismatch. Expected: {expected_total_replay}, Got: {actual_replay}" | 
|  | 165 | +    ) | 
|  | 166 | +    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = 'spawn' | 
|  | 167 | + | 
|  | 168 | + | 
|  | 169 | +def cleanup_env_and_memory(): | 
|  | 170 | +    destroy_model_parallel() | 
|  | 171 | +    destroy_distributed_environment() | 
|  | 172 | +    with contextlib.suppress(AssertionError): | 
|  | 173 | +        torch.distributed.destroy_process_group() | 
|  | 174 | +    gc.collect() | 
|  | 175 | +    torch.npu.empty_cache() | 
|  | 176 | +    torch.npu.reset_peak_memory_stats() | 
0 commit comments