Skip to content

Commit a4311d8

Browse files
authored
Allow env/reward/tool functions to be maybe async; fix for image_url sanitization (#245)
* allow env methods, reward funcs, tools to be optionally async * fix for image_url sanitization
1 parent 8e38e7f commit a4311d8

17 files changed

+158
-121
lines changed

environments/math_python/math_python.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def num_errors(parser, completion) -> float:
4343
if "error" in msg["content"].lower()
4444
]
4545
)
46-
return num_errors
46+
return float(num_errors)
4747

4848
rubric = vf.Rubric(
4949
funcs=[correct_answer_reward_func, num_turns, num_tool_calls, num_errors],

tests/test_singleturn_env.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,23 +48,25 @@ def test_singleturn_env_initialization_completion(self, mock_openai_client):
4848
)
4949
assert env.message_type == "completion"
5050

51-
def test_is_completed_method(self, mock_singleturn_env):
51+
@pytest.mark.asyncio
52+
async def test_is_completed_method(self, mock_singleturn_env):
5253
"""Test the is_completed method logic."""
5354
# No responses yet
5455
messages = [{"role": "user", "content": "Hello"}]
5556
state = {"responses": []}
56-
assert not mock_singleturn_env.is_completed(messages, state)
57+
assert not await mock_singleturn_env.is_completed(messages, state)
5758

5859
# With responses
5960
state = {"responses": [MagicMock()]}
60-
assert mock_singleturn_env.is_completed(messages, state)
61+
assert await mock_singleturn_env.is_completed(messages, state)
6162

62-
def test_env_response_method(self, mock_singleturn_env):
63+
@pytest.mark.asyncio
64+
async def test_env_response_method(self, mock_singleturn_env):
6365
"""Test the env_response method (which should never be called in practice)."""
6466
messages = [{"role": "user", "content": "Hello"}]
6567
state = {}
6668

67-
response, new_state = mock_singleturn_env.env_response(messages, state)
69+
response, new_state = await mock_singleturn_env.env_response(messages, state)
6870

6971
# Should return minimal response (env_response returns a list of messages)
7072
assert len(response) == 1
@@ -345,12 +347,12 @@ async def test_singleturn_stops_after_one_response(
345347

346348
# Before any responses
347349
state = {"responses": []}
348-
assert not env.is_completed([], state)
350+
assert not await env.is_completed([], state)
349351

350352
# After one response
351353
state = {"responses": [MagicMock()]}
352-
assert env.is_completed([], state)
354+
assert await env.is_completed([], state)
353355

354356
# Even with multiple responses (shouldn't happen), it's still completed
355357
state = {"responses": [MagicMock(), MagicMock()]}
356-
assert env.is_completed([], state)
358+
assert await env.is_completed([], state)

verifiers/envs/environment.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
SamplingArgs,
2828
State,
2929
)
30+
from verifiers.utils.message_utils import cleanup_messages
3031
from verifiers.utils.tool_utils import sanitize_tool_calls
3132

