Skip to content

Conversation

@RunningLeon
Copy link
Collaborator

@RunningLeon RunningLeon commented Sep 9, 2025

Motivation

Support speculative decoding

Examples

pipeline

from lmdeploy import pipeline, PytorchEngineConfig
from lmdeploy.messages import SpeculativeConfig


if __name__ == '__main__':

    model_path = 'meta-llama/Llama-3.1-8B-Instruct'
    spec_cfg = SpeculativeConfig(method='eagle3', 
                                    num_speculative_tokens=3,
                                    model='yuhuili/EAGLE3-LLaMA3.1-Instruct-8B',
                                    )
    pipe = pipeline(model_path, 
                    backend_config=PytorchEngineConfig(max_batch_size=128),
                    speculative_config=spec_cfg)
    response = pipe(['Hi, pls intro yourself', 'Shanghai is'])
    print(response)

serving

HF_HOME=/nvme1/shared/huggingface_hub \
TRANSFORMERS_OFFLINE=1 \
CUDA_VISIBLE_DEVICES=7 \
lmdeploy serve api_server \
meta-llama/Llama-3.1-8B-Instruct \
--backend pytorch \
--server-port 24545 \
--speculative-draft-model yuhuili/EAGLE3-LLaMA3.1-Instruct-8B \
--speculative-algorithm eagle3 \
--speculative-num-draft-tokens 3 \
--max-batch-size 128 \
--enable-metrics

BC-breaking (Optional)

Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.
  3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

@RunningLeon RunningLeon requested a review from grimoire September 9, 2025 06:26
@RunningLeon RunningLeon changed the title [WIP]: Support speculative decoding [Feature]: Support speculative decoding Oct 23, 2025
@RunningLeon RunningLeon marked this pull request as ready for review October 23, 2025 09:04
self.speculative_config = speculative_config

if speculative_config is not None:
engine_config.prefill_interval = 16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change was used for debugging and can be removed

device_type=engine_config.device_type,
distributed_executor_backend=engine_config.distributed_executor_backend,
dtype=engine_config.dtype,
speculative_config=speculative_config,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speculative_config is define outside pytorch engine. It is not a good design to use it in any module besides Engine.

# input ids
token_ids = [msg.token_ids for msg in messages]

# spec decode
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the comment.


def _debug_spec_stats(self, batched_outputs: BatchedOutputs, is_decoding: bool = False):
"""Debugging spec stats."""
is_debugging = True
Copy link
Collaborator

@grimoire grimoire Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this after release the feature?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just for debugging.

logger.warning(f'Overriding HF config with {hf_overrides}')
override_hf_config(model_config.hf_config, hf_overrides)

# for serialization of transformers modules
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not work

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works with tp case on one node, but not teste on dp case on multiple nodes

inputs: ModelInputs,
cache_engine: CacheEngine,
stream: torch.cuda.Stream = None,
output_position_ids: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output position_ids is cheap, we can always output it.

f'batch_size={inputs.seq_length.size(0)} '
f'num_tokens={inputs.input_ids.size(-1)} '
f'is_decoding={inputs.is_decoding}')
logger.info(f'<ForwardTask> rank[{rank}]: '
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

info this would be too verbose

def __init__(self,
model_path: str,
engine_config: PytorchEngineConfig = None,
speculative_config: SpeculativeConfig = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does RayEngine support speculative_config?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.

return k_states, v_states


@triton.testing.perf_report(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove bench

input_buffers['position_ids'] = torch.zeros((1, max_tokens), dtype=torch.int64, device=device)
if getattr(self.config, 'use_flash_mla', False) is True:
import flash_mla
seqlens_dtype = torch.int64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when would we need int64?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default is int64. while mla, fa3 needs int32.

"""Get max tokens."""
num_tokens = input_ids.size(1)
orig_batch = q_seqlens.size(0)
if num_tokens == orig_batch:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think sending tensor here is a good idea.

self.scheduler_config = scheduler_config
self.cache_config = cache_config

self.num_spec_tokens = num_spec_tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has this value been used?

@@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.

from .deepseek_mtp import DeepseekMTP # noqa F401
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want model w/o spec decoding loading these modules.

def get_logits(self, hidden_states: torch.Tensor):
"""Get logits of model output."""
draft_model = self.model
if not isinstance(draft_model, torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

graph_runner has expose get_logits of model.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. but eagle do not have get_logits while eagle3 has . Base on graph_runner's get_logits method, we cannot differ these two. That's why here check if original model has get_logits

expected_output_token_ids = torch.tensor([[0, 1, 2], [0, -1, -1], [1, -1, -1]], dtype=torch.long).cuda()

draft_probs = None
target_probs, draft_token_ids, bonus_token_ids, max_spec_len = torch.load('tmp.pt')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test requires tmp.pt, do not place it here.
Could we add it in unit test?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants