@@ -32,16 +32,28 @@ def __init__(self,
3232                 device : torch .device ,
3333                 runner = None ):
3434        self .name  =  SpecDcodeType .EAGLE  if  vllm_config .speculative_config .method  ==  "eagle"  else  SpecDcodeType .EAGLE3 
35-         self .vllm_config  =  vllm_config 
3635        self .device  =  device 
36+         self .vllm_config  =  vllm_config 
37+         self .speculative_config  =  vllm_config .speculative_config 
38+         self .draft_model_config  =  self .speculative_config .draft_model_config 
39+         self .method  =  self .speculative_config .method 
40+ 
3741        self .runner  =  runner 
42+         self .dtype  =  vllm_config .model_config .dtype 
43+         self .max_model_len  =  vllm_config .model_config .max_model_len 
44+         self .block_size  =  vllm_config .cache_config .block_size 
45+         self .num_speculative_tokens  =  (
46+             self .speculative_config .num_speculative_tokens )
47+         self .max_num_tokens  =  (
48+             vllm_config .scheduler_config .max_num_batched_tokens )
49+         self .token_arange_np  =  np .arange (self .max_num_tokens )
3850
3951        self .block_size  =  vllm_config .cache_config .block_size 
4052        # We need to get the hidden size from the draft model config because 
4153        # the draft model's hidden size can be different from the target model's 
4254        # hidden size (e.g., Llama 3.3 70B). 
43-         self .hidden_size  =  vllm_config . speculative_config . draft_model_config .get_hidden_size (
44-         ) 
55+         self .hidden_size  =  self . draft_model_config .get_hidden_size () 
56+ 
4557
4658        self .use_cuda_graph  =  (self .vllm_config .compilation_config .level 
4759                               ==  CompilationLevel .PIECEWISE  and 
@@ -52,17 +64,16 @@ def __init__(self,
5264
5365        # persistent buffers for cuda graph 
5466        self .input_ids  =  torch .zeros (
55-             self .vllm_config . scheduler_config . max_num_batched_tokens ,
67+             self .max_num_tokens ,
5668            dtype = torch .int32 ,
5769            device = device )
5870        self .positions  =  torch .zeros (
59-             self .vllm_config . scheduler_config . max_num_batched_tokens ,
71+             self .max_num_tokens ,
6072            dtype = torch .int64 ,
6173            device = device )
6274        self .hidden_states  =  torch .zeros (
63-             (self .vllm_config .scheduler_config .max_num_batched_tokens ,
64-              self .hidden_size ),
65-             dtype = self .vllm_config .model_config .dtype ,
75+             (self .max_num_tokens , self .hidden_size ),
76+             dtype = self .dtype ,
6677            device = device )
6778        # We need +1 here because the arange is used to set query_start_loc, 
6879        # which has one more element than batch_size. 
@@ -398,14 +409,18 @@ def _propose(
398409        # [batch_size, max_num_blocks_per_req] 
399410        block_table : torch .Tensor ,
400411        sampling_metadata : SamplingMetadata ,
412+         last_token_indices : Optional [torch .Tensor ],
413+ 
401414    ) ->  torch .Tensor :
402415        device  =  cu_num_tokens .device 
403416        cu_num_tokens  =  cu_num_tokens .cpu ()
404417        block_table  =  block_table .cpu ()
405418        num_tokens  =  target_token_ids .shape [0 ]
406419        batch_size  =  next_token_ids .shape [0 ]
407-         last_token_indices  =  cu_num_tokens [1 :] -  1 
420+         if  last_token_indices  is  None :
421+             last_token_indices  =  common_attn_metadata .query_start_loc [1 :] -  1 
408422        target_positions  =  target_positions .cpu ()
423+         
409424        if  self .name  ==  SpecDcodeType .EAGLE3 :
410425            assert  isinstance (self .model , Eagle3LlamaForCausalLM )
411426            target_hidden_states  =  self .model .combine_hidden_states (
0 commit comments