@@ -125,7 +125,7 @@ enum slot_command {
125125struct slot_params {
126126 bool stream = true ;
127127 uint32_t seed = -1 ; // RNG seed
128- int n_keep = 0 ; // RNG seed
128+ int n_keep = 0 ; // number of tokens to keep from initial prompt
129129 int32_t n_predict = -1 ; // new tokens to predict
130130 std::string grammar = " " ; // optional BNF-like grammar to constrain sampling
131131 bool cache_prompt = false ; // remember a the prompt to avoid reprocessing all prompt
@@ -262,6 +262,34 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
262262 return out;
263263}
264264
265+ struct llama_sampling_context * llama_sampling_init_srv (const struct llama_sampling_params sparams, std::string grammar, int n_ctx) {
266+ struct llama_sampling_context * result = new llama_sampling_context ();
267+
268+ result->params = sparams;
269+ result->grammar = nullptr ;
270+
271+ // if there is a grammar, parse it
272+ if (!grammar.empty ()) {
273+ result->parsed_grammar = grammar_parser::parse (grammar.c_str ());
274+
275+ // will be empty (default) if there are parse errors
276+ if (result->parsed_grammar .rules .empty ()) {
277+ fprintf (stderr, " %s: failed to parse grammar\n " , __func__);
278+ return nullptr ;
279+ }
280+
281+ std::vector<const llama_grammar_element *> grammar_rules (result->parsed_grammar .c_rules ());
282+
283+ result->grammar = llama_grammar_init (
284+ grammar_rules.data (),
285+ grammar_rules.size (), result->parsed_grammar .symbol_ids .at (" root" ));
286+ }
287+
288+ result->prev .resize (n_ctx);
289+
290+ return result;
291+ }
292+
265293struct slot_image {
266294 clip_image_u8 img_data;
267295 bool request_encode_image = false ;
@@ -287,7 +315,6 @@ struct llama_client_slot
287315 int num_tokens_predicted = 0 ;
288316 llama_token sampled;
289317 std::vector<llama_token> cache_tokens;
290- std::vector<llama_token> last_n_tokens;
291318 std::vector<completion_token_output> generated_token_probs;
292319 int sent_tokens = 0 ;
293320 slot_state state = IDLE;
@@ -307,13 +334,12 @@ struct llama_client_slot
307334 double t_token_generation; // ms
308335
309336 struct slot_params params;
337+
338+ // sampling
310339 struct llama_sampling_params sparams;
311- llama_sampling_context ctx_sampling;
340+ llama_sampling_context* ctx_sampling = nullptr ;
312341 bool has_next_token = true ;
313-
314- // grammar props
315- grammar_parser::parse_state parsed_grammar;
316- llama_grammar *grammar = nullptr ;
342+ int max_context_size = 0 ;
317343
318344 // multimodal
319345 std::vector<slot_image> images;
@@ -332,47 +358,26 @@ struct llama_client_slot
332358 infill = false ;
333359 clean_tokens ();
334360
335- if (grammar != nullptr ) {
336- llama_grammar_free (grammar);
337- grammar = nullptr ;
338- ctx_sampling.params = sparams;
339- ctx_sampling.grammar = NULL ;
361+ if (ctx_sampling != nullptr ) {
362+ llama_sampling_free (ctx_sampling);
340363 }
341364
365+ ctx_sampling = llama_sampling_init_srv (sparams, params.grammar , max_context_size);
366+
342367 for (slot_image img : images) {
343368 free (img.image_embedding );
344369 delete[] img.img_data .data ;
345370 img.prefix_prompt = " " ;
346371 }
372+
347373 images.clear ();
348374 // llama_set_rng_seed(ctx, params.seed); in batched the seed matter???????
349375 }
350376
351377 bool loadGrammar (llama_token eos)
352378 {
353- if (!params.grammar .empty ()) {
354- parsed_grammar = grammar_parser::parse (params.grammar .c_str ());
355- // will be empty (default) if there are parse errors
356- if (parsed_grammar.rules .empty ()) {
357- LOG_ERROR (" grammar parse error" , {{" grammar" , params.grammar }});
358- return false ;
359- }
360- grammar_parser::print_grammar (stderr, parsed_grammar);
361-
362- {
363- auto it = sparams.logit_bias .find (eos);
364- if (it != sparams.logit_bias .end () && it->second == -INFINITY) {
365- LOG_WARNING (" EOS token is disabled, which will cause most grammars to fail" , {});
366- }
367- }
368-
369- std::vector<const llama_grammar_element *> grammar_rules (parsed_grammar.c_rules ());
370- grammar = llama_grammar_init (
371- grammar_rules.data (), grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
372- }
373- ctx_sampling.params = sparams;
374- ctx_sampling.grammar = grammar;
375- return true ;
379+ ctx_sampling = llama_sampling_init_srv (sparams, params.grammar , max_context_size);
380+ return ctx_sampling != nullptr ;
376381 }
377382
378383 bool hasBudget (gpt_params &global_params) {
@@ -448,7 +453,6 @@ struct llama_server_context
448453 llama_model *model = nullptr ;
449454 llama_context *ctx = nullptr ;
450455 llama_batch batch;
451- std::vector<llama_token_data> candidates;
452456 bool all_slots_are_idle = false ;
453457 gpt_params params;
454458 int n_ctx;
@@ -468,11 +472,6 @@ struct llama_server_context
468472 llama_free_model (model);
469473 model = nullptr ;
470474 }
471- for (auto &slot : slots) {
472- if (slot.grammar ) {
473- llama_grammar_free (slot.grammar );
474- }
475- }
476475 }
477476
478477 bool loadModel (const gpt_params ¶ms_)
@@ -510,7 +509,6 @@ struct llama_server_context
510509 }
511510 n_ctx = llama_n_ctx (ctx);
512511 n_vocab = llama_n_vocab (model);
513- candidates.reserve (n_vocab);
514512 return true ;
515513 }
516514
@@ -529,13 +527,12 @@ struct llama_server_context
529527 {
530528 llama_client_slot slot;
531529 slot.id = i;
532- slot.last_n_tokens .resize (max_ctx_per_slot);
533- std::fill (slot.last_n_tokens .begin (), slot.last_n_tokens .end (), 0 );
530+ slot.max_context_size = max_ctx_per_slot;
534531 slot.reset ();
535532 LOG_TEE (" -> Slot %i - max context: %i\n " , slot.id , max_ctx_per_slot);
536533 slots.push_back (slot);
537534 }
538- batch = llama_batch_init (n_ctx, 0 );
535+ batch = llama_batch_init (n_ctx, 0 , 1 );
539536 // empty system prompt
540537 system_prompt = " " ;
541538 num_tokens_system = 0 ;
@@ -626,10 +623,7 @@ struct llama_server_context
626623
627624 for (int32_t i = 0 ; i < batch.n_tokens ; ++i)
628625 {
629- batch.token [i] = tokens_system[i];
630- batch.pos [i] = i;
631- batch.seq_id [i] = 0 ;
632- batch.logits [i] = false ;
626+ llama_batch_add (batch, tokens_system[i], i, { 0 }, false );
633627 }
634628
635629 if (llama_decode (ctx, batch) != 0 )
@@ -726,8 +720,6 @@ struct llama_server_context
726720
727721 bool processToken (completion_token_output & result, llama_client_slot & slot) {
728722 // remember which tokens were sampled - used for repetition penalties during sampling
729- slot.last_n_tokens .erase (slot.last_n_tokens .begin ());
730- slot.last_n_tokens .push_back (result.tok );
731723 const std::string token_str = llama_token_to_piece (ctx, result.tok );
732724 slot.sampled = result.tok ;
733725
@@ -859,11 +851,12 @@ struct llama_server_context
859851 const int32_t n_tokens = std::min (n_batch, (int32_t ) (batch.n_tokens - i));
860852 llama_batch batch_view = {
861853 n_tokens,
862- batch.token + i,
854+ batch.token + i,
863855 nullptr ,
864- batch.pos + i,
865- batch.seq_id + i,
866- batch.logits + i,
856+ batch.pos + i,
857+ batch.n_seq_id + i,
858+ batch.seq_id + i,
859+ batch.logits + i,
867860 0 , 0 , 0 , // unused
868861 };
869862 if (llama_decode (ctx, batch_view)) {
@@ -878,8 +871,8 @@ struct llama_server_context
878871 if (n_eval > n_batch) {
879872 n_eval = n_batch;
880873 }
881- llama_batch batch = {int32_t (n_eval), nullptr , (img.image_embedding + i * n_embd), nullptr , nullptr , nullptr , slot.n_past , 1 , 0 , };
882- if (llama_decode (ctx, batch )) {
874+ llama_batch batch_img = {int32_t (n_eval), nullptr , (img.image_embedding + i * n_embd), nullptr , nullptr , nullptr , nullptr , slot.n_past , 1 , 0 , };
875+ if (llama_decode (ctx, batch_img )) {
883876 LOG_TEE (" %s : failed to eval image\n " , __func__);
884877 return false ;
885878 }
@@ -894,10 +887,7 @@ struct llama_server_context
894887 (json)(slot.images [image_idx].prefix_prompt );
895888 std::vector<llama_token> append_tokens = tokenize (json_prompt, false ); // has next image
896889 for (int i = 0 ; i < append_tokens.size (); ++i) {
897- batch.token [batch.n_tokens ] = append_tokens[i];
898- batch.pos [batch.n_tokens ] = slot.n_past ;
899- batch.seq_id [batch.n_tokens ] = slot.id ;
900- batch.logits [batch.n_tokens ] = false ;
890+ llama_batch_add (batch, append_tokens[i], slot.n_past , { slot.id }, true );
901891 slot.n_past += 1 ;
902892 batch.n_tokens += 1 ;
903893 }
@@ -922,7 +912,6 @@ struct llama_server_context
922912 std::this_thread::sleep_for (std::chrono::milliseconds (5 ));
923913 }
924914
925- // context shift takes effect only when there is a single slot
926915 for (llama_client_slot &slot : slots) {
927916 if (slot.isProcessing () && slot.cache_tokens .size () >= (size_t )max_ctx_per_slot)
928917 {
@@ -976,16 +965,12 @@ struct llama_server_context
976965 continue ;
977966 }
978967
979- batch.token [batch.n_tokens ] = slot.sampled ;
980- batch.pos [batch.n_tokens ] = num_tokens_system + slot.n_past ;
981- batch.seq_id [batch.n_tokens ] = slot.id ;
982- batch.logits [batch.n_tokens ] = true ;
968+ slot.i_batch = batch.n_tokens ;
969+
970+ llama_batch_add (batch, slot.sampled , num_tokens_system + slot.n_past , { slot.id }, true );
983971
984972 slot.n_decoded += 1 ;
985- slot.i_batch = batch.n_tokens ;
986973 slot.n_past += 1 ;
987-
988- batch.n_tokens += 1 ;
989974 }
990975 // process in chunks of params.n_batch
991976 int32_t n_batch = params.n_batch ;
@@ -1026,7 +1011,7 @@ struct llama_server_context
10261011 slot.num_prompt_tokens = prompt_tokens.size ();
10271012
10281013 if (!slot.params .cache_prompt ) {
1029- std::fill (slot.last_n_tokens .begin (), slot.last_n_tokens .end (), 0 );
1014+ std::fill (slot.ctx_sampling -> prev .begin (), slot.ctx_sampling -> prev .end (), 0 );
10301015 slot.n_past = 0 ;
10311016 slot.num_prompt_tokens_processed = slot.num_prompt_tokens ;
10321017 } else {
@@ -1038,23 +1023,27 @@ struct llama_server_context
10381023 // if input prompt is too big, truncate like normal
10391024 if (slot.num_prompt_tokens >= (size_t )max_ctx_per_slot)
10401025 {
1026+ // applied bug of #3661
10411027 const int n_left = max_ctx_per_slot - slot.params .n_keep ;
1028+ const int n_block_size = n_left / 2 ;
1029+ const int erased_blocks = (slot.num_prompt_tokens - slot.params .n_keep - n_block_size) / n_block_size;
10421030 std::vector<llama_token> new_tokens (prompt_tokens.begin (), prompt_tokens.begin () + slot.params .n_keep );
10431031 // Use half the left-over space in the context for the prompt
1044- new_tokens.insert (new_tokens.end (), prompt_tokens.end () - n_left / 2 , prompt_tokens.end ());
1032+ new_tokens.insert (new_tokens.end (), prompt_tokens.end () + slot. params . n_keep + erased_blocks * n_block_size , prompt_tokens.end ());
10451033 LOG_VERBOSE (" input truncated" , {
1046- {" n_ctx" , n_ctx },
1047- {" n_keep" , params.n_keep },
1034+ {" n_ctx" , max_ctx_per_slot },
1035+ {" n_keep" , slot. params .n_keep },
10481036 {" n_left" , n_left},
10491037 {" new_tokens" , tokens_to_str (ctx, new_tokens.cbegin (), new_tokens.cend ())},
10501038 });
10511039 slot.truncated = true ;
10521040 prompt_tokens = new_tokens;
10531041 slot.num_prompt_tokens = prompt_tokens.size ();
1042+ GGML_ASSERT (slot.num_prompt_tokens < (size_t )max_ctx_per_slot);
10541043 }
10551044 const size_t ps = slot.num_prompt_tokens ;
1056- std::fill (slot.last_n_tokens .begin (), slot.last_n_tokens .end () - ps, 0 );
1057- std::copy (prompt_tokens.begin (), prompt_tokens.end (), slot.last_n_tokens .end () - ps);
1045+ std::fill (slot.ctx_sampling -> prev .begin (), slot.ctx_sampling -> prev .end () - ps, 0 );
1046+ std::copy (prompt_tokens.begin (), prompt_tokens.end (), slot.ctx_sampling -> prev .end () - ps);
10581047 slot.n_past = common_part (slot.cache_tokens , prompt_tokens);
10591048 slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past ;
10601049 LOG_TEE (" slot %i - in cache: %i tokens | to process: %i tokens\n " , slot.id , slot.n_past , slot.num_prompt_tokens_processed );
@@ -1081,11 +1070,7 @@ struct llama_server_context
10811070 // process the prefix of first image
10821071 std::vector<llama_token> prefix_tokens = ingest_images ? tokenize (slot.images [0 ].prefix_prompt , true ) : prompt_tokens;
10831072 for (; slot.n_past < prefix_tokens.size (); ++slot.n_past ) {
1084- batch.token [batch.n_tokens ] = prefix_tokens[slot.n_past ];
1085- batch.pos [batch.n_tokens ] = slot.n_past + num_tokens_system;
1086- batch.seq_id [batch.n_tokens ] = slot.id ;
1087- batch.logits [batch.n_tokens ] = false ;
1088- batch.n_tokens += 1 ;
1073+ llama_batch_add (batch, prefix_tokens[slot.n_past ], num_tokens_system + slot.n_past , { slot.id }, false );
10891074 }
10901075
10911076 if (ingest_images && !ingestImages (slot, n_batch)) {
@@ -1113,11 +1098,12 @@ struct llama_server_context
11131098 const int32_t n_tokens = std::min (n_batch, (int32_t ) (batch.n_tokens - i));
11141099 llama_batch batch_view = {
11151100 n_tokens,
1116- batch.token + i,
1101+ batch.token + i,
11171102 nullptr ,
1118- batch.pos + i,
1119- batch.seq_id + i,
1120- batch.logits + i,
1103+ batch.pos + i,
1104+ batch.n_seq_id + i,
1105+ batch.seq_id + i,
1106+ batch.logits + i,
11211107 0 , 0 , 0 , // unused
11221108 };
11231109
@@ -1150,25 +1136,27 @@ struct llama_server_context
11501136 }
11511137
11521138 completion_token_output result;
1153- const llama_token id = llama_sampling_sample (ctx, NULL , slot.ctx_sampling , slot.last_n_tokens , candidates, slot.i_batch - i);
1139+ const llama_token id = llama_sampling_sample (slot.ctx_sampling , ctx, NULL , slot.i_batch - i);
1140+
1141+ llama_sampling_accept (slot.ctx_sampling , ctx, id);
11541142
11551143 if (slot.n_decoded == 1 ) {
11561144 slot.t_start_genereration = ggml_time_us ();
11571145 slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt ) / 1e3 ;
11581146 }
11591147
1160- llama_token_data_array candidates_p = { candidates. data (), candidates .size (), false };
1148+ llama_token_data_array cur_p = { slot. ctx_sampling -> cur . data (), slot. ctx_sampling -> cur .size (), false };
11611149 result.tok = id;
11621150 const int32_t n_probs = slot.sparams .n_probs ;
11631151 if (slot.sparams .temp <= 0 && n_probs > 0 )
11641152 {
11651153 // For llama_sample_token_greedy we need to sort candidates
1166- llama_sample_softmax (ctx, &candidates_p );
1154+ llama_sample_softmax (ctx, &cur_p );
11671155 }
11681156
1169- for (size_t i = 0 ; i < std::min (candidates_p .size , (size_t )n_probs); ++i)
1157+ for (size_t i = 0 ; i < std::min (cur_p .size , (size_t )n_probs); ++i)
11701158 {
1171- result.probs .push_back ({candidates_p .data [i].id , candidates_p .data [i].p });
1159+ result.probs .push_back ({cur_p .data [i].id , cur_p .data [i].p });
11721160 }
11731161
11741162 if (!processToken (result, slot)) {
0 commit comments