Skip to content

Commit 9d3514a

Browse files
authored
vvhg-code-infill (#1)
1 parent 7eb4117 commit 9d3514a

File tree

7 files changed

+314
-7
lines changed

7 files changed

+314
-7
lines changed

common/common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
356356
params.interactive_first = true;
357357
} else if (arg == "-ins" || arg == "--instruct") {
358358
params.instruct = true;
359+
} else if (arg == "--infill") {
360+
params.infill = true;
359361
} else if (arg == "--multiline-input") {
360362
params.multiline_input = true;
361363
} else if (arg == "--simple-io") {

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ struct gpt_params {
118118
bool numa = false; // attempt optimizations that help on some NUMA systems
119119
bool export_cgraph = false; // export the computation graph
120120
bool verbose_prompt = false; // print prompt tokens before generation
121+
bool infill = false; // use infill mode
121122
};
122123

123124
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);

examples/main/main.cpp

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,15 @@ int main(int argc, char ** argv) {
239239
LOG("add_bos: %d\n", add_bos);
240240

241241
std::vector<llama_token> embd_inp;
242-
243-
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
242+
if(params.infill) {
243+
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos);
244+
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos);
245+
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
246+
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
247+
embd_inp = inp_pfx;
248+
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
249+
embd_inp.push_back(llama_token_middle(ctx));
250+
} else if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
244251
LOG("tokenize the prompt\n");
245252
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
246253
} else {
@@ -709,9 +716,58 @@ int main(int argc, char ** argv) {
709716
LOG("found antiprompt: %s\n", last_output.c_str());
710717
}
711718
}
712-
719+
// deal with eot token in infill mode
720+
if ((last_tokens.back() == llama_token_eot(ctx) || is_interacting) && params.infill && params.interactive){
721+
if(is_interacting && !params.interactive_first) {
722+
// print an eot token
723+
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
724+
}
725+
fflush(stdout);
726+
printf("\n");
727+
console::set_display(console::user_input);
728+
std::string buffer;
729+
std::string line;
730+
bool another_line=true;
731+
// set a new prefix via stdin
732+
do {
733+
another_line = console::readline(line, params.multiline_input);
734+
buffer += line;
735+
} while (another_line);
736+
// check if we got an empty line, if so we use the old input
737+
if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
738+
params.input_prefix = buffer;
739+
}
740+
buffer.clear();
741+
// set a new suffix via stdin
742+
do {
743+
another_line = console::readline(line, params.multiline_input);
744+
buffer += line;
745+
} while (another_line);
746+
// check if we got an empty line
747+
if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
748+
params.input_suffix = buffer;
749+
}
750+
buffer.clear();
751+
// done taking input, reset color
752+
console::set_display(console::reset);
753+
// tokenize new prefix and suffix
754+
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos);
755+
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos);
756+
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
757+
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
758+
embd_inp = inp_pfx;
759+
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
760+
embd_inp.push_back(llama_token_middle(ctx));
761+
embd.clear();
762+
embd_guidance.clear();
763+
n_remain = params.n_predict;
764+
n_past = 0;
765+
n_consumed = 0;
766+
// LOG_TEE("took new input\n");
767+
is_interacting = false;
768+
}
713769
// deal with end of text token in interactive mode
714-
if (last_tokens.back() == llama_token_eos(ctx)) {
770+
else if (last_tokens.back() == llama_token_eos(ctx)) {
715771
LOG("found EOS token\n");
716772

717773
if (params.interactive) {
@@ -731,7 +787,7 @@ int main(int argc, char ** argv) {
731787
}
732788
}
733789

734-
if (n_past > 0 && is_interacting) {
790+
if (n_past > 0 && is_interacting && !(params.infill && params.interactive)) {
735791
LOG("waiting for user input\n");
736792

737793
if (params.instruct) {
@@ -825,17 +881,23 @@ int main(int argc, char ** argv) {
825881

826882
// end of text token
827883
if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !(params.instruct || params.interactive)) {
828-
LOG_TEE(" [end of text]\n");
884+
if (!params.infill){
885+
LOG_TEE(" [end of text]\n");
886+
}
829887
break;
830888
}
831889

832890
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
833891
// We skip this logic when n_predict == -1 (infinite) or -2 (stop at context size).
834-
if (params.interactive && n_remain <= 0 && params.n_predict >= 0) {
892+
if (params.interactive && n_remain <= 0 && params.n_predict >= 0 ) {
835893
n_remain = params.n_predict;
836894
is_interacting = true;
837895
}
838896
}
897+
if (params.infill && !params.interactive && n_remain <= 0) {
898+
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
899+
fflush(stdout);
900+
}
839901

840902
if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) {
841903
LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());

examples/server/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,16 @@ node index.js
176176

177177
`content`: Set the text to process.
178178

179+
**POST** `/infill`: For code infilling. Takes a prefix and a suffix and returns the predicted completion as stream.
180+
181+
*Options:*
182+
183+
`input_prefix`: Set the prefix of the code to infill.
184+
185+
`input_suffix`: Set the suffix of the code to infill.
186+
187+
It also accepts all the options of `/completion` except `stream` and `prompt`.
188+
179189
## More examples
180190

181191
### Interactive mode

examples/server/server.cpp

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,71 @@ struct llama_server_context
341341
return true;
342342
}
343343

344+
void loadInfill()
345+
{
346+
auto prefix_tokens = tokenize(params.input_prefix, true); // always add BOS
347+
auto suffix_tokens = tokenize(params.input_suffix, true); // always add BOS
348+
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx));
349+
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
350+
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
351+
prefix_tokens.push_back(llama_token_middle(ctx));
352+
auto prompt_tokens = prefix_tokens;
353+
354+
num_prompt_tokens = prompt_tokens.size();
355+
356+
if (params.n_keep < 0)
357+
{
358+
params.n_keep = (int)num_prompt_tokens;
359+
}
360+
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
361+
362+
// if input prompt is too big, truncate like normal
363+
if (num_prompt_tokens >= (size_t)params.n_ctx)
364+
{
365+
printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens);
366+
// todo we probably want to cut from both sides
367+
const int n_left = (params.n_ctx - params.n_keep) / 2;
368+
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
369+
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
370+
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
371+
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
372+
373+
LOG_VERBOSE("input truncated", {
374+
{"n_ctx", params.n_ctx},
375+
{"n_keep", params.n_keep},
376+
{"n_left", n_left},
377+
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
378+
});
379+
380+
truncated = true;
381+
prompt_tokens = new_tokens;
382+
}
383+
else
384+
{
385+
const size_t ps = num_prompt_tokens;
386+
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
387+
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
388+
}
389+
390+
// compare the evaluated prompt with the new prompt
391+
n_past = common_part(embd, prompt_tokens);
392+
printf("n_past: %d\n", n_past);
393+
embd = prompt_tokens;
394+
if (n_past == num_prompt_tokens)
395+
{
396+
// we have to evaluate at least 1 token to generate logits.
397+
printf("we have to evaluate at least 1 token to generate logits\n");
398+
n_past--;
399+
}
400+
401+
LOG_VERBOSE("prompt ingested", {
402+
{"n_past", n_past},
403+
{"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)},
404+
{"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
405+
});
406+
407+
has_next_token = true;
408+
}
344409
void loadPrompt()
345410
{
346411
auto prompt_tokens = tokenize(prompt, true); // always add BOS
@@ -1199,6 +1264,27 @@ static void parse_options_completion(const json &body, llama_server_context &lla
11991264
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
12001265
}
12011266

1267+
static void parse_options_infill(const json &body, llama_server_context &llama)
1268+
{
1269+
if (body.count("input_prefix") != 0)
1270+
{
1271+
llama.params.input_prefix = body["input_prefix"];
1272+
}
1273+
else
1274+
{
1275+
llama.params.input_prefix = "";
1276+
}
1277+
if (body.count("input_suffix") != 0)
1278+
{
1279+
llama.params.input_suffix = body["input_suffix"];
1280+
}
1281+
else
1282+
{
1283+
llama.params.input_suffix = "";
1284+
}
1285+
parse_options_completion(body, llama);
1286+
}
1287+
12021288
static void log_server_request(const Request &req, const Response &res)
12031289
{
12041290
LOG_INFO("request", {
@@ -1498,6 +1584,127 @@ int main(int argc, char **argv)
14981584
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
14991585
} });
15001586

1587+
svr.Post("/infill", [&llama](const Request &req, Response &res)
1588+
{
1589+
auto lock = llama.lock();
1590+
1591+
llama.rewind();
1592+
1593+
llama_reset_timings(llama.ctx);
1594+
1595+
parse_options_infill(json::parse(req.body), llama);
1596+
1597+
if (!llama.loadGrammar())
1598+
{
1599+
res.status = 400;
1600+
return;
1601+
}
1602+
llama.loadInfill();
1603+
llama.beginCompletion();
1604+
const auto chunked_content_provider = [&](size_t, DataSink & sink) {
1605+
size_t sent_count = 0;
1606+
size_t sent_token_probs_index = 0;
1607+
1608+
while (llama.has_next_token) {
1609+
const completion_token_output token_with_probs = llama.doCompletion();
1610+
if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) {
1611+
continue;
1612+
}
1613+
const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok);
1614+
1615+
size_t pos = std::min(sent_count, llama.generated_text.size());
1616+
1617+
const std::string str_test = llama.generated_text.substr(pos);
1618+
bool is_stop_full = false;
1619+
size_t stop_pos =
1620+
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
1621+
if (stop_pos != std::string::npos) {
1622+
is_stop_full = true;
1623+
llama.generated_text.erase(
1624+
llama.generated_text.begin() + pos + stop_pos,
1625+
llama.generated_text.end());
1626+
pos = std::min(sent_count, llama.generated_text.size());
1627+
} else {
1628+
is_stop_full = false;
1629+
stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
1630+
STOP_PARTIAL);
1631+
}
1632+
1633+
if (
1634+
stop_pos == std::string::npos ||
1635+
// Send rest of the text if we are at the end of the generation
1636+
(!llama.has_next_token && !is_stop_full && stop_pos > 0)
1637+
) {
1638+
const std::string to_send = llama.generated_text.substr(pos, std::string::npos);
1639+
1640+
sent_count += to_send.size();
1641+
1642+
std::vector<completion_token_output> probs_output = {};
1643+
1644+
if (llama.params.n_probs > 0) {
1645+
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
1646+
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
1647+
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
1648+
if (probs_pos < probs_stop_pos) {
1649+
probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
1650+
}
1651+
sent_token_probs_index = probs_stop_pos;
1652+
}
1653+
1654+
const json data = format_partial_response(llama, to_send, probs_output);
1655+
1656+
const std::string str =
1657+
"data: " +
1658+
data.dump(-1, ' ', false, json::error_handler_t::replace) +
1659+
"\n\n";
1660+
1661+
LOG_VERBOSE("data stream", {
1662+
{ "to_send", str }
1663+
});
1664+
1665+
if (!sink.write(str.data(), str.size())) {
1666+
LOG_VERBOSE("stream closed", {});
1667+
llama_print_timings(llama.ctx);
1668+
return false;
1669+
}
1670+
}
1671+
1672+
if (!llama.has_next_token) {
1673+
// Generation is done, send extra information.
1674+
const json data = format_final_response(
1675+
llama,
1676+
"",
1677+
std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.begin() + sent_token_probs_index)
1678+
);
1679+
1680+
const std::string str =
1681+
"data: " +
1682+
data.dump(-1, ' ', false, json::error_handler_t::replace) +
1683+
"\n\n";
1684+
1685+
LOG_VERBOSE("data stream", {
1686+
{ "to_send", str }
1687+
});
1688+
1689+
if (!sink.write(str.data(), str.size())) {
1690+
LOG_VERBOSE("stream closed", {});
1691+
llama_print_timings(llama.ctx);
1692+
return false;
1693+
}
1694+
}
1695+
}
1696+
1697+
llama_print_timings(llama.ctx);
1698+
sink.done();
1699+
return true;
1700+
};
1701+
const auto on_complete = [&](bool) {
1702+
llama.mutex.unlock();
1703+
};
1704+
lock.release();
1705+
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
1706+
});
1707+
15011708
svr.Get("/model.json", [&llama](const Request &, Response &res)
15021709
{
15031710
const json data = format_generation_settings(llama);

0 commit comments

Comments
 (0)