Skip to content

Conversation

@JC-ut0
Copy link
Contributor

@JC-ut0 JC-ut0 commented Oct 18, 2025

What this PR does / why we need it?

  1. Refactor the file mtp_proposer.py, splits torchair related codes into mtp_torchair_proposer.py
  2. According to [Spec Decode] Efficient padded speculation vllm#24539, implements padded speculative decoding as described in [Performance]: Padded Speculative Decoding vllm#21984.

Does this PR introduce any user-facing change?

User can use disable_padded_drafter_batch to disable/enable padded speculation
offline example:

speculative_config={"method": "deepseek_mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False}

How was this patch tested?

  • egaer with pad/unpad:
  • aclgraph with pad/unpad
  • torchair with pad/unpad

performance test of deepseek-r1 with tp16、dp1
aclgraph with pad ITL: 168ms
aclgraph with unpad ITL: 169ms
original: 178ms

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant refactoring to support a "padded spec" mechanism for speculative decoding. The changes are extensive, especially in mtp_proposer.py, which has been aligned with a new, more general speculative decoding framework. A new MtpTorchairProposer class has been introduced to encapsulate torchair-specific logic.

My review has identified a critical issue in MtpTorchairProposer where its method signatures are out of sync with its base class, which will cause runtime errors. I've also found a couple of high-severity issues in mtp_proposer.py, including a misleading log message and dead code left over from the refactoring. These issues should be addressed to ensure correctness and maintainability.

Comment on lines 508 to 512
logger.warning(
f"Currently the eagle proposer only supports cudagraph_mode "
"PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} "
"to CUDAGraphMode.PIECEWISE"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The warning message here seems to be a copy-paste error from another file. It mentions "the eagle proposer" while being in mtp_proposer.py. This can be misleading for developers who are debugging performance issues or graph compilation behavior. It should be updated to refer to the MTP proposer.

Suggested change
logger.warning(
f"Currently the eagle proposer only supports cudagraph_mode "
"PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} "
"to CUDAGraphMode.PIECEWISE"
)
logger.warning(
f"Currently the MTP proposer only supports cudagraph_mode "
"PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} "
"to CUDAGraphMode.PIECEWISE"
)

@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@JC-ut0 JC-ut0 changed the title [Draft] Padded spec [Draft] Refactor spec decode to support efficient padded speculation Oct 20, 2025
@JC-ut0 JC-ut0 changed the title [Draft] Refactor spec decode to support efficient padded speculation [FEAT] Refactor spec decode to support efficient padded speculation Oct 20, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces padded speculative decoding and refactors the MtpProposer by separating the torchair-related logic. The changes are substantial and improve the architecture for speculative decoding. My review focuses on performance considerations in the new implementation and a bug in the tests. I've identified several areas where CPU-bound operations and CPU-GPU data transfers in performance-critical paths could be optimized by moving computations to the GPU using torch operations. Additionally, there's a critical issue in the test suite where a function is called with an incorrect argument type, which should be addressed.

Comment on lines 267 to +382
def _prepare_inputs(
self,
# [batch_size + 1]
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
token_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
slot_mapping: torch.Tensor,
is_torchair_graph: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
# [0, a, a + b, a + b + c] -> [a, b, c]
query_len_per_req = (cu_target_query_lens[1:] -
cu_target_query_lens[:-1])
# [a, b, c] -> [a - n1, b - n2, c - n3]
num_tokens_per_req = query_len_per_req - num_rejected_tokens
if is_torchair_graph:
cu_num_tokens = cu_target_query_lens
relative_index = query_len_per_req - num_rejected_tokens - 1
token_indices = cu_num_tokens[:-1] + relative_index
# the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model
target_token_ids = token_ids
target_positions = positions
target_hidden_states = hidden_states
target_slot_mapping = slot_mapping
else:
cu_num_tokens = torch.empty_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
cu_num_tokens[0] = 0

# FIXME(woosuk): Avoid synchronization.
num_tokens = cu_num_tokens[-1].item()
token_indices = torch.zeros(
num_tokens,
dtype=torch.int32,
device=cu_num_tokens.device,
)

BLOCK_SIZE = 1024
self._prepare_input_kernel(
token_indices,
cu_target_query_lens,
cu_num_tokens,
block_size=BLOCK_SIZE,
)
target_token_ids = token_ids[token_indices]
target_positions = positions[token_indices]
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = slot_mapping[token_indices]
return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: list[list[int]],
num_draft_tokens: list[int],
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It updates to the common_attn_metadata to account for the rejected
tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator.
"""
# E.g.
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1, q1 + q2, q1 + q2 + q3]
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
# num_rejected_tokens: [n1, n2, n3]
# This function computes the intermediate values:
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
# And returns:
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
# common_attn_metadata.seq_lens{_cpu}:
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
# token_indices: [0, 1, ..., q1 - n1 - 1,
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]

num_rejected_tokens = [
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(num_rejected_tokens,
dtype=torch.int32)

device = common_attn_metadata.query_start_loc.device
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens

# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
new_query_len_per_req = query_start_loc_cpu[
1:] - query_start_loc_cpu[:-1]
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

# [q1 - n1, q2 - n2, q3 - n3] ->
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
new_query_start_loc_cpu = torch.zeros(
query_start_loc_cpu.shape,
dtype=torch.int32,
pin_memory=is_pin_memory_available(),
)
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

total_num_tokens = new_query_start_loc_np[-1]
# Example assuming num_tokens_per_req_np = [2, 4, 3]
# this implies that `new_query_start_locs` is:
# [0, 2, 6, 9] ->
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
# _r1_ ____r2____ ___r3__
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
new_num_tokens_per_req_np)
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
# _r1_ ____r2____ ___r3__
token_offests = (self.token_arange_np[:total_num_tokens] -
new_query_start_locs_expanded)

# Expand starting positions to match token pattern
# [0, q1, q1 + q2] ->
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
# _r1_ _____r2_______ ___________r3____________
old_query_start_locs_expanded = np.repeat(
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
# Final token indices are:
# [0, 1, // req 1
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
token_indices_np = token_offests + old_query_start_locs_expanded
token_indices = torch.from_numpy(token_indices_np).to(
device, non_blocking=True)

spec_common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=new_query_start_loc_cpu.to(device,
non_blocking=True),
query_start_loc_cpu=new_query_start_loc_cpu,
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
seq_lens_cpu=new_seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata.
num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
positions=common_attn_metadata.positions[token_indices],
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
graph_pad_size=self.runner.graph_pad_size,
decode_token_per_req=self.runner.decode_token_per_req,
)
return spec_common_attn_metadata, token_indices
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _prepare_inputs function performs several operations on the CPU using numpy, followed by data transfers to the GPU. This can introduce significant performance overhead due to CPU-GPU synchronization. To improve performance, these operations should be performed directly on the GPU using torch equivalents. For example, np.cumsum can be replaced with torch.cumsum, and np.repeat with torch.repeat_interleave. This will avoid unnecessary data movement between CPU and GPU.

Comment on lines +659 to +674
self.backup_next_token_ids.np[:num_reqs] = np.array([
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i].item())
for i in range(num_reqs)
])
self.backup_next_token_ids.copy_to_gpu(num_reqs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The prepare_next_token_ids_padded function's docstring states that it "must use device functions to operate on the inputs, and should not introduce any blocking CPU-GPU synchronization." However, the implementation includes a CPU-side loop to precompute backup_next_token_ids and then copies the data to the GPU. This introduces the synchronization that the function aims to avoid and can be a performance bottleneck. This logic should be moved to the GPU to maintain performance.

Comment on lines 1185 to 1347
if self.speculative_config:
num_tokens = [
self.requests[r].num_tokens for r in self.input_batch.req_ids
]
num_tokens_np = np.array(num_tokens, dtype=np.int32)

# Record the index of requests that should not be sampled,
# so that we could clear the sampled tokens before returning
num_reqs = self.input_batch.num_reqs
discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np
discard_request_indices = np.nonzero(discard_requests_mask)[0]
self.num_discarded_requests = len(discard_request_indices)
self.discard_request_indices.np[:self.num_discarded_requests] = (
discard_request_indices)

self.discard_request_indices.copy_to_gpu(
self.num_discarded_requests)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to compute discard_request_indices is performed on the CPU using a Python list comprehension and numpy operations, followed by a copy to the GPU. This happens in _prepare_input_ids, which is on a performance-critical path. To avoid the overhead of CPU-GPU data transfer, this computation should be performed directly on the GPU using torch operations.

@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Signed-off-by: xuyexiong <[email protected]>
Signed-off-by: xuyexiong <[email protected]>
Signed-off-by: xuyexiong <[email protected]>
Signed-off-by: xuyexiong <[email protected]>
Signed-off-by: xuyexiong <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module:tests ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants