@@ -371,8 +371,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
371371
372372 llama_batch_ext_clear (batch.get ());
373373 for (int i = 0 ; i < batch_size; i++) {
374- llama_seq_id seq_id = 0 ;
375- llama_batch_ext_add_text (batch.get (), tokens[batch_start + i], j*n_batch + i, &seq_id, 1 , true );
374+ batch.add_text (tokens[batch_start + i], j*n_batch + i, 0 , true );
376375 }
377376
378377 // LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
@@ -568,7 +567,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
568567 for (int k = 0 ; k < batch_size; ++k) {
569568 const llama_pos pos = j*n_batch + k;
570569 bool output = pos >= first;
571- llama_batch_ext_add_text ( batch.get (), tokens[seq_start + k], pos, & seq, 1 , output);
570+ batch.add_text ( tokens[seq_start + k], pos, seq, output);
572571
573572 n_outputs += output ? 1 : 0 ;
574573 }
@@ -864,7 +863,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
864863
865864 for (size_t i = 0 ; i < hs_cur.common_prefix ; ++i) {
866865 std::vector<llama_seq_id> seq_ids = { s0 + 0 , s0 + 1 , s0 + 2 , s0 + 3 };
867- llama_batch_ext_add_text ( batch.get (), hs_cur.seq_tokens [0 ][i], i, seq_ids. data (), seq_ids. size () , false );
866+ batch.add_text ( hs_cur.seq_tokens [0 ][i], i, seq_ids, false );
868867 }
869868 llama_batch_ext_set_output_last (batch.get ());
870869 n_logits += 1 ;
@@ -875,7 +874,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
875874 for (size_t i = hs_cur.common_prefix ; i < seq_tokens_size; ++i) {
876875 const bool needs_logits = i < seq_tokens_size - 1 ;
877876 llama_seq_id seq_id = s0 + s;
878- llama_batch_ext_add_text ( batch.get (), hs_cur.seq_tokens [s][i], i, & seq_id, 1 , needs_logits);
877+ batch.add_text ( hs_cur.seq_tokens [s][i], i, seq_id, needs_logits);
879878 n_logits += needs_logits;
880879 }
881880 }
@@ -1143,16 +1142,15 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
11431142
11441143 for (size_t i = 0 ; i < data[i1].common_prefix ; ++i) {
11451144 std::vector<llama_seq_id> seq_ids{ s0 + 0 , s0 + 1 };
1146- llama_batch_ext_add_text ( batch.get (), data[i1].seq_tokens [0 ][i], i, seq_ids. data (), seq_ids. size () , false );
1145+ batch.add_text ( data[i1].seq_tokens [0 ][i], i, seq_ids, false );
11471146 }
11481147 llama_batch_ext_set_output_last (batch.get ());
11491148 n_logits += 1 ;
11501149
11511150 for (int s = 0 ; s < 2 ; ++s) {
11521151 // TODO: end before the last token, no need to predict past the end of the sequences
11531152 for (size_t i = data[i1].common_prefix ; i < data[i1].seq_tokens [s].size (); ++i) {
1154- llama_seq_id seq_id = s0 + s;
1155- llama_batch_ext_add_text (batch.get (), data[i1].seq_tokens [s][i], i, &seq_id, 1 , true );
1153+ batch.add_text (data[i1].seq_tokens [s][i], i, s0 + s, true );
11561154 n_logits += 1 ;
11571155 }
11581156 }
@@ -1511,7 +1509,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
15111509 }
15121510
15131511 for (size_t i = 0 ; i < cur_task.common_prefix ; ++i) {
1514- llama_batch_ext_add_text ( batch.get (), cur_task.seq_tokens [0 ][i], i, batch_indeces. data (), batch_indeces. size () , false );
1512+ batch.add_text ( cur_task.seq_tokens [0 ][i], i, batch_indeces, false );
15151513 }
15161514 llama_batch_ext_set_output_last (batch.get ()); // we need logits for the last token of the common prefix
15171515 n_logits += 1 ;
@@ -1521,8 +1519,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
15211519 // TODO: don't evaluate the last token of each sequence
15221520 for (size_t i = cur_task.common_prefix ; i < seq_tokens_size; ++i) {
15231521 const bool needs_logits = i < seq_tokens_size - 1 ;
1524- llama_seq_id seq_id = { s0 + s };
1525- llama_batch_ext_add_text (batch.get (), cur_task.seq_tokens [s][i], i, &seq_id, 1 , needs_logits);
1522+ batch.add_text (cur_task.seq_tokens [s][i], i, s0 + s, needs_logits);
15261523 n_logits += needs_logits;
15271524 }
15281525 }
@@ -1749,8 +1746,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
17491746
17501747 llama_batch_ext_clear (batch.get ());
17511748 for (int i = 0 ; i < batch_size; i++) {
1752- llama_seq_id seq_id = 0 ;
1753- llama_batch_ext_add_text (batch.get (), tokens[batch_start + i], j*n_batch + i, &seq_id, 1 , true );
1749+ batch.add_text (tokens[batch_start + i], j*n_batch + i, 0 , true );
17541750 }
17551751
17561752 if (llama_decode_ext (ctx, batch.get ())) {
0 commit comments