From bd6e38d81950aae311212628af4881bf12856c0b Mon Sep 17 00:00:00 2001 From: wangxiaoxin-sherie Date: Thu, 23 Oct 2025 14:32:43 +0800 Subject: [PATCH] support FULL_AND_PIECEWISE graph mode Signed-off-by: wangxiaoxin-sherie --- tests/e2e/multicard/test_full_graph_mode.py | 46 ++++++++++++++++++- vllm_ascend/compilation/acl_graph.py | 2 +- vllm_ascend/platform.py | 12 ++--- vllm_ascend/torchair/torchair_model_runner.py | 4 +- vllm_ascend/worker/model_runner_v1.py | 11 ++++- 5 files changed, 64 insertions(+), 11 deletions(-) diff --git a/tests/e2e/multicard/test_full_graph_mode.py b/tests/e2e/multicard/test_full_graph_mode.py index 3b9f293230..6222ead064 100644 --- a/tests/e2e/multicard/test_full_graph_mode.py +++ b/tests/e2e/multicard/test_full_graph_mode.py @@ -29,7 +29,7 @@ from tests.e2e.model_utils import check_outputs_equal -def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH(): +def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL_DECODE_ONLY(): if 'HCCL_OP_EXPANSION_MODE' in os.environ: del os.environ['HCCL_OP_EXPANSION_MODE'] prompts = [ @@ -70,3 +70,47 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH(): name_0="vllm_eager_outputs", name_1="vllm_fullgraph_outputs", ) + + +def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL_AND_PIECEIWSE(): + if 'HCCL_OP_EXPANSION_MODE' in os.environ: + del os.environ['HCCL_OP_EXPANSION_MODE'] + prompts = [ + "Hello, my name is", "The president of the United States is", + "The capital of France is", "The future of AI is" + ] + model = "Qwen/Qwen3-30B-A3B" + sampling_params = SamplingParams(max_tokens=32, temperature=0.0) + with VllmRunner( + model, + max_model_len=1024, + tensor_parallel_size=2, + enforce_eager=False, + compilation_config={"cudagraph_mode": + "FULL_AND_PIECEIWSE"}) as runner: + vllm_fullgraph_outputs = runner.model.generate(prompts, + sampling_params) + + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=True, + ) as runner: + vllm_eager_outputs = runner.model.generate(prompts, sampling_params) + + vllm_fullgraph_outputs_list = [] + for output in vllm_fullgraph_outputs: + vllm_fullgraph_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + vllm_eager_outputs_list = [] + for output in vllm_eager_outputs: + vllm_eager_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs_list, + outputs_1_lst=vllm_fullgraph_outputs_list, + name_0="vllm_eager_outputs", + name_1="vllm_fullgraph_outputs", + ) diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 2ba6b253cd..326756806e 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -332,7 +332,7 @@ def set_graph_params(aclgraph_capture_sizes: set[int]): def update_graph_params_workspaces(num_tokens: int, workspace: int): global _graph_params if _graph_params is not None: - _graph_params.workspaces[num_tokens] = workspace + _graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace) def get_graph_params(): diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index d8cf5251ea..3ca05cf7ae 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -223,16 +223,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: vllm_config.compilation_config.init_with_cudagraph_sizes( sp_aclgraph_sizes) - # TODO: Full graph is fully supported later, and the default value will be set to full graph. - if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: - compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - if compilation_config.cudagraph_mode == CUDAGraphMode.NONE: compilation_config.level = CompilationLevel.NO_COMPILATION - elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE: + elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or \ + compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: logger.info( - "PIECEWISE compilation enabled on NPU. use_inductor not supported - " - "using only ACL Graph mode") + f"{compilation_config.cudagraph_mode} compilation enabled on NPU. use_inductor not supported - " + f"using only ACL Graph mode" + f"{compilation_config.cudagraph_mode}") assert compilation_config.level == CompilationLevel.PIECEWISE, \ "When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE" compilation_config.set_splitting_ops_for_v1() diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 51c9508e5c..be32fdec48 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -21,6 +21,7 @@ import types from typing import Any, Optional +import numpy as np import torch import torch.distributed as dist import torch.nn as nn @@ -153,6 +154,7 @@ def _build_dummy_attn_metadata( num_reqs: int, num_tokens: int, max_query_len: int, + num_scheduled_tokens: np.ndarray[Any, Any], aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, force_attention: bool = False, ) -> Optional[dict[str, Any]]: @@ -161,7 +163,7 @@ def _build_dummy_attn_metadata( if with_prefill or self.enable_shared_expert_dp: attn_metadata = super()._build_dummy_attn_metadata( with_prefill, num_reqs, num_tokens, max_query_len, - aclgraph_runtime_mode, force_attention) + num_scheduled_tokens, aclgraph_runtime_mode, force_attention) else: common_attn_metadata = TorchairCommonAttentionMetadata( num_reqs=num_reqs, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0bfe0f847c..c2c887836a 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -377,7 +377,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): device=self.device) if self.vllm_config.model_config.use_mla and \ - self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + (self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or \ + self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE): rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.cos = torch.ones(self.max_num_reqs * self.decode_token_per_req, @@ -2260,6 +2261,7 @@ def _build_dummy_attn_metadata( num_reqs: int, num_tokens: int, max_query_len: int, + num_scheduled_tokens: np.ndarray[Any, Any], aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, force_attention: bool = False, ) -> Optional[dict[str, Any]]: @@ -2275,6 +2277,12 @@ def _build_dummy_attn_metadata( self.seq_lens_np[:num_reqs] = seq_lens self.seq_lens_np[num_reqs:] = 0 + if num_scheduled_tokens is not None: + cu_num_tokens, arange = self._get_cumsum_and_arange( + num_scheduled_tokens) + self.query_start_loc_cpu[1:num_reqs + + 1] = torch.from_numpy(cu_num_tokens) + num_computed_tokens_cpu = ( self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) @@ -2474,6 +2482,7 @@ def _dummy_run( max_query_len=max_query_len, aclgraph_runtime_mode=aclgraph_runtime_mode, force_attention=force_attention, + num_scheduled_tokens=num_scheduled_tokens, ) need_dummy_logits = (not self.in_profile_run