Skip to content

Commit aef9f21

Browse files
authored
fix saving dataset to HF, toolcall sanitizing (#246)
1 parent a4311d8 commit aef9f21

File tree

6 files changed

+47
-130
lines changed

6 files changed

+47
-130
lines changed

tests/test_environment_extra.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from verifiers.envs.environment import Environment
2121
from verifiers.parsers.parser import Parser
2222
from verifiers.rubrics.rubric import Rubric
23-
from verifiers.types import GenerateOutputs
24-
from verifiers.utils.tool_utils import sanitize_tool_calls
23+
from verifiers.types import GenerateOutputs, Info, Messages, SamplingArgs
24+
from verifiers.utils.message_utils import sanitize_tool_calls
2525

2626

2727
# Local simple concrete Environment for testing
@@ -30,16 +30,17 @@ async def rollout(
3030
self,
3131
client,
3232
model,
33-
prompt,
33+
prompt: Messages,
3434
answer: str = "",
3535
task: str = "default",
36-
info: dict = {},
37-
sampling_args: dict = {},
36+
info: Info | None = {},
37+
sampling_args: SamplingArgs | None = None,
3838
**kwargs,
3939
):
4040
response = await self.get_model_response(
4141
prompt=prompt, client=client, model=model, sampling_args=sampling_args
4242
)
43+
assert response is not None
4344
if self.message_type == "chat":
4445
completion = [
4546
{"role": "assistant", "content": response.choices[0].message.content}

verifiers/envs/environment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
SamplingArgs,
2828
State,
2929
)
30-
from verifiers.utils.message_utils import cleanup_messages
31-
from verifiers.utils.tool_utils import sanitize_tool_calls
30+
from verifiers.utils.message_utils import cleanup_messages, sanitize_tool_calls
3231

3332
if TYPE_CHECKING:
3433
from transformers.tokenization_utils_base import ( # type: ignore
@@ -519,6 +518,7 @@ def make_dataset(
519518
"""
520519
Make a dataset from the evaluation results.
521520
"""
521+
# TODO: enable saving of multimodal datasets
522522
state_columns = state_columns or []
523523

524524
if push_to_hub and hub_name is None:

verifiers/scripts/eval.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from openai import OpenAI
1313

1414
import verifiers as vf
15-
from verifiers.utils.message_utils import messages_to_printable
15+
from verifiers.utils.message_utils import messages_to_printable, sanitize_tool_calls
1616

1717

1818
def eval_environment(
@@ -124,8 +124,8 @@ def eval_environment(
124124
tasks = results.task
125125
data_dict = {
126126
"id": ids,
127-
"prompt": printable_prompts,
128-
"completion": printable_completions,
127+
"prompt": [sanitize_tool_calls(p) for p in printable_prompts],
128+
"completion": [sanitize_tool_calls(c) for c in printable_completions],
129129
"task": tasks,
130130
}
131131
if results.info[0] != {}:
@@ -170,9 +170,7 @@ def eval_environment(
170170
print(f"Saved dataset to {results_path}")
171171
if save_to_hf_hub:
172172
if hf_hub_dataset_name == "":
173-
dataset_name = (
174-
f"{env}_{model}_n={num_examples}_r={rollouts_per_example}"
175-
)
173+
dataset_name = f"{env}_{model.replace('/', '-')}_n{num_examples}_r{rollouts_per_example}"
176174
else:
177175
dataset_name = hf_hub_dataset_name
178176
dataset.push_to_hub(dataset_name)

verifiers/utils/logging_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def _format_messages(messages) -> Text:
6161
out.append(content, style=style)
6262
if "tool_calls" in msg:
6363
for tool_call in msg["tool_calls"]:
64-
tool_call_str = json.dumps(dict(tool_call["function"]), indent=2)
64+
name = getattr(tool_call.function, "name", "")
65+
args = getattr(tool_call.function, "arguments", {})
66+
tool_call_str = json.dumps({"name": name, "args": args}, indent=2)
6567
out.append(f"\n\n[tool call]\n{tool_call_str}", style=style)
6668
return out
6769

verifiers/utils/message_utils.py

Lines changed: 32 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,17 @@
1-
from collections.abc import Iterable
1+
import json
22
from typing import cast
33

44
from verifiers.types import ChatMessage, Messages
55

66

7-
def sanitize_object(obj: object):
8-
"""
9-
Recursively convert Pydantic/OpenAI SDK objects to plain Python types
10-
(dict/list/str/bool/number). Leaves primitives unchanged.
11-
"""
12-
if isinstance(obj, (str, bytes, bytearray, int, float, bool)) or obj is None:
13-
return obj
14-
dump = getattr(obj, "model_dump", None)
15-
if callable(dump):
16-
obj = dump()
17-
if isinstance(obj, dict):
18-
return {k: sanitize_object(v) for k, v in obj.items()}
19-
# check if obj is iterable
20-
if isinstance(obj, Iterable):
21-
return [sanitize_object(x) for x in obj]
22-
return obj
23-
24-
25-
def sanitize_chat_message(message: ChatMessage):
26-
"""
27-
input: chat message (dict or object)
28-
output: chat message (dict)
29-
"""
30-
# TODO: debug for multimodal messages; content can get consumed as an iterator
31-
new_message = {}
32-
dump = getattr(message, "model_dump", None)
33-
if callable(dump):
34-
new_message = dump()
35-
return new_message
36-
assert isinstance(message, dict)
37-
assert isinstance(new_message, dict)
38-
new_message["role"] = message["role"]
39-
if "content" in message and message["content"]:
40-
content = message["content"]
41-
if isinstance(content, str):
42-
new_message["content"] = content
43-
else:
44-
new_message["content"] = []
45-
parts = list(content) if not isinstance(content, str) else content
46-
for c in parts:
47-
if isinstance(c, str):
48-
new_message["content"].append(c)
49-
else:
50-
new_message["content"].append(sanitize_object(c))
51-
if "tool_calls" in message and message["tool_calls"]:
52-
tool_calls = list(message["tool_calls"])
53-
new_message["tool_calls"] = [
54-
sanitize_object(tool_call) for tool_call in tool_calls
55-
]
56-
return new_message
57-
58-
59-
def sanitize_messages(messages: Messages) -> str | list:
60-
"""
61-
input: list of dicts or Pydantic models, or str
62-
output: list of dicts, or str
63-
"""
64-
if isinstance(messages, str):
65-
return messages
66-
sanitized_list = [sanitize_chat_message(m) for m in list(messages)]
67-
return sanitized_list
68-
69-
70-
def content_to_printable(content: object) -> str:
7+
def message_to_printable(message: ChatMessage) -> ChatMessage:
718
"""
72-
Render content to readable text, handling multimodal lists.
73-
- Text parts: return their text
74-
- Image-like parts: return "[image]"
75-
Falls back to str(content).
9+
Removes image_url objects from message content.
7610
"""
77-
print(str(content)[:100])
78-
if isinstance(content, str):
79-
return content
80-
if isinstance(content, dict):
81-
if "type" in content and content["type"] == "text":
82-
return content["text"]
83-
if "type" in content and content["type"] in {
84-
"image_url",
85-
"input_image",
86-
"image",
87-
}:
88-
return "[image]"
89-
if isinstance(content, (list, tuple)):
90-
out = []
91-
for x in content:
92-
out.append(content_to_printable(x))
93-
return "\n\n".join(out)
94-
return str(content)
95-
96-
97-
def message_to_printable(message: ChatMessage) -> ChatMessage:
9811
new_message = {}
9912
new_message["role"] = message["role"]
10013
new_message["content"] = []
101-
if "tool_calls" in message and message["tool_calls"]:
14+
if "tool_calls" in message:
10215
new_message["tool_calls"] = message["tool_calls"]
10316
content = message.get("content")
10417
if content is None:
@@ -121,6 +34,9 @@ def message_to_printable(message: ChatMessage) -> ChatMessage:
12134

12235

12336
def messages_to_printable(messages: Messages) -> Messages:
37+
"""
38+
Removes image_url objects from messages.
39+
"""
12440
if isinstance(messages, str):
12541
return messages
12642
return [message_to_printable(m) for m in messages]
@@ -129,6 +45,8 @@ def messages_to_printable(messages: Messages) -> Messages:
12945
def cleanup_message(message: ChatMessage) -> ChatMessage:
13046
new_message = {}
13147
new_message["role"] = message["role"]
48+
if "tool_calls" in message:
49+
new_message["tool_calls"] = message["tool_calls"]
13250
new_message["content"] = []
13351
content = message.get("content")
13452
if content is None:
@@ -161,3 +79,26 @@ def cleanup_messages(messages: Messages) -> Messages:
16179
for m in messages:
16280
new_messages.append(cleanup_message(m))
16381
return new_messages
82+
83+
84+
def sanitize_tool_calls(messages: Messages):
85+
"""
86+
Sanitize tool calls from messages.
87+
"""
88+
if not isinstance(messages, list):
89+
return messages
90+
sanitized_messages = []
91+
for m in messages:
92+
if "tool_calls" in m:
93+
new_m = {
94+
"role": m["role"],
95+
"content": m.get("content", ""),
96+
"tool_calls": [
97+
json.dumps(tc.model_dump()) # type: ignore
98+
for tc in m.get("tool_calls", [])
99+
],
100+
}
101+
sanitized_messages.append(new_m)
102+
else:
103+
sanitized_messages.append(m)
104+
return sanitized_messages

verifiers/utils/tool_utils.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
from __future__ import annotations
22

33
import inspect
4-
import json
54
import re
65
from typing import Any, Literal, Union, get_args, get_origin
76

87
from verifiers.types import (
98
ChatCompletionToolParam,
109
FunctionParameters,
1110
JsonPrimitive,
12-
Messages,
1311
)
1412

1513
_JSON_PRIMITIVE_MAP: dict[type, JsonPrimitive] = {
@@ -178,26 +176,3 @@ def convert_func_to_oai_tool(func: Any) -> ChatCompletionToolParam:
178176
"parameters": parameters_schema,
179177
},
180178
}
181-
182-
183-
def sanitize_tool_calls(messages: Messages):
184-
"""
185-
Sanitize tool calls from messages.
186-
"""
187-
if not isinstance(messages, list):
188-
return messages
189-
sanitized_messages = []
190-
for m in messages:
191-
if "tool_calls" in m:
192-
new_m = {
193-
"role": m["role"],
194-
"content": m.get("content", ""),
195-
"tool_calls": [
196-
json.dumps(tc.model_dump()) # type: ignore
197-
for tc in m.get("tool_calls", [])
198-
],
199-
}
200-
sanitized_messages.append(new_m)
201-
else:
202-
sanitized_messages.append(m)
203-
return sanitized_messages

0 commit comments

Comments
 (0)