Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion tests/e2e/multicard/test_full_graph_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from tests.e2e.model_utils import check_outputs_equal


def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL_DECODE_ONLY():
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
del os.environ['HCCL_OP_EXPANSION_MODE']
prompts = [
Expand Down Expand Up @@ -70,3 +70,47 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
name_0="vllm_eager_outputs",
name_1="vllm_fullgraph_outputs",
)


def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL_AND_PIECEIWSE():
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
del os.environ['HCCL_OP_EXPANSION_MODE']
prompts = [
"Hello, my name is", "The president of the United States is",
"The capital of France is", "The future of AI is"
]
model = "Qwen/Qwen3-30B-A3B"
sampling_params = SamplingParams(max_tokens=32, temperature=0.0)
with VllmRunner(
model,
max_model_len=1024,
tensor_parallel_size=2,
enforce_eager=False,
compilation_config={"cudagraph_mode":
"FULL_AND_PIECEIWSE"}) as runner:
vllm_fullgraph_outputs = runner.model.generate(prompts,
sampling_params)

with VllmRunner(
model,
max_model_len=1024,
enforce_eager=True,
) as runner:
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)

vllm_fullgraph_outputs_list = []
for output in vllm_fullgraph_outputs:
vllm_fullgraph_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))

vllm_eager_outputs_list = []
for output in vllm_eager_outputs:
vllm_eager_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))

check_outputs_equal(
outputs_0_lst=vllm_eager_outputs_list,
outputs_1_lst=vllm_fullgraph_outputs_list,
name_0="vllm_eager_outputs",
name_1="vllm_fullgraph_outputs",
)
2 changes: 1 addition & 1 deletion vllm_ascend/compilation/acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
def update_graph_params_workspaces(num_tokens: int, workspace: int):
global _graph_params
if _graph_params is not None:
_graph_params.workspaces[num_tokens] = workspace
_graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)


def get_graph_params():
Expand Down
12 changes: 5 additions & 7 deletions vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,16 +223,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
vllm_config.compilation_config.init_with_cudagraph_sizes(
sp_aclgraph_sizes)

# TODO: Full graph is fully supported later, and the default value will be set to full graph.
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
compilation_config.level = CompilationLevel.NO_COMPILATION
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or \
compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
logger.info(
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode")
f"{compilation_config.cudagraph_mode} compilation enabled on NPU. use_inductor not supported - "
f"using only ACL Graph mode"
f"{compilation_config.cudagraph_mode}")
assert compilation_config.level == CompilationLevel.PIECEWISE, \
"When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE"
compilation_config.set_splitting_ops_for_v1()
Expand Down
4 changes: 3 additions & 1 deletion vllm_ascend/torchair/torchair_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import types
from typing import Any, Optional

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
Expand Down Expand Up @@ -153,6 +154,7 @@ def _build_dummy_attn_metadata(
num_reqs: int,
num_tokens: int,
max_query_len: int,
num_scheduled_tokens: np.ndarray[Any, Any],
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False,
) -> Optional[dict[str, Any]]:
Expand All @@ -161,7 +163,7 @@ def _build_dummy_attn_metadata(
if with_prefill or self.enable_shared_expert_dp:
attn_metadata = super()._build_dummy_attn_metadata(
with_prefill, num_reqs, num_tokens, max_query_len,
aclgraph_runtime_mode, force_attention)
num_scheduled_tokens, aclgraph_runtime_mode, force_attention)
else:
common_attn_metadata = TorchairCommonAttentionMetadata(
num_reqs=num_reqs,
Expand Down
11 changes: 10 additions & 1 deletion vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
device=self.device)

if self.vllm_config.model_config.use_mla and \
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
(self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or \
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE):
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos = torch.ones(self.max_num_reqs *
self.decode_token_per_req,
Expand Down Expand Up @@ -2260,6 +2261,7 @@ def _build_dummy_attn_metadata(
num_reqs: int,
num_tokens: int,
max_query_len: int,
num_scheduled_tokens: np.ndarray[Any, Any],
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False,
) -> Optional[dict[str, Any]]:
Expand All @@ -2275,6 +2277,12 @@ def _build_dummy_attn_metadata(
self.seq_lens_np[:num_reqs] = seq_lens
self.seq_lens_np[num_reqs:] = 0

if num_scheduled_tokens is not None:
cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens)
self.query_start_loc_cpu[1:num_reqs +
1] = torch.from_numpy(cu_num_tokens)

num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])

Expand Down Expand Up @@ -2474,6 +2482,7 @@ def _dummy_run(
max_query_len=max_query_len,
aclgraph_runtime_mode=aclgraph_runtime_mode,
force_attention=force_attention,
num_scheduled_tokens=num_scheduled_tokens,
)

need_dummy_logits = (not self.in_profile_run
Expand Down
Loading