Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,15 @@ def propose_tree(
# Copy inputs to buffer for cudagraph.
num_tokens = attn_metadata.num_actual_tokens
input_ids = tree_input_ids.view(-1)

# Handle -1 sentinel values from padded speculation for MTP models
# which call embed_tokens() and can't handle invalid indices
if self.method == "mtp":
# Filter out -1 sentinel values that mark discarded/invalid
# tokens
vocab_size = self.model.model.embed_tokens.weight.size(0)
input_ids = torch.clamp(input_ids, min=0, max=vocab_size - 1)
Comment on lines +709 to +713
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

While this logic correctly handles sentinel values for MTP models within propose_tree, the fix is incomplete. The propose method (around line 210) also constructs input_ids from target_token_ids and next_token_ids, which can contain -1 from padded speculation or rejection sampling. This will lead to the same embedding lookup error that this PR aims to fix.

To fully resolve the issue, a similar clamping mechanism should be implemented in the propose method as well. You can add the following code block after self.input_ids[last_token_indices] = next_token_ids:

if self.method == "mtp":
    # Handle -1 sentinel values from padded speculation for MTP models
    # which call embed_tokens() and can't handle invalid indices.
    vocab_size = self.model.model.embed_tokens.weight.size(0)
    clamped_input_ids = torch.clamp(self.input_ids[:num_tokens], min=0, max=vocab_size - 1)
    self.input_ids[:num_tokens] = clamped_input_ids


self.input_ids[:num_tokens] = input_ids
self.positions[:num_tokens] = tree_positions.view(-1)
self.hidden_states[:num_tokens] = tree_hidden_states.view(
Expand Down
Loading