Skip to content

Commit 8e38e7f

Browse files
lakshyaagwillccbb
andauthored
Add sampling_args flag to vf-eval (#240)
* Add `sampling_args` flag to `vf-eval` * Update README to include usage of `sampling_args` in `vf-eval` * ruff fix --------- Co-authored-by: William Brown <[email protected]>
1 parent c054ff9 commit 8e38e7f

File tree

3 files changed

+152
-8
lines changed

3 files changed

+152
-8
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,12 @@ For tasks involving LLM judges, you may wish to use `vf.JudgeRubric()` for manag
144144

145145
Note on concurrency: environment APIs accept `max_concurrent` to control parallel rollouts. The `vf-eval` CLI currently exposes `--max-concurrent-requests`; ensure this maps to your environment’s concurrency as expected.
146146

147+
`vf-eval` also supports specifying `sampling_args` as a JSON object, which is sent to the vLLM inference engine:
148+
149+
```bash
150+
vf-eval vf-environment-name --sampling-args '{"reasoning_effort": "low"}'
151+
```
152+
147153
### ToolEnv
148154

149155
For many applications involving tool use, you can use `ToolEnv` to leverage models' native tool/function-calling capabilities in an agentic loop. Tools can be specified as generic Python functions (with type hints and docstrings), which will then be passed in JSON schema form to each inference request.

tests/test_eval_cli.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import verifiers.scripts.eval as vf_eval
2+
3+
4+
def _make_fake_env(captured):
5+
class FakeEnv:
6+
def evaluate(
7+
self,
8+
client,
9+
model,
10+
sampling_args=None,
11+
num_examples=-1,
12+
rollouts_per_example=1,
13+
**kwargs,
14+
):
15+
captured["sampling_args"] = dict(sampling_args or {})
16+
17+
class Result:
18+
prompt = ["p"]
19+
completion = ["c"]
20+
reward = [1.0]
21+
info = [{}]
22+
task = ["default"]
23+
answer = [""]
24+
metrics = {}
25+
26+
return Result()
27+
28+
return FakeEnv()
29+
30+
31+
def test_cli_sampling_args_precedence_over_flags(monkeypatch):
32+
captured = {}
33+
34+
# Patch environment loader to return our fake env
35+
monkeypatch.setattr(
36+
vf_eval.vf,
37+
"load_environment",
38+
lambda env_id, **env_args: _make_fake_env(captured),
39+
)
40+
41+
# Patch OpenAI client used by the CLI to a simple dummy
42+
class DummyOpenAI:
43+
def __init__(self, api_key=None, base_url=None):
44+
self.api_key = api_key
45+
self.base_url = base_url
46+
47+
monkeypatch.setattr(vf_eval, "OpenAI", DummyOpenAI)
48+
49+
# Run evaluation with JSON sampling args overriding flags
50+
vf_eval.eval_environment(
51+
env="dummy-env",
52+
env_args={},
53+
env_dir_path="./environments",
54+
endpoints_path="./configs/endpoints.py",
55+
model="gpt-4.1-mini",
56+
api_key_var="OPENAI_API_KEY",
57+
api_base_url="https://api.openai.com/v1",
58+
num_examples=1,
59+
rollouts_per_example=1,
60+
max_concurrent_requests=1,
61+
max_tokens=42,
62+
temperature=0.9,
63+
sampling_args={
64+
"enable_thinking": False,
65+
"max_tokens": 77,
66+
"temperature": 0.1,
67+
},
68+
verbose=False,
69+
save_dataset=False,
70+
save_to_hf_hub=False,
71+
hf_hub_dataset_name="",
72+
)
73+
74+
sa = captured["sampling_args"]
75+
assert sa["max_tokens"] == 77
76+
assert sa["temperature"] == 0.1
77+
assert sa["enable_thinking"] is False
78+
79+
80+
def test_cli_sampling_args_fill_from_flags_when_missing(monkeypatch):
81+
captured = {}
82+
83+
# Patch environment loader to return our fake env
84+
monkeypatch.setattr(
85+
vf_eval.vf,
86+
"load_environment",
87+
lambda env_id, **env_args: _make_fake_env(captured),
88+
)
89+
90+
# Patch OpenAI client used by the CLI to a simple dummy
91+
class DummyOpenAI:
92+
def __init__(self, api_key=None, base_url=None):
93+
self.api_key = api_key
94+
self.base_url = base_url
95+
96+
monkeypatch.setattr(vf_eval, "OpenAI", DummyOpenAI)
97+
98+
# Run evaluation with JSON lacking max_tokens/temperature
99+
vf_eval.eval_environment(
100+
env="dummy-env",
101+
env_args={},
102+
env_dir_path="./environments",
103+
endpoints_path="./configs/endpoints.py",
104+
model="gpt-4.1-mini",
105+
api_key_var="OPENAI_API_KEY",
106+
api_base_url="https://api.openai.com/v1",
107+
num_examples=1,
108+
rollouts_per_example=1,
109+
max_concurrent_requests=1,
110+
max_tokens=55,
111+
temperature=0.8,
112+
sampling_args={
113+
"enable_thinking": True,
114+
},
115+
verbose=False,
116+
save_dataset=False,
117+
save_to_hf_hub=False,
118+
hf_hub_dataset_name="",
119+
)
120+
121+
sa = captured["sampling_args"]
122+
assert sa["max_tokens"] == 55
123+
assert sa["temperature"] == 0.8
124+
assert sa["enable_thinking"] is True

verifiers/scripts/eval.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def eval_environment(
2828
max_concurrent_requests: int,
2929
max_tokens: int | None,
3030
temperature: float | None,
31+
sampling_args: dict | None,
3132
verbose: bool,
3233
save_dataset: bool,
3334
save_to_hf_hub: bool,
@@ -62,15 +63,18 @@ def eval_environment(
6263

6364
client = OpenAI(api_key=os.getenv(api_key_var, "EMPTY"), base_url=api_base_url)
6465
vf_env = vf.load_environment(env_id=env, **env_args)
65-
sampling_args: dict[str, int | float | None] = {
66-
"max_tokens": max_tokens,
67-
}
68-
if temperature is not None:
69-
sampling_args["temperature"] = temperature
66+
# Merge sampling args with precedence to JSON payload over explicit flags
67+
merged_sampling_args: dict = {}
68+
if sampling_args is not None:
69+
merged_sampling_args.update(sampling_args)
70+
if "max_tokens" not in merged_sampling_args:
71+
merged_sampling_args["max_tokens"] = max_tokens
72+
if temperature is not None and "temperature" not in merged_sampling_args:
73+
merged_sampling_args["temperature"] = temperature
7074
results = vf_env.evaluate(
7175
client=client,
7276
model=model,
73-
sampling_args=sampling_args,
77+
sampling_args=merged_sampling_args,
7478
num_examples=num_examples,
7579
rollouts_per_example=rollouts_per_example,
7680
max_concurrent_requests=max_concurrent_requests,
@@ -143,8 +147,7 @@ def eval_environment(
143147
"model": model,
144148
"num_examples": num_examples,
145149
"rollouts_per_example": rollouts_per_example,
146-
"max_tokens": max_tokens,
147-
"temperature": temperature,
150+
"sampling_args": merged_sampling_args,
148151
"date": datetime.now().strftime("%Y-%m-%d"),
149152
"time": datetime.now().strftime("%H:%M:%S"),
150153
"avg_reward": sum(results.reward) / len(results.reward),
@@ -258,6 +261,16 @@ def main():
258261
parser.add_argument(
259262
"--temperature", "-T", type=float, default=None, help="Temperature for sampling"
260263
)
264+
parser.add_argument(
265+
"--sampling-args",
266+
"-S",
267+
type=json.loads,
268+
default=None,
269+
help=(
270+
"Sampling arguments as JSON object. Keys here override --max-tokens/--temperature. "
271+
'Example: \'{"enable_thinking": false, "max_tokens": 256}\''
272+
),
273+
)
261274
parser.add_argument(
262275
"--verbose", "-v", default=False, action="store_true", help="Verbose output"
263276
)
@@ -297,6 +310,7 @@ def main():
297310
max_concurrent_requests=args.max_concurrent_requests,
298311
max_tokens=args.max_tokens,
299312
temperature=args.temperature,
313+
sampling_args=args.sampling_args,
300314
verbose=args.verbose,
301315
save_dataset=args.save_dataset,
302316
save_to_hf_hub=args.save_to_hf_hub,

0 commit comments

Comments
 (0)