@@ -455,6 +455,113 @@ def generate(
455
455
# shutdown the executor to prevent thread leaks
456
456
executor .shutdown (wait = False )
457
457
458
+ #########################################################
459
+ # Helper functions for evaluation and dataset generation
460
+ #########################################################
461
+
462
+ def evaluate (
463
+ self ,
464
+ client : AsyncOpenAI | OpenAI ,
465
+ model : str ,
466
+ sampling_args : SamplingArgs | None = None ,
467
+ num_examples : int = - 1 ,
468
+ rollouts_per_example : int = 1 ,
469
+ score_rollouts : bool = True ,
470
+ max_concurrent : int = - 1 ,
471
+ ** kwargs ,
472
+ ) -> GenerateOutputs :
473
+ """
474
+ Evaluate model on the Environment evaluation dataset.
475
+ """
476
+ if self .eval_dataset is None :
477
+ self .logger .info ("eval_dataset is not set, falling back to train dataset" )
478
+ assert self .dataset is not None
479
+ inputs = self .get_dataset (n = num_examples )
480
+ else :
481
+ inputs = self .get_eval_dataset (n = num_examples )
482
+ assert inputs is not None , "No dataset found"
483
+ if rollouts_per_example > 1 :
484
+ inputs = inputs .repeat (rollouts_per_example )
485
+ results = self .generate (
486
+ inputs ,
487
+ client ,
488
+ model ,
489
+ sampling_args ,
490
+ score_rollouts ,
491
+ max_concurrent ,
492
+ ** kwargs ,
493
+ )
494
+ return results
495
+
496
+ def _sanitize_tool_calls (self , completion : Messages ) -> Messages :
497
+ """
498
+ Sanitize tool calls from a completion.
499
+ """
500
+
501
+ assert isinstance (completion , list )
502
+ sanitized_completion = []
503
+ for m in completion :
504
+ if "tool_calls" in m :
505
+ new_m = {
506
+ "role" : m ["role" ],
507
+ "content" : m .get ("content" , "" ),
508
+ "tool_calls" : [
509
+ json .dumps (tc .model_dump ()) # type: ignore
510
+ for tc in m .get ("tool_calls" , [])
511
+ ],
512
+ }
513
+ sanitized_completion .append (new_m )
514
+ else :
515
+ sanitized_completion .append (m )
516
+ return sanitized_completion
517
+
518
+ def make_dataset (
519
+ self ,
520
+ results : GenerateOutputs ,
521
+ push_to_hub : bool = False ,
522
+ hub_name : str | None = None ,
523
+ state_columns : list [str ] | None = None ,
524
+ ** kwargs ,
525
+ ) -> Dataset :
526
+ """
527
+ Make a dataset from the evaluation results.
528
+ """
529
+ state_columns = state_columns or []
530
+
531
+ if push_to_hub and hub_name is None :
532
+ raise ValueError ("hub_name must be provided if push_to_hub is True" )
533
+
534
+ cols = ["prompt" , "completion" , "answer" , "info" , "task" , "reward" ]
535
+
536
+ results_dict = {
537
+ "prompt" : results .prompt ,
538
+ "completion" : [],
539
+ "answer" : results .answer ,
540
+ "info" : results .info ,
541
+ "task" : results .task ,
542
+ "reward" : results .reward ,
543
+ }
544
+ for i in range (len (results .completion )):
545
+ results_dict ["completion" ].append (
546
+ self ._sanitize_tool_calls (results .completion [i ])
547
+ )
548
+ results_dict .update (results .metrics )
549
+ cols .extend (results .metrics .keys ())
550
+ if results .state [0 ] is not None :
551
+ for col in state_columns :
552
+ if col in results .state [0 ]:
553
+ results_dict [col ] = [state [col ] for state in results .state ]
554
+ cols .append (col )
555
+ else :
556
+ self .logger .warning (
557
+ f"Column { col } not found in state, skipping from dataset."
558
+ )
559
+ dataset = Dataset .from_dict ({col : results_dict [col ] for col in cols })
560
+ if push_to_hub :
561
+ assert hub_name is not None
562
+ dataset .push_to_hub (hub_name )
563
+ return dataset
564
+
458
565
#########################################################
459
566
# Optional helper functions for parsing vLLM completions
460
567
#########################################################
@@ -777,110 +884,3 @@ def process_env_results_vllm(
777
884
778
885
# alias for process_env_results_vllm
779
886
process_env_results = process_env_results_vllm
780
-
781
- #########################################################
782
- # Helper functions for evaluation and dataset generation
783
- #########################################################
784
-
785
- def evaluate (
786
- self ,
787
- client : AsyncOpenAI | OpenAI ,
788
- model : str ,
789
- sampling_args : SamplingArgs | None = None ,
790
- num_examples : int = - 1 ,
791
- rollouts_per_example : int = 1 ,
792
- score_rollouts : bool = True ,
793
- max_concurrent : int = - 1 ,
794
- ** kwargs ,
795
- ) -> GenerateOutputs :
796
- """
797
- Evaluate model on the Environment evaluation dataset.
798
- """
799
- if self .eval_dataset is None :
800
- self .logger .info ("eval_dataset is not set, falling back to train dataset" )
801
- assert self .dataset is not None
802
- inputs = self .get_dataset (n = num_examples )
803
- else :
804
- inputs = self .get_eval_dataset (n = num_examples )
805
- assert inputs is not None , "No dataset found"
806
- if rollouts_per_example > 1 :
807
- inputs = inputs .repeat (rollouts_per_example )
808
- results = self .generate (
809
- inputs ,
810
- client ,
811
- model ,
812
- sampling_args ,
813
- score_rollouts ,
814
- max_concurrent ,
815
- ** kwargs ,
816
- )
817
- return results
818
-
819
- def _sanitize_tool_calls (self , completion : Messages ) -> Messages :
820
- """
821
- Sanitize tool calls from a completion.
822
- """
823
-
824
- assert isinstance (completion , list )
825
- sanitized_completion = []
826
- for m in completion :
827
- if "tool_calls" in m :
828
- new_m = {
829
- "role" : m ["role" ],
830
- "content" : m .get ("content" , "" ),
831
- "tool_calls" : [
832
- json .dumps (tc .model_dump ()) # type: ignore
833
- for tc in m .get ("tool_calls" , [])
834
- ],
835
- }
836
- sanitized_completion .append (new_m )
837
- else :
838
- sanitized_completion .append (m )
839
- return sanitized_completion
840
-
841
- def make_dataset (
842
- self ,
843
- results : GenerateOutputs ,
844
- push_to_hub : bool = False ,
845
- hub_name : str | None = None ,
846
- state_columns : list [str ] | None = None ,
847
- ** kwargs ,
848
- ) -> Dataset :
849
- """
850
- Make a dataset from the evaluation results.
851
- """
852
- state_columns = state_columns or []
853
-
854
- if push_to_hub and hub_name is None :
855
- raise ValueError ("hub_name must be provided if push_to_hub is True" )
856
-
857
- cols = ["prompt" , "completion" , "answer" , "info" , "task" , "reward" ]
858
-
859
- results_dict = {
860
- "prompt" : results .prompt ,
861
- "completion" : [],
862
- "answer" : results .answer ,
863
- "info" : results .info ,
864
- "task" : results .task ,
865
- "reward" : results .reward ,
866
- }
867
- for i in range (len (results .completion )):
868
- results_dict ["completion" ].append (
869
- self ._sanitize_tool_calls (results .completion [i ])
870
- )
871
- results_dict .update (results .metrics )
872
- cols .extend (results .metrics .keys ())
873
- if results .state [0 ] is not None :
874
- for col in state_columns :
875
- if col in results .state [0 ]:
876
- results_dict [col ] = [state [col ] for state in results .state ]
877
- cols .append (col )
878
- else :
879
- self .logger .warning (
880
- f"Column { col } not found in state, skipping from dataset."
881
- )
882
- dataset = Dataset .from_dict ({col : results_dict [col ] for col in cols })
883
- if push_to_hub :
884
- assert hub_name is not None
885
- dataset .push_to_hub (hub_name )
886
- return dataset
0 commit comments