@@ -149,6 +149,7 @@ struct task_server {
149149 task_type type;
150150 json data;
151151 bool infill_mode = false ;
152+ bool embedding_mode = false ;
152153};
153154
154155struct task_result {
@@ -371,6 +372,7 @@ struct llama_client_slot
371372 std::vector<completion_token_output> generated_token_probs;
372373
373374 bool infill = false ;
375+ bool embedding = false ;
374376 bool has_next_token = true ;
375377 bool truncated = false ;
376378 bool stopped_eos = false ;
@@ -1244,13 +1246,14 @@ struct llama_server_context
12441246 queue_results.push_back (res);
12451247 }
12461248
1247- int request_completion (json data, bool infill)
1249+ int request_completion (json data, bool infill, bool embedding )
12481250 {
12491251 std::lock_guard<std::mutex> lock (mutex_tasks);
12501252 task_server task;
12511253 task.id = id_gen++;
12521254 task.data = data;
12531255 task.infill_mode = infill;
1256+ task.embedding_mode = embedding;
12541257 task.type = COMPLETION_TASK;
12551258 queue_tasks.push_back (task);
12561259 return task.id ;
@@ -1376,7 +1379,7 @@ struct llama_server_context
13761379 {
13771380 LOG_TEE (" slot unavailable\n " );
13781381 // send error result
1379- send_error (task.id , " slot unavaliable " );
1382+ send_error (task.id , " slot unavailable " );
13801383 return ;
13811384 }
13821385
@@ -1388,6 +1391,7 @@ struct llama_server_context
13881391 slot->reset ();
13891392
13901393 slot->infill = task.infill_mode ;
1394+ slot->embedding = task.embedding_mode ;
13911395 slot->task_id = task.id ;
13921396
13931397 if (!launch_slot_with_data (slot, task.data ))
@@ -1695,7 +1699,7 @@ struct llama_server_context
16951699 }
16961700
16971701 // prompt evaluated for embedding
1698- if (params .embedding )
1702+ if (slot .embedding )
16991703 {
17001704 send_embedding (slot);
17011705 slot.release ();
@@ -2274,7 +2278,7 @@ int main(int argc, char **argv)
22742278 svr.Post (" /completion" , [&llama](const httplib::Request &req, httplib::Response &res)
22752279 {
22762280 json data = json::parse (req.body );
2277- const int task_id = llama.request_completion (data, false );
2281+ const int task_id = llama.request_completion (data, false , false );
22782282 if (!json_value (data, " stream" , false )) {
22792283 std::string completion_text;
22802284 task_result result = llama.next_result (task_id);
@@ -2329,7 +2333,7 @@ int main(int argc, char **argv)
23292333 svr.Post (" /infill" , [&llama](const httplib::Request &req, httplib::Response &res)
23302334 {
23312335 json data = json::parse (req.body );
2332- const int task_id = llama.request_completion (data, true );
2336+ const int task_id = llama.request_completion (data, true , false );
23332337 if (!json_value (data, " stream" , false )) {
23342338 std::string completion_text;
23352339 task_result result = llama.next_result (task_id);
@@ -2433,7 +2437,7 @@ int main(int argc, char **argv)
24332437 {
24342438 prompt = " " ;
24352439 }
2436- const int task_id = llama.request_completion ({ {" prompt" , prompt}, { " n_predict" , 0 } }, false );
2440+ const int task_id = llama.request_completion ({ {" prompt" , prompt}, { " n_predict" , 0 } }, false , true );
24372441 task_result result = llama.next_result (task_id);
24382442 return res.set_content (result.result_json .dump (), " application/json" );
24392443 });
0 commit comments