Skip to content

Commit 0e6098b

Browse files
committed
reorg funcs in env
1 parent f0d2ce5 commit 0e6098b

File tree

1 file changed

+107
-107
lines changed

1 file changed

+107
-107
lines changed

verifiers/envs/environment.py

Lines changed: 107 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,113 @@ def generate(
455455
# shutdown the executor to prevent thread leaks
456456
executor.shutdown(wait=False)
457457

458+
#########################################################
459+
# Helper functions for evaluation and dataset generation
460+
#########################################################
461+
462+
def evaluate(
463+
self,
464+
client: AsyncOpenAI | OpenAI,
465+
model: str,
466+
sampling_args: SamplingArgs | None = None,
467+
num_examples: int = -1,
468+
rollouts_per_example: int = 1,
469+
score_rollouts: bool = True,
470+
max_concurrent: int = -1,
471+
**kwargs,
472+
) -> GenerateOutputs:
473+
"""
474+
Evaluate model on the Environment evaluation dataset.
475+
"""
476+
if self.eval_dataset is None:
477+
self.logger.info("eval_dataset is not set, falling back to train dataset")
478+
assert self.dataset is not None
479+
inputs = self.get_dataset(n=num_examples)
480+
else:
481+
inputs = self.get_eval_dataset(n=num_examples)
482+
assert inputs is not None, "No dataset found"
483+
if rollouts_per_example > 1:
484+
inputs = inputs.repeat(rollouts_per_example)
485+
results = self.generate(
486+
inputs,
487+
client,
488+
model,
489+
sampling_args,
490+
score_rollouts,
491+
max_concurrent,
492+
**kwargs,
493+
)
494+
return results
495+
496+
def _sanitize_tool_calls(self, completion: Messages) -> Messages:
497+
"""
498+
Sanitize tool calls from a completion.
499+
"""
500+
501+
assert isinstance(completion, list)
502+
sanitized_completion = []
503+
for m in completion:
504+
if "tool_calls" in m:
505+
new_m = {
506+
"role": m["role"],
507+
"content": m.get("content", ""),
508+
"tool_calls": [
509+
json.dumps(tc.model_dump()) # type: ignore
510+
for tc in m.get("tool_calls", [])
511+
],
512+
}
513+
sanitized_completion.append(new_m)
514+
else:
515+
sanitized_completion.append(m)
516+
return sanitized_completion
517+
518+
def make_dataset(
519+
self,
520+
results: GenerateOutputs,
521+
push_to_hub: bool = False,
522+
hub_name: str | None = None,
523+
state_columns: list[str] | None = None,
524+
**kwargs,
525+
) -> Dataset:
526+
"""
527+
Make a dataset from the evaluation results.
528+
"""
529+
state_columns = state_columns or []
530+
531+
if push_to_hub and hub_name is None:
532+
raise ValueError("hub_name must be provided if push_to_hub is True")
533+
534+
cols = ["prompt", "completion", "answer", "info", "task", "reward"]
535+
536+
results_dict = {
537+
"prompt": results.prompt,
538+
"completion": [],
539+
"answer": results.answer,
540+
"info": results.info,
541+
"task": results.task,
542+
"reward": results.reward,
543+
}
544+
for i in range(len(results.completion)):
545+
results_dict["completion"].append(
546+
self._sanitize_tool_calls(results.completion[i])
547+
)
548+
results_dict.update(results.metrics)
549+
cols.extend(results.metrics.keys())
550+
if results.state[0] is not None:
551+
for col in state_columns:
552+
if col in results.state[0]:
553+
results_dict[col] = [state[col] for state in results.state]
554+
cols.append(col)
555+
else:
556+
self.logger.warning(
557+
f"Column {col} not found in state, skipping from dataset."
558+
)
559+
dataset = Dataset.from_dict({col: results_dict[col] for col in cols})
560+
if push_to_hub:
561+
assert hub_name is not None
562+
dataset.push_to_hub(hub_name)
563+
return dataset
564+
458565
#########################################################
459566
# Optional helper functions for parsing vLLM completions
460567
#########################################################
@@ -777,110 +884,3 @@ def process_env_results_vllm(
777884

778885
# alias for process_env_results_vllm
779886
process_env_results = process_env_results_vllm
780-
781-
#########################################################
782-
# Helper functions for evaluation and dataset generation
783-
#########################################################
784-
785-
def evaluate(
786-
self,
787-
client: AsyncOpenAI | OpenAI,
788-
model: str,
789-
sampling_args: SamplingArgs | None = None,
790-
num_examples: int = -1,
791-
rollouts_per_example: int = 1,
792-
score_rollouts: bool = True,
793-
max_concurrent: int = -1,
794-
**kwargs,
795-
) -> GenerateOutputs:
796-
"""
797-
Evaluate model on the Environment evaluation dataset.
798-
"""
799-
if self.eval_dataset is None:
800-
self.logger.info("eval_dataset is not set, falling back to train dataset")
801-
assert self.dataset is not None
802-
inputs = self.get_dataset(n=num_examples)
803-
else:
804-
inputs = self.get_eval_dataset(n=num_examples)
805-
assert inputs is not None, "No dataset found"
806-
if rollouts_per_example > 1:
807-
inputs = inputs.repeat(rollouts_per_example)
808-
results = self.generate(
809-
inputs,
810-
client,
811-
model,
812-
sampling_args,
813-
score_rollouts,
814-
max_concurrent,
815-
**kwargs,
816-
)
817-
return results
818-
819-
def _sanitize_tool_calls(self, completion: Messages) -> Messages:
820-
"""
821-
Sanitize tool calls from a completion.
822-
"""
823-
824-
assert isinstance(completion, list)
825-
sanitized_completion = []
826-
for m in completion:
827-
if "tool_calls" in m:
828-
new_m = {
829-
"role": m["role"],
830-
"content": m.get("content", ""),
831-
"tool_calls": [
832-
json.dumps(tc.model_dump()) # type: ignore
833-
for tc in m.get("tool_calls", [])
834-
],
835-
}
836-
sanitized_completion.append(new_m)
837-
else:
838-
sanitized_completion.append(m)
839-
return sanitized_completion
840-
841-
def make_dataset(
842-
self,
843-
results: GenerateOutputs,
844-
push_to_hub: bool = False,
845-
hub_name: str | None = None,
846-
state_columns: list[str] | None = None,
847-
**kwargs,
848-
) -> Dataset:
849-
"""
850-
Make a dataset from the evaluation results.
851-
"""
852-
state_columns = state_columns or []
853-
854-
if push_to_hub and hub_name is None:
855-
raise ValueError("hub_name must be provided if push_to_hub is True")
856-
857-
cols = ["prompt", "completion", "answer", "info", "task", "reward"]
858-
859-
results_dict = {
860-
"prompt": results.prompt,
861-
"completion": [],
862-
"answer": results.answer,
863-
"info": results.info,
864-
"task": results.task,
865-
"reward": results.reward,
866-
}
867-
for i in range(len(results.completion)):
868-
results_dict["completion"].append(
869-
self._sanitize_tool_calls(results.completion[i])
870-
)
871-
results_dict.update(results.metrics)
872-
cols.extend(results.metrics.keys())
873-
if results.state[0] is not None:
874-
for col in state_columns:
875-
if col in results.state[0]:
876-
results_dict[col] = [state[col] for state in results.state]
877-
cols.append(col)
878-
else:
879-
self.logger.warning(
880-
f"Column {col} not found in state, skipping from dataset."
881-
)
882-
dataset = Dataset.from_dict({col: results_dict[col] for col in cols})
883-
if push_to_hub:
884-
assert hub_name is not None
885-
dataset.push_to_hub(hub_name)
886-
return dataset

0 commit comments

Comments
 (0)