@@ -27,20 +27,27 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
2727
2828 int count = 0 ;
2929 int seq_count = tokens.size () / params.n_ctx ;
30+ int n_vocab = llama_n_vocab (ctx);
3031
3132 double nll = 0.0 ;
32-
33- fprintf (stderr, " %s : calculating perplexity over %d chunks\n " , __func__, seq_count);
33+ fprintf (stderr, " %s : calculating perplexity over %d chunks, batch_size=%d\n " , __func__, seq_count, params.n_batch );
3434
3535 for (int i = 0 ; i < seq_count; ++i) {
3636 int start = i * params.n_ctx ;
37- int end = start + params.n_ctx - 1 ; // TODO: this is not optimal, e.g. it makes the batch 511 instead of 512
38- // it is better to always be power of 2 for better performance
39- std::vector<llama_token> embd (tokens.begin () + start, tokens.begin () + end);
37+ int end = start + params.n_ctx ;
38+
39+ std::vector<float > logits;
40+ int num_batches = (params.n_ctx + params.n_batch - 1 ) / params.n_batch ;
4041 auto start_t = std::chrono::high_resolution_clock::now ();
41- if (llama_eval (ctx, embd.data (), embd.size (), 0 , params.n_threads )) {
42- fprintf (stderr, " %s : failed to eval\n " , __func__);
43- return ;
42+ for (int j = 0 ; j < num_batches; ++j) {
43+ int batch_start = start + j * params.n_batch ;
44+ int batch_size = std::min (end - batch_start, params.n_batch );
45+ if (llama_eval (ctx, tokens.data () + batch_start, batch_size, j * params.n_batch , params.n_threads )) {
46+ fprintf (stderr, " %s : failed to eval\n " , __func__);
47+ return ;
48+ }
49+ auto batch_logits = llama_get_logits (ctx);
50+ logits.insert (logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
4451 }
4552 auto end_t = std::chrono::high_resolution_clock::now ();
4653 if (i == 0 ) {
@@ -59,15 +66,12 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
5966 // Example, we have a context window of 512, we will compute perplexity for each of the
6067 // last 256 tokens. Then, we split the input up into context window size chunks to
6168 // process the entire prompt.
62-
63- auto logits = llama_get_logits (ctx);
64- for (int j = params.n_ctx / 2 ; j < params.n_ctx - 1 ; ++j) {
69+ for (int j = std::min (512 , params.n_ctx / 2 ); j < params.n_ctx - 1 ; ++j) {
6570 // Calculate probability of next token, given the previous ones.
66- int n_vocab = llama_n_vocab (ctx);
6771 std::vector<float > tok_logits (
68- logits + j * n_vocab,
69- logits + (j + 1 ) * n_vocab);
70- const float prob = softmax (tok_logits)[tokens[start + j + 1 ]];
72+ logits. begin () + j * n_vocab,
73+ logits. begin () + (j + 1 ) * n_vocab);
74+ float prob = softmax (tok_logits)[tokens[start + j + 1 ]];
7175 nll += -std::log (prob);
7276 ++count;
7377 }
@@ -82,11 +86,13 @@ int main(int argc, char ** argv) {
8286 gpt_params params;
8387 params.model = " models/llama-7B/ggml-model.bin" ;
8488
89+ params.n_batch = 512 ;
8590 if (gpt_params_parse (argc, argv, params) == false ) {
8691 return 1 ;
8792 }
8893
8994 params.perplexity = true ;
95+ params.n_batch = std::min (params.n_batch , params.n_ctx );
9096
9197 if (params.n_ctx > 2048 ) {
9298 fprintf (stderr, " %s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
0 commit comments