Skip to content

Commit 0c1ff29

Browse files
authored
Add support for base model RL / message_type="completions" (#201)
* add continuation quality environment using qwen 2.5 base model and gpt 4.1 mini judge * delete duplicated message_type parameters * implement completion methods for vllm * ruff check
1 parent 79a2191 commit 0c1ff29

File tree

7 files changed

+314
-16
lines changed

7 files changed

+314
-16
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# vf-continuation-quality
2+
3+
### Overview
4+
- **Environment ID**: `vf-continuation-quality`
5+
- **Short description**: Single-turn quality grades on base model continuations using a judge model.
6+
- **Tags**: single-turn, completions, base-model
7+
8+
### Datasets
9+
- **Primary dataset(s)**: `agentlans/wikipedia-paragraphs` mapped to prefix/ground-truth continuation
10+
- **Source links**: Hugging Face Datasets
11+
- **Split sizes**: Train split filtered to adequately-long paragraphs
12+
13+
### Task
14+
- **Type**: single-turn
15+
- **Parser**: custom
16+
- **Rubric overview**: Judge model letter grade (gpt-4.1-mini-based by default)
17+
18+
### Quickstart
19+
Run an evaluation with default settings:
20+
21+
```bash
22+
uv run vf-eval vf-continuation-quality
23+
```
24+
25+
Configure model and sampling:
26+
27+
```bash
28+
uv run vf-eval vf-continuation-quality -m gpt-4.1-mini -n 20 -r 3 -t 1024 -T 0.7 -a '{"key": "value"}' # env-specific args as JSON
29+
```
30+
31+
Notes:
32+
- Use `-a` / `--env-args` to pass environment-specific configuration as a JSON object.
33+
- Reports are written under `./environments/vf_continuation_quality/reports/` and auto-embedded below.
34+
35+
### Environment Arguments
36+
Document any supported environment arguments and their meaning. Example:
37+
38+
| Arg | Type | Default | Description |
39+
| --- | ---- | ------- | ----------- |
40+
| `dataset_name` | str | `"agentlans/wikipedia-paragraphs"` | Training dataset |
41+
| `dataset_split` | str | `"train"` | Training dataset split |
42+
| `dataset_key` | str | `"text"` | Column in dataset with training text |
43+
| `judge_model` | str | `"gpt-4.1-mini"` | Model to judge continuations with |
44+
| `judge_base_url` | str | `"https://api.openai.com/v1"` | API base URL for judge model |
45+
| `judge_api_key_var` | str | `"OPENAI_API_KEY"` | Environment variable containing the judge model API key |
46+
47+
### Metrics
48+
Summarize key metrics your rubric emits and how they’re interpreted.
49+
50+
| Metric | Meaning |
51+
| ------ | ------- |
52+
| `reward` | Main scalar reward (weighted sum of criteria) |
53+
54+
## Evaluation Reports
55+
56+
<!-- Do not edit below this line. Content is auto-generated. -->
57+
<!-- vf:begin:reports -->
58+
<p>No reports found. Run <code>uv run vf-eval vf-continuation-quality -a '{"key": "value"}'</code> to generate one.</p>
59+
<!-- vf:end:reports -->
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[project]
2+
name = "vf-continuation-quality"
3+
version = "0.1.0"
4+
dependencies = [
5+
"verifiers>=0.1.2",
6+
]
7+
8+
[build-system]
9+
requires = ["hatchling"]
10+
build-backend = "hatchling.build"
11+
12+
[tool.hatch.build]
13+
include = ["vf_continuation_quality.py"]
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import os
2+
import random
3+
4+
from datasets import load_dataset
5+
from openai import OpenAI
6+
7+
import verifiers as vf
8+
9+
_rand = random.Random(777)
10+
def make_cut(text: str) -> dict[str, str]:
11+
"""Makes a random cut somewhere in the paragraph"""
12+
n_spaces = text.count(" ")
13+
# mostly split near the middle
14+
split_space = int(_rand.normalvariate(0.5, 0.15) * n_spaces)
15+
# make sure there's at least ~25 words before and after the split point
16+
split_space = min(n_spaces - 25, max(25, split_space))
17+
idx = -1
18+
for _ in range(split_space):
19+
idx = text.find(" ", idx + 1)
20+
return { "prompt": text[:idx], "answer": text[idx:] }
21+
22+
23+
def load_environment(
24+
dataset_name: str = "agentlans/wikipedia-paragraphs",
25+
dataset_split: str | None = "train",
26+
dataset_key: str = "text",
27+
judge_model: str = "gpt-4.1-mini",
28+
judge_base_url: str = "https://api.openai.com/v1",
29+
judge_api_key_var: str = "OPENAI_API_KEY",
30+
) -> vf.Environment:
31+
dataset = load_dataset(dataset_name, split=dataset_split)
32+
# only accept examples with >~100 words or so
33+
dataset = dataset.filter(lambda x: x[dataset_key].count(" ") > 100)
34+
dataset = dataset.map(lambda x: make_cut(x[dataset_key]))
35+
dataset = dataset.shuffle(seed=777)
36+
37+
judge_client = OpenAI(api_key=os.getenv(judge_api_key_var), base_url=judge_base_url)
38+
judge_prompt = """Evaluate this base model contination from a prefix, compared to the true continuation from Wikipedia.
39+
40+
<prefix>
41+
{question}
42+
</prefix>
43+
44+
<true_continuation>
45+
{answer}
46+
</true_continuation>
47+
48+
<model_continuation>
49+
{response}
50+
</model_continuation>
51+
52+
Provide a letter grade from A-F where:
53+
- A: Smooth prose, facts are mostly accurate w.r.t the true continuation
54+
- B: Smooth prose, regardless of factual accuracy
55+
- C: Some awkward wording, spacing, or punctuation
56+
- D: Inclusions of awkward or glitchy text along with promising prose, some coherent sentences
57+
- F: Incoherent text
58+
59+
Think aloud in a <scratchpad> for a few lines, then respond with the letter grade in <grade> ... </grade> tags."""
60+
rubric = vf.JudgeRubric(
61+
judge_client=judge_client,
62+
judge_model=judge_model,
63+
judge_prompt=judge_prompt,
64+
)
65+
66+
grade_parser = vf.XMLParser(fields=["grade"], answer_field="grade")
67+
def grade_reward(prompt, completion, answer, state, **kwargs) -> float:
68+
judge_response = rubric.judge(prompt, completion, answer, state, **kwargs)
69+
judge_grade = (
70+
(grade_parser.parse_answer(judge_response) or "F")
71+
.strip()
72+
.replace("+", "")
73+
.replace("-", "")
74+
.upper()
75+
)
76+
return {
77+
"A": 1.0,
78+
"B": 0.75,
79+
"C": 0.5,
80+
"D": 0.25,
81+
}.get(judge_grade, 0.0)
82+
83+
rubric.add_reward_func(grade_reward, weight=1.0)
84+
85+
return vf.SingleTurnEnv(
86+
message_type="completion",
87+
dataset=dataset,
88+
parser=vf.Parser(),
89+
rubric=rubric,
90+
sampling_args={
91+
"stop": ["\n"],
92+
},
93+
)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import verifiers as vf
2+
3+
"""
4+
# install
5+
vf-install vf-continuation-quality (-p /path/to/environments)
6+
7+
# quick eval
8+
vf-eval vf-continuation-quality (-m model_name in endpoints.py)
9+
10+
inference:
11+
CUDA_VISIBLE_DEVICES=0 vf-vllm --model Qwen/Qwen2.5-0.5B \
12+
--enforce-eager --disable-log-requests
13+
14+
training:
15+
CUDA_VISIBLE_DEVICES=1 accelerate launch --num-processes 1 \
16+
--config-file configs/zero3.yaml examples/grpo/train_continuation_quality.py
17+
"""
18+
19+
model_name = "Qwen/Qwen2.5-0.5B"
20+
vf_env = vf.load_environment(env_id="vf-continuation-quality")
21+
model, tokenizer = vf.get_model_and_tokenizer(model_name)
22+
trainer = vf.GRPOTrainer(
23+
env=vf_env,
24+
model=model,
25+
processing_class=tokenizer,
26+
args=vf.grpo_defaults(run_name="continuation-quality"),
27+
)
28+
trainer.train()

verifiers/envs/environment.py

Lines changed: 117 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from verifiers.parsers.parser import Parser
1313
from verifiers.rubrics.rubric import Rubric
1414
from verifiers.types import (
15+
Completion,
1516
ChatCompletion,
1617
ChatCompletionToolParam,
1718
ChatMessage,
@@ -656,6 +657,21 @@ def parse_chat_completion_logprobs(
656657
]
657658
return logprobs
658659

660+
def parse_completion_logprobs(
661+
self, completion: Completion
662+
) -> List[float]:
663+
"""Parses the completion logprobs from a vLLM chat completion"""
664+
assert len(completion.choices) == 1, (
665+
"Response should always have one choice"
666+
)
667+
assert completion.choices[0].logprobs is not None, (
668+
"Logprobs should not be None. Make sure to set logprobs=True in the extra body when making the request to /v1/completions"
669+
)
670+
assert completion.choices[0].logprobs.token_logprobs is not None, (
671+
"Logprob token_logprobs should not be None. Make sure to set logprobs=True in the extra body when making the request to /v1/completions"
672+
)
673+
return completion.choices[0].logprobs.token_logprobs
674+
659675
def parse_chat_completion_tokens(
660676
self, chat_completion: ChatCompletion
661677
) -> list[int]:
@@ -670,11 +686,32 @@ def parse_chat_completion_tokens(
670686
"Logprob content should not be None. Make sure to set logprobs=True in the extra body when making the request to /v1/chat/completions"
671687
)
672688
tokens = [
689+
# tokens are token_id:<int> because we request `return_tokens_as_token_ids` from vllm in GRPOTrainer
673690
int(token.token.split(":")[-1])
674691
for token in chat_completion.choices[0].logprobs.content
675692
]
676693
return tokens
677694

695+
def parse_completion_tokens(
696+
self, completion: Completion
697+
) -> List[int]:
698+
"""Parses the output token ids from a list of chat completions returned by vLLM OAI server."""
699+
assert len(completion.choices) == 1, (
700+
"Response should always have one choice"
701+
)
702+
assert completion.choices[0].logprobs is not None, (
703+
"Logprobs should not be None. Make sure to set logprobs=True in the extra body when making the request to /v1/completions"
704+
)
705+
assert completion.choices[0].logprobs.tokens is not None, (
706+
"Logprob tokens should not be None. Make sure to set logprobs=True in the extra body when making the request to /v1/completions"
707+
)
708+
tokens = [
709+
# tokens are token_id:<int> because we request `return_tokens_as_token_ids` from vllm in GRPOTrainer
710+
int(token.split(":")[-1])
711+
for token in completion.choices[0].logprobs.tokens
712+
]
713+
return tokens
714+
678715
def process_chat_format_vllm(
679716
self,
680717
prompt: list[ChatMessage],
@@ -759,6 +796,77 @@ def process_chat_format_vllm(
759796
completion_logprobs,
760797
)
761798

799+
def process_completion_format_vllm(
800+
self,
801+
prompt: str,
802+
completion: str,
803+
state: State,
804+
processing_class: "PreTrainedTokenizerBase",
805+
mask_env_responses: bool = False,
806+
) -> Tuple[List[int], List[int], List[int], List[int], List[float]]:
807+
"""
808+
Process completion format conversations using incremental prefixes.
809+
"""
810+
responses: list[Completion] = state["responses"]
811+
responses_start_idx: list[int] = state["responses_start_idx"]
812+
assert len(responses) == len(responses_start_idx), "Should have an index for each completion response"
813+
814+
idx = 0
815+
zipped: list[tuple[str, Completion | None]] = []
816+
for response, response_start_idx in zip(responses, responses_start_idx):
817+
if response_start_idx > idx:
818+
# non-model-generated section
819+
zipped.append((completion[idx:response_start_idx], None))
820+
response_text = response.choices[0].text or ""
821+
zipped.append((response_text, response))
822+
idx = response_start_idx + len(response_text)
823+
assert idx == len(completion), "Completion not fully consumed"
824+
825+
prompt_ids: list[int] = processing_class.encode(prompt)
826+
rollout_consumed = prompt
827+
prompt_mask: list[int] = [0] * len(prompt_ids)
828+
completion_ids: list[int] = []
829+
completion_mask: list[int] = []
830+
completion_logprobs: list[float] = []
831+
i = 0
832+
while i < len(zipped):
833+
text, response = zipped[i]
834+
# model-generated case -- use response
835+
if response is not None:
836+
completion_turn_ids = self.parse_completion_tokens(response)
837+
completion_turn_mask = [1] * len(completion_turn_ids)
838+
completion_turn_logprobs = self.parse_completion_logprobs(response)
839+
completion_ids.extend(completion_turn_ids)
840+
completion_mask.extend(completion_turn_mask)
841+
completion_logprobs.extend(completion_turn_logprobs)
842+
rollout_consumed += text
843+
i += 1
844+
# non-model-generated (user/tool case) -- use text
845+
else:
846+
token_prefix: list[int] = processing_class.encode(rollout_consumed)
847+
token_prefix_with_turn: list[int] = processing_class.encode(rollout_consumed + text)
848+
assert token_prefix_with_turn[: len(token_prefix)] == token_prefix, (
849+
f"Token prefix mismatch. Token prefix: {token_prefix}, token prefix with turn: {token_prefix_with_turn}"
850+
)
851+
completion_turn_ids = token_prefix_with_turn[len(token_prefix) :]
852+
if mask_env_responses:
853+
completion_turn_mask = [0] * len(completion_turn_ids)
854+
else:
855+
completion_turn_mask = [1] * len(completion_turn_ids)
856+
completion_turn_logprobs = [0.0] * len(completion_turn_ids)
857+
completion_ids.extend(completion_turn_ids)
858+
completion_mask.extend(completion_turn_mask)
859+
completion_logprobs.extend(completion_turn_logprobs)
860+
rollout_consumed += text
861+
i += 1
862+
return (
863+
prompt_ids,
864+
prompt_mask,
865+
completion_ids,
866+
completion_mask,
867+
completion_logprobs,
868+
)
869+
762870
def process_env_results_vllm(
763871
self,
764872
prompts: list[Messages],
@@ -775,10 +883,8 @@ def process_env_results_vllm(
775883
Process results with vLLM tokens/logprobs.
776884
"""
777885
# Determine format from first prompt
886+
# TODO: why not from self.message_type?
778887
is_chat_format = isinstance(prompts[0], list)
779-
assert is_chat_format, (
780-
"vLLM output parsing is not yet supported for completion format"
781-
)
782888

783889
all_prompt_ids = []
784890
all_prompt_masks = []
@@ -803,10 +909,15 @@ def process_env_results_vllm(
803909
)
804910
else:
805911
assert isinstance(prompt, str) and isinstance(completion, str)
806-
prompt_ids, prompt_mask, completion_ids, completion_mask = (
807-
self.process_completion_format(prompt, completion, processing_class)
912+
(
913+
prompt_ids,
914+
prompt_mask,
915+
completion_ids,
916+
completion_mask,
917+
completion_logprobs,
918+
) = self.process_completion_format_vllm(
919+
prompt, completion, state, processing_class, mask_env_responses
808920
)
809-
completion_logprobs = [0] * len(completion_ids)
810921
is_truncated = False
811922
if max_seq_len > 0 and len(prompt_ids) + len(completion_ids) > max_seq_len:
812923
if len(prompt_ids) > max_seq_len:

0 commit comments

Comments
 (0)