@@ -171,7 +171,7 @@ def propose(
171171 for layer_name in self .attn_layer_names :
172172 per_layer_attn_metadata [layer_name ] = attn_metadata
173173 if self .use_cuda_graph and \
174- num_tokens <= self .cudagraph_batch_sizes [- 1 ]:
174+ num_tokens <= self .cudagraph_batch_sizes [- 1 ]:
175175 num_input_tokens = self .vllm_config .pad_for_cudagraph (num_tokens )
176176 else :
177177 num_input_tokens = num_tokens
@@ -253,7 +253,7 @@ def propose(
253253 draft_token_ids_list = [draft_token_ids ]
254254
255255 if self .use_cuda_graph and \
256- batch_size <= self .cudagraph_batch_sizes [- 1 ]:
256+ batch_size <= self .cudagraph_batch_sizes [- 1 ]:
257257 input_batch_size = self .vllm_config .pad_for_cudagraph (batch_size )
258258 else :
259259 input_batch_size = batch_size
@@ -474,7 +474,7 @@ def propose_tree(
474474 num_tokens , - 1 )
475475
476476 if self .use_cuda_graph and \
477- num_tokens <= self .cudagraph_batch_sizes [- 1 ]:
477+ num_tokens <= self .cudagraph_batch_sizes [- 1 ]:
478478 num_input_tokens = self .vllm_config .pad_for_cudagraph (
479479 num_tokens )
480480 else :
@@ -644,17 +644,15 @@ def load_model(self, target_model: nn.Module) -> None:
644644 and self .model .model .embed_tokens .weight .shape \
645645 == target_language_model .model .embed_tokens .weight .shape :
646646 logger .info (
647- "Assuming the EAGLE head shares the same vocab embedding" \
648- " with the target model."
649- )
647+ "Assuming the EAGLE head shares the same vocab embedding"
648+ " with the target model." )
650649 del self .model .model .embed_tokens
651650 self .model .model .embed_tokens = (
652651 target_language_model .model .embed_tokens )
653652 else :
654653 logger .info (
655- "The EAGLE head's vocab embedding will be loaded separately" \
656- " from the target model."
657- )
654+ "The EAGLE head's vocab embedding will be loaded separately"
655+ " from the target model." )
658656
659657 # share lm_head with the target model if needed
660658 # some model definition do not define lm_head explicitly
@@ -745,4 +743,4 @@ def compute_probs_and_sample_next_token(
745743 greedy_token_ids ,
746744 next_token_ids ,
747745 )
748- return next_token_ids , probs
746+ return next_token_ids , probs
0 commit comments