@@ -149,6 +149,7 @@ struct task_server {
149149 task_type type;
150150 json data;
151151 bool infill_mode = false ;
152+ bool embedding_mode = false ;
152153};
153154
154155struct task_result {
@@ -371,6 +372,7 @@ struct llama_client_slot
371372 std::vector<completion_token_output> generated_token_probs;
372373
373374 bool infill = false ;
375+ bool embedding = false ;
374376 bool has_next_token = true ;
375377 bool truncated = false ;
376378 bool stopped_eos = false ;
@@ -1244,13 +1246,14 @@ struct llama_server_context
12441246 queue_results.push_back (res);
12451247 }
12461248
1247- int request_completion (json data, bool infill)
1249+ int request_completion (json data, bool infill, bool embedding )
12481250 {
12491251 std::lock_guard<std::mutex> lock (mutex_tasks);
12501252 task_server task;
12511253 task.id = id_gen++;
12521254 task.data = data;
12531255 task.infill_mode = infill;
1256+ task.embedding_mode = embedding;
12541257 task.type = COMPLETION_TASK;
12551258 queue_tasks.push_back (task);
12561259 return task.id ;
@@ -1376,7 +1379,7 @@ struct llama_server_context
13761379 {
13771380 LOG_TEE (" slot unavailable\n " );
13781381 // send error result
1379- send_error (task.id , " slot unavaliable " );
1382+ send_error (task.id , " slot unavailable " );
13801383 return ;
13811384 }
13821385
@@ -1388,6 +1391,7 @@ struct llama_server_context
13881391 slot->reset ();
13891392
13901393 slot->infill = task.infill_mode ;
1394+ slot->embedding = task.embedding_mode ;
13911395 slot->task_id = task.id ;
13921396
13931397 if (!launch_slot_with_data (slot, task.data ))
@@ -1695,7 +1699,7 @@ struct llama_server_context
16951699 }
16961700
16971701 // prompt evaluated for embedding
1698- if (params .embedding )
1702+ if (slot .embedding )
16991703 {
17001704 send_embedding (slot);
17011705 slot.release ();
@@ -2277,7 +2281,7 @@ int main(int argc, char **argv)
22772281 svr.Post (" /completion" , [&llama](const httplib::Request &req, httplib::Response &res)
22782282 {
22792283 json data = json::parse (req.body );
2280- const int task_id = llama.request_completion (data, false );
2284+ const int task_id = llama.request_completion (data, false , false );
22812285 if (!json_value (data, " stream" , false )) {
22822286 std::string completion_text;
22832287 task_result result = llama.next_result (task_id);
@@ -2332,7 +2336,7 @@ int main(int argc, char **argv)
23322336 svr.Post (" /infill" , [&llama](const httplib::Request &req, httplib::Response &res)
23332337 {
23342338 json data = json::parse (req.body );
2335- const int task_id = llama.request_completion (data, true );
2339+ const int task_id = llama.request_completion (data, true , false );
23362340 if (!json_value (data, " stream" , false )) {
23372341 std::string completion_text;
23382342 task_result result = llama.next_result (task_id);
@@ -2436,7 +2440,7 @@ int main(int argc, char **argv)
24362440 {
24372441 prompt = " " ;
24382442 }
2439- const int task_id = llama.request_completion ({ {" prompt" , prompt}, { " n_predict" , 0 } }, false );
2443+ const int task_id = llama.request_completion ({ {" prompt" , prompt}, { " n_predict" , 0 } }, false , true );
24402444 task_result result = llama.next_result (task_id);
24412445 return res.set_content (result.result_json .dump (), " application/json" );
24422446 });
0 commit comments