Skip to content

Commit e3bc050

Browse files
authored
adding stateful toolenv, moving tool json loading to env_response (#224)
* adding stateful toolenv, moving tool json loading to env_response * fix statefultoolenv constructor
1 parent 8c1b7fe commit e3bc050

File tree

2 files changed

+74
-4
lines changed

2 files changed

+74
-4
lines changed

verifiers/envs/stateful_tool_env.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import json
2+
from abc import abstractmethod
3+
from typing import Callable
4+
5+
from verifiers.envs.tool_env import ToolEnv
6+
from verifiers.types import ChatCompletionMessageToolCall, Message, Messages, State
7+
from verifiers.utils.tool_utils import convert_func_to_oai_tool
8+
9+
10+
class StatefulToolEnv(ToolEnv):
11+
def __init__(
12+
self,
13+
tools: list[Callable] | None = None,
14+
max_turns: int = 10,
15+
error_formatter: Callable[[Exception], str] = lambda e: f"{str(e)}",
16+
**kwargs,
17+
):
18+
super().__init__(
19+
tools=tools,
20+
max_turns=max_turns,
21+
error_formatter=error_formatter,
22+
**kwargs,
23+
)
24+
self.tools = tools or []
25+
self.max_turns = max_turns
26+
self.error_formatter = error_formatter
27+
self.oai_tools = [convert_func_to_oai_tool(tool) for tool in self.tools]
28+
self.tool_map = {tool.__name__: tool for tool in self.tools}
29+
30+
@abstractmethod
31+
def update_tool_args(
32+
self, tool_args: dict, messages: Messages, state: State, **kwargs
33+
) -> dict:
34+
"""Update tool arguments and/or state (in-place) based on messages and state."""
35+
pass
36+
37+
def call_tool(
38+
self, tool_name: str, tool_args: dict, tool_call_id: str, **kwargs
39+
) -> Message:
40+
"""Call a tool based on JSON command."""
41+
try:
42+
tool_func = self.tool_map[tool_name]
43+
result = str(tool_func(**tool_args))
44+
return {
45+
"role": "tool",
46+
"content": str(result),
47+
"tool_call_id": tool_call_id,
48+
}
49+
except Exception as e:
50+
return {
51+
"role": "tool",
52+
"content": self.error_formatter(e),
53+
"tool_call_id": tool_call_id,
54+
}
55+
56+
def env_response(
57+
self, messages: Messages, state: State, **kwargs
58+
) -> tuple[Messages, State]:
59+
assert isinstance(messages, list)
60+
assert "tool_calls" in messages[-1]
61+
tool_messages = []
62+
for tool_call in messages[-1]["tool_calls"]:
63+
assert isinstance(tool_call, ChatCompletionMessageToolCall)
64+
tool_name: str = tool_call.function.name
65+
tool_args: dict = json.loads(tool_call.function.arguments)
66+
tool_call_id: str = tool_call.id or ""
67+
tool_args = self.update_tool_args(tool_args, messages, state, **kwargs)
68+
tool_message: Message = self.call_tool(tool_name, tool_args, tool_call_id)
69+
tool_messages.append(tool_message)
70+
return tool_messages, state

verifiers/envs/tool_env.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(
1919
self.error_formatter = error_formatter
2020
self.oai_tools = [convert_func_to_oai_tool(tool) for tool in self.tools]
2121
self.tool_map = {tool.__name__: tool for tool in self.tools}
22-
super().__init__(oai_tools=self.oai_tools, **kwargs)
22+
super().__init__(oai_tools=self.oai_tools, max_turns=max_turns, **kwargs)
2323

2424
def is_completed(self, messages: Messages, state: State, **kwargs: Any) -> bool:
2525
assert isinstance(messages, list)
@@ -30,12 +30,12 @@ def is_completed(self, messages: Messages, state: State, **kwargs: Any) -> bool:
3030
return is_assistant_message and no_tool_calls
3131

3232
def call_tool(
33-
self, tool_name: str, tool_args: str, tool_call_id: str, **kwargs
33+
self, tool_name: str, tool_args: dict, tool_call_id: str, **kwargs
3434
) -> Message:
3535
"""Call a tool based on JSON command."""
3636
try:
3737
tool_func = self.tool_map[tool_name]
38-
result = str(tool_func(**json.loads(tool_args)))
38+
result = str(tool_func(**tool_args))
3939
return {
4040
"role": "tool",
4141
"content": str(result),
@@ -57,7 +57,7 @@ def env_response(
5757
for tool_call in messages[-1]["tool_calls"]:
5858
assert isinstance(tool_call, ChatCompletionMessageToolCall)
5959
tool_name: str = tool_call.function.name
60-
tool_args: str = tool_call.function.arguments
60+
tool_args: dict = json.loads(tool_call.function.arguments)
6161
tool_call_id: str = tool_call.id or ""
6262
tool_message: Message = self.call_tool(tool_name, tool_args, tool_call_id)
6363
tool_messages.append(tool_message)

0 commit comments

Comments
 (0)