-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[Bugfix] Fix MTP bug with padded speculation #26198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Matthew Bonanni <[email protected]>
Signed-off-by: Matthew Bonanni <[email protected]>
Signed-off-by: Matthew Bonanni <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses a crash in MTP models when using padded speculation by clamping invalid token IDs. The change in propose_tree is a good step towards fixing the issue. However, the fix is incomplete as the propose method is not patched, leaving it vulnerable to the same crash. Additionally, I've noticed another critical issue in propose_tree where the model's return value is not handled correctly for MTP models, which will lead to a TypeError. While the second issue is outside the scope of the current diff, addressing both is crucial for a robust solution. I have added a comment on the diff to detail the incomplete fix.
| if self.method == "mtp": | ||
| # Filter out -1 sentinel values that mark discarded/invalid | ||
| # tokens | ||
| vocab_size = self.model.model.embed_tokens.weight.size(0) | ||
| input_ids = torch.clamp(input_ids, min=0, max=vocab_size - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this logic correctly handles sentinel values for MTP models within propose_tree, the fix is incomplete. The propose method (around line 210) also constructs input_ids from target_token_ids and next_token_ids, which can contain -1 from padded speculation or rejection sampling. This will lead to the same embedding lookup error that this PR aims to fix.
To fully resolve the issue, a similar clamping mechanism should be implemented in the propose method as well. You can add the following code block after self.input_ids[last_token_indices] = next_token_ids:
if self.method == "mtp":
# Handle -1 sentinel values from padded speculation for MTP models
# which call embed_tokens() and can't handle invalid indices.
vocab_size = self.model.model.embed_tokens.weight.size(0)
clamped_input_ids = torch.clamp(self.input_ids[:num_tokens], min=0, max=vocab_size - 1)
self.input_ids[:num_tokens] = clamped_input_ids|
Same question as the bot, why only add the fix in |
|
Hello~ Could you also check out my fix #26231 to determine if we encountered the same issue? |
@seven-mile This does appear to be the same issue and your fix addresses my issue as well. I believe your fix is better so I'll close my PR. Thanks! @benchislett Could you review #26231? |
Purpose
MTP models crash with embedding lookup errors when using padded speculation (PR #24539) in large batches with variable-length inputs.
Root Cause: Eagle's padded speculation uses
-1as sentinel values to mark discarded/invalid tokens in batched requests. MTP models callembed_tokens()directly and can't handle these invalid indices, while other speculators don't have this issue as they don't perform embedding lookups.Solution: Filter out
-1sentinel values in Eagle proposer before they reach MTP models by clampinginput_idsto valid vocabulary range [0, vocab_size-1]. This preserves the efficiency benefits of padded speculation while ensuring MTP models receive valid token indices.Test Plan
Basic functionality
Correctness
Test Result
Basic functionality
Doesn't crash
Correctness
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.