@@ -672,6 +672,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
672672 if (params.logdir .back () != DIRECTORY_SEPARATOR) {
673673 params.logdir += DIRECTORY_SEPARATOR;
674674 }
675+ } else if (arg == " --save-all-logits" || arg == " --kl-divergence-base" ) {
676+ if (++i >= argc) {
677+ invalid_param = true ;
678+ break ;
679+ }
680+ params.logits_file = argv[i];
675681 } else if (arg == " --perplexity" || arg == " --all-logits" ) {
676682 params.logits_all = true ;
677683 } else if (arg == " --ppl-stride" ) {
@@ -716,6 +722,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
716722 break ;
717723 }
718724 params.multiple_choice_tasks = std::stoi (argv[i]);
725+ } else if (arg == " --kl-divergence" ) {
726+ params.kl_divergence = true ;
719727 } else if (arg == " --ignore-eos" ) {
720728 params.ignore_eos = true ;
721729 } else if (arg == " --no-penalize-nl" ) {
@@ -967,6 +975,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
967975 printf (" --winogrande-tasks N number of tasks to use when computing the Winogrande score (default: %zu)\n " , params.winogrande_tasks );
968976 printf (" --multiple-choice compute multiple choice score over random tasks from datafile supplied with -f\n " );
969977 printf (" --multiple-choice-tasks N number of tasks to use when computing the multiple choice score (default: %zu)\n " , params.winogrande_tasks );
978+ printf (" --kl-divergence computes KL-divergence to logits provided via --kl-divergence-base" );
970979 printf (" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n " , params.n_keep );
971980 printf (" --draft N number of tokens to draft for speculative decoding (default: %d)\n " , params.n_draft );
972981 printf (" --chunks N max number of chunks to process (default: %d, -1 = all)\n " , params.n_chunks );
0 commit comments