Skip to content

Commit c054ff9

Browse files
Allow unsetting max_tokens in eval script (#241)
* Improve max_tokens handling and parsing across environments and scripts Co-authored-by: lakshyajannu <[email protected]> * Refactor max_tokens argument handling in eval.py to accept integer input directly and simplify parsing logic --------- Co-authored-by: Cursor Agent <[email protected]>
1 parent fcc0267 commit c054ff9

File tree

3 files changed

+36
-11
lines changed

3 files changed

+36
-11
lines changed

verifiers/envs/environment.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,27 +199,39 @@ async def get_model_response(
199199
Returns special error messages for context length issues.
200200
"""
201201
sampling_args = sampling_args or {}
202+
# Resolve message type first
203+
if message_type is None:
204+
message_type = self.message_type
205+
# Normalize sampling args:
206+
# - If max_tokens is provided for chat, rename to max_completion_tokens
207+
# - Drop any None-valued entries to avoid sending them to the client
202208
if "max_tokens" in sampling_args:
203-
sampling_args["max_completion_tokens"] = sampling_args.pop("max_tokens")
209+
if sampling_args["max_tokens"] is None:
210+
sampling_args.pop("max_tokens")
211+
elif message_type == "chat":
212+
sampling_args["max_completion_tokens"] = sampling_args.pop("max_tokens")
213+
if (
214+
"max_completion_tokens" in sampling_args
215+
and sampling_args["max_completion_tokens"] is None
216+
):
217+
sampling_args.pop("max_completion_tokens")
218+
clean_sampling_args = {k: v for k, v in sampling_args.items() if v is not None}
204219

205220
try:
206-
if message_type is None:
207-
message_type = self.message_type
208-
209221
if message_type == "chat":
210222
assert isinstance(prompt, list)
211223
if oai_tools:
212224
response = await client.chat.completions.create(
213225
model=model,
214226
messages=prompt, # type: ignore
215227
tools=oai_tools,
216-
**sampling_args,
228+
**clean_sampling_args,
217229
)
218230
else:
219231
response = await client.chat.completions.create(
220232
model=model,
221233
messages=prompt, # type: ignore
222-
**sampling_args,
234+
**clean_sampling_args,
223235
)
224236
return response
225237
elif message_type == "completion":
@@ -229,7 +241,7 @@ async def get_model_response(
229241
)
230242
assert isinstance(prompt, str)
231243
response = await client.completions.create(
232-
model=model, prompt=prompt, **sampling_args
244+
model=model, prompt=prompt, **clean_sampling_args
233245
)
234246
return response
235247
except Exception as e:

verifiers/rubrics/judge_rubric.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,23 @@ async def judge(
6969
cached = state.get("judge_response")
7070
if isinstance(cached, dict) and judge_prompt in cached:
7171
return cached[judge_prompt]
72+
# Normalize judge sampling args for chat API
73+
judge_args = dict(self.judge_sampling_args or {})
74+
if "max_tokens" in judge_args:
75+
if judge_args["max_tokens"] is None:
76+
judge_args.pop("max_tokens")
77+
else:
78+
judge_args["max_completion_tokens"] = judge_args.pop("max_tokens")
79+
if (
80+
"max_completion_tokens" in judge_args
81+
and judge_args["max_completion_tokens"] is None
82+
):
83+
judge_args.pop("max_completion_tokens")
84+
judge_args = {k: v for k, v in judge_args.items() if v is not None}
7285
judge_response = await self.judge_client.chat.completions.create(
7386
model=self.judge_model,
7487
messages=[{"role": "user", "content": judge_prompt}],
75-
**self.judge_sampling_args,
88+
**judge_args,
7689
)
7790
judge_response = str(judge_response.choices[0].message.content)
7891
if not isinstance(cached, dict):

verifiers/scripts/eval.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def eval_environment(
2626
num_examples: int,
2727
rollouts_per_example: int,
2828
max_concurrent_requests: int,
29-
max_tokens: int,
29+
max_tokens: int | None,
3030
temperature: float | None,
3131
verbose: bool,
3232
save_dataset: bool,
@@ -252,8 +252,8 @@ def main():
252252
"--max-tokens",
253253
"-t",
254254
type=int,
255-
default=1024,
256-
help="Maximum number of tokens to generate",
255+
default=None,
256+
help="Maximum number of tokens to generate (unset to use model default)",
257257
)
258258
parser.add_argument(
259259
"--temperature", "-T", type=float, default=None, help="Temperature for sampling"

0 commit comments

Comments
 (0)