-
Couldn't load subscription status.
- Fork 525
[FEAT] Refactor spec decode to support efficient padded speculation #3528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant refactoring to support a "padded spec" mechanism for speculative decoding. The changes are extensive, especially in mtp_proposer.py, which has been aligned with a new, more general speculative decoding framework. A new MtpTorchairProposer class has been introduced to encapsulate torchair-specific logic.
My review has identified a critical issue in MtpTorchairProposer where its method signatures are out of sync with its base class, which will cause runtime errors. I've also found a couple of high-severity issues in mtp_proposer.py, including a misleading log message and dead code left over from the refactoring. These issues should be addressed to ensure correctness and maintainability.
| logger.warning( | ||
| f"Currently the eagle proposer only supports cudagraph_mode " | ||
| "PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} " | ||
| "to CUDAGraphMode.PIECEWISE" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The warning message here seems to be a copy-paste error from another file. It mentions "the eagle proposer" while being in mtp_proposer.py. This can be misleading for developers who are debugging performance issues or graph compilation behavior. It should be updated to refer to the MTP proposer.
| logger.warning( | |
| f"Currently the eagle proposer only supports cudagraph_mode " | |
| "PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} " | |
| "to CUDAGraphMode.PIECEWISE" | |
| ) | |
| logger.warning( | |
| f"Currently the MTP proposer only supports cudagraph_mode " | |
| "PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} " | |
| "to CUDAGraphMode.PIECEWISE" | |
| ) |
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces padded speculative decoding and refactors the MtpProposer by separating the torchair-related logic. The changes are substantial and improve the architecture for speculative decoding. My review focuses on performance considerations in the new implementation and a bug in the tests. I've identified several areas where CPU-bound operations and CPU-GPU data transfers in performance-critical paths could be optimized by moving computations to the GPU using torch operations. Additionally, there's a critical issue in the test suite where a function is called with an incorrect argument type, which should be addressed.
| def _prepare_inputs( | ||
| self, | ||
| # [batch_size + 1] | ||
| cu_target_query_lens: torch.Tensor, | ||
| # [batch_size] | ||
| num_rejected_tokens: torch.Tensor, | ||
| token_ids: torch.Tensor, | ||
| positions: torch.Tensor, | ||
| hidden_states: torch.Tensor, | ||
| slot_mapping: torch.Tensor, | ||
| is_torchair_graph: bool = False | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, | ||
| torch.Tensor, torch.Tensor]: | ||
| # cu_target_query_lens: [0, a, a + b, a + b + c] | ||
| # num_rejected_tokens: [n1, n2, n3] | ||
| # num_tokens_per_req: [a - n1, b - n2, c - n3] | ||
| # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] | ||
| # token_indices: [0, 1, ..., a - n1 - 1, | ||
| # a, a + 1, ..., a + b - n2 - 1, | ||
| # a + b, a + b + 1, ..., a + b + c - n3 - 1] | ||
| # [0, a, a + b, a + b + c] -> [a, b, c] | ||
| query_len_per_req = (cu_target_query_lens[1:] - | ||
| cu_target_query_lens[:-1]) | ||
| # [a, b, c] -> [a - n1, b - n2, c - n3] | ||
| num_tokens_per_req = query_len_per_req - num_rejected_tokens | ||
| if is_torchair_graph: | ||
| cu_num_tokens = cu_target_query_lens | ||
| relative_index = query_len_per_req - num_rejected_tokens - 1 | ||
| token_indices = cu_num_tokens[:-1] + relative_index | ||
| # the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model | ||
| target_token_ids = token_ids | ||
| target_positions = positions | ||
| target_hidden_states = hidden_states | ||
| target_slot_mapping = slot_mapping | ||
| else: | ||
| cu_num_tokens = torch.empty_like(cu_target_query_lens) | ||
| torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) | ||
| cu_num_tokens[0] = 0 | ||
|
|
||
| # FIXME(woosuk): Avoid synchronization. | ||
| num_tokens = cu_num_tokens[-1].item() | ||
| token_indices = torch.zeros( | ||
| num_tokens, | ||
| dtype=torch.int32, | ||
| device=cu_num_tokens.device, | ||
| ) | ||
|
|
||
| BLOCK_SIZE = 1024 | ||
| self._prepare_input_kernel( | ||
| token_indices, | ||
| cu_target_query_lens, | ||
| cu_num_tokens, | ||
| block_size=BLOCK_SIZE, | ||
| ) | ||
| target_token_ids = token_ids[token_indices] | ||
| target_positions = positions[token_indices] | ||
| target_hidden_states = hidden_states[token_indices] | ||
| target_slot_mapping = slot_mapping[token_indices] | ||
| return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping | ||
| common_attn_metadata: CommonAttentionMetadata, | ||
| sampled_token_ids: list[list[int]], | ||
| num_draft_tokens: list[int], | ||
| ) -> tuple[CommonAttentionMetadata, torch.Tensor]: | ||
| """ | ||
| This function is used to prepare the inputs for speculative decoding. | ||
| It updates to the common_attn_metadata to account for the rejected | ||
| tokens (and newly sampled tokens). It also returns the token indices | ||
| of the tokens that should be fed to the speculator. | ||
| """ | ||
| # E.g. | ||
| # common_attn_metadata.query_start_loc{_cpu}: | ||
| # [0, q1, q1 + q2, q1 + q2 + q3] | ||
| # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] | ||
| # num_rejected_tokens: [n1, n2, n3] | ||
| # This function computes the intermediate values: | ||
| # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] | ||
| # And returns: | ||
| # common_attn_metadata.query_start_loc{_cpu}: | ||
| # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] | ||
| # common_attn_metadata.seq_lens{_cpu}: | ||
| # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] | ||
| # token_indices: [0, 1, ..., q1 - n1 - 1, | ||
| # q1, q1 + 1, ..., q1 + q2 - n2 - 1, | ||
| # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] | ||
|
|
||
| num_rejected_tokens = [ | ||
| n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 | ||
| for i, n in enumerate(num_draft_tokens) | ||
| ] | ||
| num_rejected_tokens = torch.tensor(num_rejected_tokens, | ||
| dtype=torch.int32) | ||
|
|
||
| device = common_attn_metadata.query_start_loc.device | ||
| query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu | ||
| new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens | ||
|
|
||
| # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] | ||
| new_query_len_per_req = query_start_loc_cpu[ | ||
| 1:] - query_start_loc_cpu[:-1] | ||
| # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] | ||
| new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens | ||
| new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() | ||
|
|
||
| # [q1 - n1, q2 - n2, q3 - n3] -> | ||
| # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] | ||
| new_query_start_loc_cpu = torch.zeros( | ||
| query_start_loc_cpu.shape, | ||
| dtype=torch.int32, | ||
| pin_memory=is_pin_memory_available(), | ||
| ) | ||
| new_query_start_loc_np = new_query_start_loc_cpu.numpy() | ||
| np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) | ||
|
|
||
| total_num_tokens = new_query_start_loc_np[-1] | ||
| # Example assuming num_tokens_per_req_np = [2, 4, 3] | ||
| # this implies that `new_query_start_locs` is: | ||
| # [0, 2, 6, 9] -> | ||
| # [0, 0, 2, 2, 2, 2, 6, 6, 6] | ||
| # _r1_ ____r2____ ___r3__ | ||
| new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], | ||
| new_num_tokens_per_req_np) | ||
| # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> | ||
| # [0, 1, 0, 1, 2, 3, 0, 1, 2] | ||
| # _r1_ ____r2____ ___r3__ | ||
| token_offests = (self.token_arange_np[:total_num_tokens] - | ||
| new_query_start_locs_expanded) | ||
|
|
||
| # Expand starting positions to match token pattern | ||
| # [0, q1, q1 + q2] -> | ||
| # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] | ||
| # _r1_ _____r2_______ ___________r3____________ | ||
| old_query_start_locs_expanded = np.repeat( | ||
| query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) | ||
| # Final token indices are: | ||
| # [0, 1, // req 1 | ||
| # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 | ||
| # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 | ||
| token_indices_np = token_offests + old_query_start_locs_expanded | ||
| token_indices = torch.from_numpy(token_indices_np).to( | ||
| device, non_blocking=True) | ||
|
|
||
| spec_common_attn_metadata = AscendCommonAttentionMetadata( | ||
| query_start_loc=new_query_start_loc_cpu.to(device, | ||
| non_blocking=True), | ||
| query_start_loc_cpu=new_query_start_loc_cpu, | ||
| seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), | ||
| seq_lens_cpu=new_seq_lens_cpu, | ||
| num_computed_tokens_cpu=common_attn_metadata. | ||
| num_computed_tokens_cpu, | ||
| num_reqs=common_attn_metadata.num_reqs, | ||
| num_actual_tokens=total_num_tokens, | ||
| max_query_len=new_query_len_per_req.max().item(), | ||
| block_table_tensor=common_attn_metadata.block_table_tensor, | ||
| slot_mapping=common_attn_metadata.slot_mapping[token_indices], | ||
| actual_seq_lengths_q=self.runner.actual_seq_lengths_q, | ||
| positions=common_attn_metadata.positions[token_indices], | ||
| attn_mask=self.runner.attn_mask, | ||
| spec_attn_mask=self.runner.spec_attn_mask, | ||
| attn_state=self.runner.attn_state, | ||
| graph_pad_size=self.runner.graph_pad_size, | ||
| decode_token_per_req=self.runner.decode_token_per_req, | ||
| ) | ||
| return spec_common_attn_metadata, token_indices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _prepare_inputs function performs several operations on the CPU using numpy, followed by data transfers to the GPU. This can introduce significant performance overhead due to CPU-GPU synchronization. To improve performance, these operations should be performed directly on the GPU using torch equivalents. For example, np.cumsum can be replaced with torch.cumsum, and np.repeat with torch.repeat_interleave. This will avoid unnecessary data movement between CPU and GPU.
| self.backup_next_token_ids.np[:num_reqs] = np.array([ | ||
| requests[gpu_input_batch.req_ids[i]].get_token_id( | ||
| common_attn_metadata.seq_lens_cpu[i].item()) | ||
| for i in range(num_reqs) | ||
| ]) | ||
| self.backup_next_token_ids.copy_to_gpu(num_reqs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The prepare_next_token_ids_padded function's docstring states that it "must use device functions to operate on the inputs, and should not introduce any blocking CPU-GPU synchronization." However, the implementation includes a CPU-side loop to precompute backup_next_token_ids and then copies the data to the GPU. This introduces the synchronization that the function aims to avoid and can be a performance bottleneck. This logic should be moved to the GPU to maintain performance.
| if self.speculative_config: | ||
| num_tokens = [ | ||
| self.requests[r].num_tokens for r in self.input_batch.req_ids | ||
| ] | ||
| num_tokens_np = np.array(num_tokens, dtype=np.int32) | ||
|
|
||
| # Record the index of requests that should not be sampled, | ||
| # so that we could clear the sampled tokens before returning | ||
| num_reqs = self.input_batch.num_reqs | ||
| discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np | ||
| discard_request_indices = np.nonzero(discard_requests_mask)[0] | ||
| self.num_discarded_requests = len(discard_request_indices) | ||
| self.discard_request_indices.np[:self.num_discarded_requests] = ( | ||
| discard_request_indices) | ||
|
|
||
| self.discard_request_indices.copy_to_gpu( | ||
| self.num_discarded_requests) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to compute discard_request_indices is performed on the CPU using a Python list comprehension and numpy operations, followed by a copy to the GPU. This happens in _prepare_input_ids, which is on a performance-critical path. To avoid the overhead of CPU-GPU data transfer, this computation should be performed directly on the GPU using torch operations.
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
Signed-off-by: xuyexiong <[email protected]>
Signed-off-by: xuyexiong <[email protected]>
Signed-off-by: xuyexiong <[email protected]>
Signed-off-by: xuyexiong <[email protected]>
Signed-off-by: xuyexiong <[email protected]>
Signed-off-by: xuyexiong <[email protected]>
What this PR does / why we need it?
mtp_proposer.py, splits torchair related codes intomtp_torchair_proposer.pyDoes this PR introduce any user-facing change?
User can use
disable_padded_drafter_batchto disable/enable padded speculationoffline example:
How was this patch tested?
performance test of deepseek-r1 with tp16、dp1
aclgraph with pad ITL: 168ms
aclgraph with unpad ITL: 169ms
original: 178ms