@@ -341,6 +341,71 @@ struct llama_server_context
341341 return true ;
342342 }
343343
344+ void loadInfill ()
345+ {
346+ auto prefix_tokens = tokenize (params.input_prefix , true ); // always add BOS
347+ auto suffix_tokens = tokenize (params.input_suffix , true ); // always add BOS
348+ prefix_tokens.insert (prefix_tokens.begin (), llama_token_prefix (ctx));
349+ prefix_tokens.insert (prefix_tokens.end (), llama_token_suffix (ctx));
350+ prefix_tokens.insert (prefix_tokens.end (), suffix_tokens.begin (), suffix_tokens.end ());
351+ prefix_tokens.push_back (llama_token_middle (ctx));
352+ auto prompt_tokens = prefix_tokens;
353+
354+ num_prompt_tokens = prompt_tokens.size ();
355+
356+ if (params.n_keep < 0 )
357+ {
358+ params.n_keep = (int )num_prompt_tokens;
359+ }
360+ params.n_keep = std::min (params.n_ctx - 4 , params.n_keep );
361+
362+ // if input prompt is too big, truncate like normal
363+ if (num_prompt_tokens >= (size_t )params.n_ctx )
364+ {
365+ printf (" Input prompt is too big, truncating. Can only take %d tokens but got %zu\n " , params.n_ctx , num_prompt_tokens);
366+ // todo we probably want to cut from both sides
367+ const int n_left = (params.n_ctx - params.n_keep ) / 2 ;
368+ std::vector<llama_token> new_tokens (prompt_tokens.begin (), prompt_tokens.begin () + params.n_keep );
369+ const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1 ) / n_left;
370+ new_tokens.insert (new_tokens.end (), prompt_tokens.begin () + params.n_keep + erased_blocks * n_left, prompt_tokens.end ());
371+ std::copy (prompt_tokens.end () - params.n_ctx , prompt_tokens.end (), last_n_tokens.begin ());
372+
373+ LOG_VERBOSE (" input truncated" , {
374+ {" n_ctx" , params.n_ctx },
375+ {" n_keep" , params.n_keep },
376+ {" n_left" , n_left},
377+ {" new_tokens" , tokens_to_str (ctx, new_tokens.cbegin (), new_tokens.cend ())},
378+ });
379+
380+ truncated = true ;
381+ prompt_tokens = new_tokens;
382+ }
383+ else
384+ {
385+ const size_t ps = num_prompt_tokens;
386+ std::fill (last_n_tokens.begin (), last_n_tokens.end () - ps, 0 );
387+ std::copy (prompt_tokens.begin (), prompt_tokens.end (), last_n_tokens.end () - ps);
388+ }
389+
390+ // compare the evaluated prompt with the new prompt
391+ n_past = common_part (embd, prompt_tokens);
392+ printf (" n_past: %d\n " , n_past);
393+ embd = prompt_tokens;
394+ if (n_past == num_prompt_tokens)
395+ {
396+ // we have to evaluate at least 1 token to generate logits.
397+ printf (" we have to evaluate at least 1 token to generate logits\n " );
398+ n_past--;
399+ }
400+
401+ LOG_VERBOSE (" prompt ingested" , {
402+ {" n_past" , n_past},
403+ {" cached" , tokens_to_str (ctx, embd.cbegin (), embd.cbegin () + n_past)},
404+ {" to_eval" , tokens_to_str (ctx, embd.cbegin () + n_past, embd.cend ())},
405+ });
406+
407+ has_next_token = true ;
408+ }
344409 void loadPrompt ()
345410 {
346411 auto prompt_tokens = tokenize (prompt, true ); // always add BOS
@@ -1199,6 +1264,27 @@ static void parse_options_completion(const json &body, llama_server_context &lla
11991264 LOG_VERBOSE (" completion parameters parsed" , format_generation_settings (llama));
12001265}
12011266
1267+ static void parse_options_infill (const json &body, llama_server_context &llama)
1268+ {
1269+ if (body.count (" input_prefix" ) != 0 )
1270+ {
1271+ llama.params .input_prefix = body[" input_prefix" ];
1272+ }
1273+ else
1274+ {
1275+ llama.params .input_prefix = " " ;
1276+ }
1277+ if (body.count (" input_suffix" ) != 0 )
1278+ {
1279+ llama.params .input_suffix = body[" input_suffix" ];
1280+ }
1281+ else
1282+ {
1283+ llama.params .input_suffix = " " ;
1284+ }
1285+ parse_options_completion (body, llama);
1286+ }
1287+
12021288static void log_server_request (const Request &req, const Response &res)
12031289{
12041290 LOG_INFO (" request" , {
@@ -1498,6 +1584,127 @@ int main(int argc, char **argv)
14981584 res.set_chunked_content_provider (" text/event-stream" , chunked_content_provider, on_complete);
14991585 } });
15001586
1587+ svr.Post (" /infill" , [&llama](const Request &req, Response &res)
1588+ {
1589+ auto lock = llama.lock ();
1590+
1591+ llama.rewind ();
1592+
1593+ llama_reset_timings (llama.ctx );
1594+
1595+ parse_options_infill (json::parse (req.body ), llama);
1596+
1597+ if (!llama.loadGrammar ())
1598+ {
1599+ res.status = 400 ;
1600+ return ;
1601+ }
1602+ llama.loadInfill ();
1603+ llama.beginCompletion ();
1604+ const auto chunked_content_provider = [&](size_t , DataSink & sink) {
1605+ size_t sent_count = 0 ;
1606+ size_t sent_token_probs_index = 0 ;
1607+
1608+ while (llama.has_next_token ) {
1609+ const completion_token_output token_with_probs = llama.doCompletion ();
1610+ if (token_with_probs.tok == -1 || llama.multibyte_pending > 0 ) {
1611+ continue ;
1612+ }
1613+ const std::string token_text = llama_token_to_piece (llama.ctx , token_with_probs.tok );
1614+
1615+ size_t pos = std::min (sent_count, llama.generated_text .size ());
1616+
1617+ const std::string str_test = llama.generated_text .substr (pos);
1618+ bool is_stop_full = false ;
1619+ size_t stop_pos =
1620+ llama.findStoppingStrings (str_test, token_text.size (), STOP_FULL);
1621+ if (stop_pos != std::string::npos) {
1622+ is_stop_full = true ;
1623+ llama.generated_text .erase (
1624+ llama.generated_text .begin () + pos + stop_pos,
1625+ llama.generated_text .end ());
1626+ pos = std::min (sent_count, llama.generated_text .size ());
1627+ } else {
1628+ is_stop_full = false ;
1629+ stop_pos = llama.findStoppingStrings (str_test, token_text.size (),
1630+ STOP_PARTIAL);
1631+ }
1632+
1633+ if (
1634+ stop_pos == std::string::npos ||
1635+ // Send rest of the text if we are at the end of the generation
1636+ (!llama.has_next_token && !is_stop_full && stop_pos > 0 )
1637+ ) {
1638+ const std::string to_send = llama.generated_text .substr (pos, std::string::npos);
1639+
1640+ sent_count += to_send.size ();
1641+
1642+ std::vector<completion_token_output> probs_output = {};
1643+
1644+ if (llama.params .n_probs > 0 ) {
1645+ const std::vector<llama_token> to_send_toks = llama_tokenize (llama.ctx , to_send, false );
1646+ size_t probs_pos = std::min (sent_token_probs_index, llama.generated_token_probs .size ());
1647+ size_t probs_stop_pos = std::min (sent_token_probs_index + to_send_toks.size (), llama.generated_token_probs .size ());
1648+ if (probs_pos < probs_stop_pos) {
1649+ probs_output = std::vector<completion_token_output>(llama.generated_token_probs .begin () + probs_pos, llama.generated_token_probs .begin () + probs_stop_pos);
1650+ }
1651+ sent_token_probs_index = probs_stop_pos;
1652+ }
1653+
1654+ const json data = format_partial_response (llama, to_send, probs_output);
1655+
1656+ const std::string str =
1657+ " data: " +
1658+ data.dump (-1 , ' ' , false , json::error_handler_t ::replace) +
1659+ " \n\n " ;
1660+
1661+ LOG_VERBOSE (" data stream" , {
1662+ { " to_send" , str }
1663+ });
1664+
1665+ if (!sink.write (str.data (), str.size ())) {
1666+ LOG_VERBOSE (" stream closed" , {});
1667+ llama_print_timings (llama.ctx );
1668+ return false ;
1669+ }
1670+ }
1671+
1672+ if (!llama.has_next_token ) {
1673+ // Generation is done, send extra information.
1674+ const json data = format_final_response (
1675+ llama,
1676+ " " ,
1677+ std::vector<completion_token_output>(llama.generated_token_probs .begin (), llama.generated_token_probs .begin () + sent_token_probs_index)
1678+ );
1679+
1680+ const std::string str =
1681+ " data: " +
1682+ data.dump (-1 , ' ' , false , json::error_handler_t ::replace) +
1683+ " \n\n " ;
1684+
1685+ LOG_VERBOSE (" data stream" , {
1686+ { " to_send" , str }
1687+ });
1688+
1689+ if (!sink.write (str.data (), str.size ())) {
1690+ LOG_VERBOSE (" stream closed" , {});
1691+ llama_print_timings (llama.ctx );
1692+ return false ;
1693+ }
1694+ }
1695+ }
1696+
1697+ llama_print_timings (llama.ctx );
1698+ sink.done ();
1699+ return true ;
1700+ };
1701+ const auto on_complete = [&](bool ) {
1702+ llama.mutex .unlock ();
1703+ };
1704+ lock.release ();
1705+ res.set_chunked_content_provider (" text/event-stream" , chunked_content_provider, on_complete);
1706+ });
1707+
15011708 svr.Get (" /model.json" , [&llama](const Request &, Response &res)
15021709 {
15031710 const json data = format_generation_settings (llama);
0 commit comments