12
12
from verifiers .parsers .parser import Parser
13
13
from verifiers .rubrics .rubric import Rubric
14
14
from verifiers .types import (
15
+ Completion ,
15
16
ChatCompletion ,
16
17
ChatCompletionToolParam ,
17
18
ChatMessage ,
@@ -656,6 +657,21 @@ def parse_chat_completion_logprobs(
656
657
]
657
658
return logprobs
658
659
660
+ def parse_completion_logprobs (
661
+ self , completion : Completion
662
+ ) -> List [float ]:
663
+ """Parses the completion logprobs from a vLLM chat completion"""
664
+ assert len (completion .choices ) == 1 , (
665
+ "Response should always have one choice"
666
+ )
667
+ assert completion .choices [0 ].logprobs is not None , (
668
+ "Logprobs should not be None. Make sure to set logprobs=True in the extra body when making the request to /v1/completions"
669
+ )
670
+ assert completion .choices [0 ].logprobs .token_logprobs is not None , (
671
+ "Logprob token_logprobs should not be None. Make sure to set logprobs=True in the extra body when making the request to /v1/completions"
672
+ )
673
+ return completion .choices [0 ].logprobs .token_logprobs
674
+
659
675
def parse_chat_completion_tokens (
660
676
self , chat_completion : ChatCompletion
661
677
) -> list [int ]:
@@ -670,11 +686,32 @@ def parse_chat_completion_tokens(
670
686
"Logprob content should not be None. Make sure to set logprobs=True in the extra body when making the request to /v1/chat/completions"
671
687
)
672
688
tokens = [
689
+ # tokens are token_id:<int> because we request `return_tokens_as_token_ids` from vllm in GRPOTrainer
673
690
int (token .token .split (":" )[- 1 ])
674
691
for token in chat_completion .choices [0 ].logprobs .content
675
692
]
676
693
return tokens
677
694
695
+ def parse_completion_tokens (
696
+ self , completion : Completion
697
+ ) -> List [int ]:
698
+ """Parses the output token ids from a list of chat completions returned by vLLM OAI server."""
699
+ assert len (completion .choices ) == 1 , (
700
+ "Response should always have one choice"
701
+ )
702
+ assert completion .choices [0 ].logprobs is not None , (
703
+ "Logprobs should not be None. Make sure to set logprobs=True in the extra body when making the request to /v1/completions"
704
+ )
705
+ assert completion .choices [0 ].logprobs .tokens is not None , (
706
+ "Logprob tokens should not be None. Make sure to set logprobs=True in the extra body when making the request to /v1/completions"
707
+ )
708
+ tokens = [
709
+ # tokens are token_id:<int> because we request `return_tokens_as_token_ids` from vllm in GRPOTrainer
710
+ int (token .split (":" )[- 1 ])
711
+ for token in completion .choices [0 ].logprobs .tokens
712
+ ]
713
+ return tokens
714
+
678
715
def process_chat_format_vllm (
679
716
self ,
680
717
prompt : list [ChatMessage ],
@@ -759,6 +796,77 @@ def process_chat_format_vllm(
759
796
completion_logprobs ,
760
797
)
761
798
799
+ def process_completion_format_vllm (
800
+ self ,
801
+ prompt : str ,
802
+ completion : str ,
803
+ state : State ,
804
+ processing_class : "PreTrainedTokenizerBase" ,
805
+ mask_env_responses : bool = False ,
806
+ ) -> Tuple [List [int ], List [int ], List [int ], List [int ], List [float ]]:
807
+ """
808
+ Process completion format conversations using incremental prefixes.
809
+ """
810
+ responses : list [Completion ] = state ["responses" ]
811
+ responses_start_idx : list [int ] = state ["responses_start_idx" ]
812
+ assert len (responses ) == len (responses_start_idx ), "Should have an index for each completion response"
813
+
814
+ idx = 0
815
+ zipped : list [tuple [str , Completion | None ]] = []
816
+ for response , response_start_idx in zip (responses , responses_start_idx ):
817
+ if response_start_idx > idx :
818
+ # non-model-generated section
819
+ zipped .append ((completion [idx :response_start_idx ], None ))
820
+ response_text = response .choices [0 ].text or ""
821
+ zipped .append ((response_text , response ))
822
+ idx = response_start_idx + len (response_text )
823
+ assert idx == len (completion ), "Completion not fully consumed"
824
+
825
+ prompt_ids : list [int ] = processing_class .encode (prompt )
826
+ rollout_consumed = prompt
827
+ prompt_mask : list [int ] = [0 ] * len (prompt_ids )
828
+ completion_ids : list [int ] = []
829
+ completion_mask : list [int ] = []
830
+ completion_logprobs : list [float ] = []
831
+ i = 0
832
+ while i < len (zipped ):
833
+ text , response = zipped [i ]
834
+ # model-generated case -- use response
835
+ if response is not None :
836
+ completion_turn_ids = self .parse_completion_tokens (response )
837
+ completion_turn_mask = [1 ] * len (completion_turn_ids )
838
+ completion_turn_logprobs = self .parse_completion_logprobs (response )
839
+ completion_ids .extend (completion_turn_ids )
840
+ completion_mask .extend (completion_turn_mask )
841
+ completion_logprobs .extend (completion_turn_logprobs )
842
+ rollout_consumed += text
843
+ i += 1
844
+ # non-model-generated (user/tool case) -- use text
845
+ else :
846
+ token_prefix : list [int ] = processing_class .encode (rollout_consumed )
847
+ token_prefix_with_turn : list [int ] = processing_class .encode (rollout_consumed + text )
848
+ assert token_prefix_with_turn [: len (token_prefix )] == token_prefix , (
849
+ f"Token prefix mismatch. Token prefix: { token_prefix } , token prefix with turn: { token_prefix_with_turn } "
850
+ )
851
+ completion_turn_ids = token_prefix_with_turn [len (token_prefix ) :]
852
+ if mask_env_responses :
853
+ completion_turn_mask = [0 ] * len (completion_turn_ids )
854
+ else :
855
+ completion_turn_mask = [1 ] * len (completion_turn_ids )
856
+ completion_turn_logprobs = [0.0 ] * len (completion_turn_ids )
857
+ completion_ids .extend (completion_turn_ids )
858
+ completion_mask .extend (completion_turn_mask )
859
+ completion_logprobs .extend (completion_turn_logprobs )
860
+ rollout_consumed += text
861
+ i += 1
862
+ return (
863
+ prompt_ids ,
864
+ prompt_mask ,
865
+ completion_ids ,
866
+ completion_mask ,
867
+ completion_logprobs ,
868
+ )
869
+
762
870
def process_env_results_vllm (
763
871
self ,
764
872
prompts : list [Messages ],
@@ -775,10 +883,8 @@ def process_env_results_vllm(
775
883
Process results with vLLM tokens/logprobs.
776
884
"""
777
885
# Determine format from first prompt
886
+ # TODO: why not from self.message_type?
778
887
is_chat_format = isinstance (prompts [0 ], list )
779
- assert is_chat_format , (
780
- "vLLM output parsing is not yet supported for completion format"
781
- )
782
888
783
889
all_prompt_ids = []
784
890
all_prompt_masks = []
@@ -803,10 +909,15 @@ def process_env_results_vllm(
803
909
)
804
910
else :
805
911
assert isinstance (prompt , str ) and isinstance (completion , str )
806
- prompt_ids , prompt_mask , completion_ids , completion_mask = (
807
- self .process_completion_format (prompt , completion , processing_class )
912
+ (
913
+ prompt_ids ,
914
+ prompt_mask ,
915
+ completion_ids ,
916
+ completion_mask ,
917
+ completion_logprobs ,
918
+ ) = self .process_completion_format_vllm (
919
+ prompt , completion , state , processing_class , mask_env_responses
808
920
)
809
- completion_logprobs = [0 ] * len (completion_ids )
810
921
is_truncated = False
811
922
if max_seq_len > 0 and len (prompt_ids ) + len (completion_ids ) > max_seq_len :
812
923
if len (prompt_ids ) > max_seq_len :
0 commit comments