3233
if TYPE_CHECKING:
@@ -216,7 +217,6 @@ async def get_model_response(
216217
):
217218
sampling_args.pop("max_completion_tokens")
218219
clean_sampling_args = {k: v for k, v in sampling_args.items() if v is not None}
219-
220220
try:
221221
if message_type == "chat":
222222
assert isinstance(prompt, list)
@@ -385,6 +385,8 @@ async def a_generate(
385385
if self.oai_tools and "oai_tools" not in info:
386386
info["oai_tools"] = self.oai_tools
387387

388+
results_dict["prompt"] = [cleanup_messages(p) for p in results_dict["prompt"]]
389+
388390
# prepare GenerateOutputs and run rollouts
389391
results = GenerateOutputs(
390392
prompt=results_dict["prompt"],

verifiers/envs/multiturn_env.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,23 @@
1212
SamplingArgs,
1313
State,
1414
)
15+
from verifiers.utils.async_utils import maybe_await
1516

1617

1718
class MultiTurnEnv(Environment):
1819
def __init__(self, max_turns: int = 10, **kwargs):
1920
super().__init__(**kwargs)
2021
self.max_turns = max_turns
2122

22-
def setup_state(self, state: State, **kwargs) -> State:
23+
async def setup_state(self, state: State, **kwargs) -> State:
2324
return state
2425

2526
@abstractmethod
26-
def is_completed(self, messages: Messages, state: State, **kwargs) -> bool:
27+
async def is_completed(self, messages: Messages, state: State, **kwargs) -> bool:
2728
pass
2829

2930
@abstractmethod
30-
def env_response(
31+
async def env_response(
3132
self, messages: Messages, state: State, **kwargs
3233
) -> tuple[Messages, State]:
3334
"""
@@ -60,7 +61,7 @@ async def rollout(
6061
"responses": [],
6162
"turn": 0,
6263
}
63-
state = self.setup_state(state)
64+
state = await maybe_await(self.setup_state, state, **kwargs)
6465
if self.message_type == "chat":
6566
assert isinstance(prompt, list)
6667
completion = []
@@ -70,7 +71,7 @@ async def rollout(
7071
state["responses_start_idx"] = []
7172
rollout = list(prompt) if not isinstance(prompt, str) else prompt
7273
while not is_completed:
73-
if self.is_completed(rollout, state, **kwargs):
74+
if await maybe_await(self.is_completed, rollout, state, **kwargs):
7475
is_completed = True
7576
break
7677
response = await self.get_model_response(
@@ -107,12 +108,14 @@ async def rollout(
107108
completion += response_text
108109
state["turn"] += 1
109110
if (
110-
self.is_completed(rollout, state, **kwargs)
111+
await maybe_await(self.is_completed, rollout, state, **kwargs)
111112
or state["turn"] >= self.max_turns
112113
):
113114
is_completed = True
114115
else:
115-
env_msgs, state = self.env_response(rollout, state, **kwargs)
116+
env_msgs, state = await maybe_await(
117+
self.env_response, rollout, state, **kwargs
118+
)
116119
if self.message_type == "chat":
117120
assert isinstance(env_msgs, list)
118121
assert isinstance(rollout, list)

verifiers/envs/singleturn_env.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ class SingleTurnEnv(MultiTurnEnv):
77
Environment for single-turn tasks (chat or completion).
88
"""
99

10-
def is_completed(self, messages: Messages, state: State, **kwargs) -> bool:
10+
async def is_completed(self, messages: Messages, state: State, **kwargs) -> bool:
1111
return len(state["responses"]) > 0
1212

13-
def env_response(
13+
async def env_response(
1414
self, messages: Messages, state: State, **kwargs
1515
) -> tuple[Messages, State]:
1616
# never called in MultiTurnEnv.rollout

verifiers/envs/stateful_tool_env.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def update_tool_args(
3434
"""Update tool arguments and/or state (in-place) based on messages and state."""
3535
pass
3636

37-
def call_tool(
37+
async def call_tool(
3838
self, tool_name: str, tool_args: dict, tool_call_id: str, **kwargs
3939
) -> Message:
4040
"""Call a tool based on JSON command."""
@@ -53,7 +53,7 @@ def call_tool(
5353
"tool_call_id": tool_call_id,
5454
}
5555

56-
def env_response(
56+
async def env_response(
5757
self, messages: Messages, state: State, **kwargs
5858
) -> tuple[Messages, State]:
5959
assert isinstance(messages, list)
@@ -65,6 +65,8 @@ def env_response(
6565
tool_args: dict = json.loads(tool_call.function.arguments)
6666
tool_call_id: str = tool_call.id or ""
6767
tool_args = self.update_tool_args(tool_args, messages, state, **kwargs)
68-
tool_message: Message = self.call_tool(tool_name, tool_args, tool_call_id)
68+
tool_message: Message = await self.call_tool(
69+
tool_name, tool_args, tool_call_id
70+
)
6971
tool_messages.append(tool_message)
7072
return tool_messages, state

verifiers/envs/textarena_env.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,16 @@ def __init__(
7474
**kwargs,
7575
)
7676

77-
def is_completed(self, messages: Messages, state: State, **kwargs: Any) -> bool:
77+
async def is_completed(
78+
self, messages: Messages, state: State, **kwargs: Any
79+
) -> bool:
7880
if "is_finished" in state and state["is_finished"]:
7981
state.pop("ta_env")
8082
return state["is_finished"]
8183
self.parser
8284
return False
8385

84-
def env_response(
86+
async def env_response(
8587
self, messages: Messages, state: State, **kwargs: Any
8688
) -> tuple[Messages, State]:
8789
# load env

verifiers/envs/tool_env.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from verifiers.envs.multiturn_env import MultiTurnEnv
55
from verifiers.types import ChatCompletionMessageToolCall, Message, Messages, State
6+
from verifiers.utils.async_utils import maybe_await
67
from verifiers.utils.tool_utils import convert_func_to_oai_tool
78

89

@@ -21,21 +22,23 @@ def __init__(
2122
self.tool_map = {tool.__name__: tool for tool in self.tools}
2223
super().__init__(oai_tools=self.oai_tools, max_turns=max_turns, **kwargs)
2324

24-
def is_completed(self, messages: Messages, state: State, **kwargs: Any) -> bool:
25+
async def is_completed(
26+
self, messages: Messages, state: State, **kwargs: Any
27+
) -> bool:
2528
assert isinstance(messages, list)
2629
is_assistant_message = messages[-1]["role"] == "assistant"
2730
no_tool_calls = (
2831
"tool_calls" not in messages[-1] or messages[-1]["tool_calls"] is None
2932
)
3033
return is_assistant_message and no_tool_calls
3134

32-
def call_tool(
35+
async def call_tool(
3336
self, tool_name: str, tool_args: dict, tool_call_id: str, **kwargs
3437
) -> Message:
3538
"""Call a tool based on JSON command."""
3639
try:
3740
tool_func = self.tool_map[tool_name]
38-
result = str(tool_func(**tool_args))
41+
result = str(await maybe_await(tool_func, **tool_args))
3942
return {
4043
"role": "tool",
4144
"content": str(result),
@@ -48,7 +51,7 @@ def call_tool(
4851
"tool_call_id": tool_call_id,
4952
}
5053

51-
def env_response(
54+
async def env_response(
5255
self, messages: Messages, state: State, **kwargs
5356
) -> tuple[Messages, State]:
5457
assert isinstance(messages, list)
@@ -59,6 +62,8 @@ def env_response(
5962
tool_name: str = tool_call.function.name
6063
tool_args: dict = json.loads(tool_call.function.arguments)
6164
tool_call_id: str = tool_call.id or ""
62-
tool_message: Message = self.call_tool(tool_name, tool_args, tool_call_id)
65+
tool_message: Message = await self.call_tool(
66+
tool_name, tool_args, tool_call_id
67+
)
6368
tool_messages.append(tool_message)
6469
return tool_messages, state

verifiers/rubrics/judge_rubric.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from typing import Any
22

3-
from openai import OpenAI, AsyncOpenAI
3+
from openai import AsyncOpenAI, OpenAI
44

55
from verifiers.parsers.parser import Parser
66
from verifiers.rubrics.rubric import Rubric
77
from verifiers.types import Messages, State
8+
from verifiers.utils.async_utils import maybe_await
89

910
DEFAULT_JUDGE_PROMPT = """Given a ground truth answer \
1011
and a response, determine if the response is correct.
@@ -82,7 +83,8 @@ async def judge(
8283
):
8384
judge_args.pop("max_completion_tokens")
8485
judge_args = {k: v for k, v in judge_args.items() if v is not None}
85-
judge_response = await self.judge_client.chat.completions.create(
86+
judge_response = await maybe_await(
87+
self.judge_client.chat.completions.create,
8688
model=self.judge_model,
8789
messages=[{"role": "user", "content": judge_prompt}],
8890
**judge_args,

verifiers/rubrics/rubric.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
RolloutScores,
1212
State,
1313
)
14+
from verifiers.utils.async_utils import maybe_await
1415

1516

1617
class Rubric:
@@ -94,22 +95,17 @@ def func(completion, answer, **kwargs):
9495
task=task,
9596
info=info,
9697
)
97-
ans = 0.0
9898
merged = {**common, **kwargs}
9999
if any(p.kind == p.VAR_KEYWORD for p in sig.parameters.values()):
100100
try:
101-
ans = func(**merged)
102-
if inspect.iscoroutinefunction(func):
103-
ans = await ans
101+
ans = float(await maybe_await(func, **merged))
104102
except Exception as e:
105103
self.logger.error(f"Error calling reward function {func.__name__}: {e}")
106104
ans = 0.0
107105
else:
108106
allowed = {k: v for k, v in merged.items() if k in sig.parameters}
109107
try:
110-
ans = func(**allowed)
111-
if inspect.iscoroutinefunction(func):
112-
ans = await ans
108+
ans = float(await maybe_await(func, **allowed))
113109
except Exception as e:
114110
self.logger.error(f"Error calling reward function {func.__name__}: {e}")
115111
ans = 0.0

0 commit comments

Comments
 (0)