@@ -92,6 +92,7 @@ enum server_task_type {
9292enum server_task_cmpl_type {
9393 SERVER_TASK_CMPL_TYPE_NORMAL,
9494 SERVER_TASK_CMPL_TYPE_EMBEDDING,
95+ SERVER_TASK_CMPL_TYPE_RERANK,
9596 SERVER_TASK_CMPL_TYPE_INFILL,
9697};
9798
@@ -172,6 +173,7 @@ struct server_slot {
172173 std::vector<completion_token_output> generated_token_probs;
173174
174175 server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
176+
175177 bool has_next_token = true ;
176178 bool truncated = false ;
177179 bool stopped_eos = false ;
@@ -954,8 +956,17 @@ struct server_context {
954956 slot.prompt = *prompt;
955957 } else if (prompt->is_array () && prompt->size () == 1 && prompt->at (0 ).is_array ()) {
956958 slot.prompt = prompt->at (0 );
959+ } else if (prompt->is_array () && prompt->size () > 1 ) {
960+ // array of strings
961+ for (const auto & el : *prompt) {
962+ if (!el.is_string ()) {
963+ send_error (task, " \" prompt\" must be a string, an array of strings or an array of integers" , ERROR_TYPE_INVALID_REQUEST);
964+ return false ;
965+ }
966+ }
967+ slot.prompt = *prompt;
957968 } else {
958- send_error (task, " \" prompt\" must be a string or an array of integers" , ERROR_TYPE_INVALID_REQUEST);
969+ send_error (task, " \" prompt\" must be a string, an array of strings or an array of integers" , ERROR_TYPE_INVALID_REQUEST);
959970 return false ;
960971 }
961972 }
@@ -1389,6 +1400,7 @@ struct server_context {
13891400
13901401 res.data = json {
13911402 {" embedding" , std::vector<float >(n_embd, 0 .0f )},
1403+ {" index" , slot.index },
13921404 };
13931405
13941406 continue ;
@@ -1407,6 +1419,44 @@ struct server_context {
14071419 queue_results.send (res);
14081420 }
14091421
1422+ void send_rank (const server_slot & slot, const llama_batch & batch) {
1423+ server_task_result res;
1424+ res.id = slot.id_task ;
1425+ res.error = false ;
1426+ res.stop = true ;
1427+
1428+ for (int i = 0 ; i < batch.n_tokens ; ++i) {
1429+ if (!batch.logits [i] || batch.seq_id [i][0 ] != slot.id + 1 ) {
1430+ continue ;
1431+ }
1432+
1433+ const float * embd = llama_get_embeddings_seq (ctx, batch.seq_id [i][0 ]);
1434+ if (embd == NULL ) {
1435+ embd = llama_get_embeddings_ith (ctx, i);
1436+ }
1437+
1438+ if (embd == NULL ) {
1439+ SLT_ERR (slot, " failed to get embeddings, token = %d, seq_id = %d\n " , batch.token [i], batch.seq_id [i][0 ]);
1440+
1441+ res.data = json {
1442+ {" index" , slot.index },
1443+ {" rank" , -1e6 },
1444+ };
1445+
1446+ continue ;
1447+ }
1448+
1449+ res.data = json {
1450+ {" index" , slot.index },
1451+ {" rank" , embd[0 ]},
1452+ };
1453+ }
1454+
1455+ SLT_DBG (slot, " sending rank, res = '%s'\n " , res.data .dump ().c_str ());
1456+
1457+ queue_results.send (res);
1458+ }
1459+
14101460 //
14111461 // Functions to create new task(s) and receive result(s)
14121462 //
@@ -1442,13 +1492,23 @@ struct server_context {
14421492 // otherwise, it's a multiple-prompt task, we break it into smaller tasks
14431493 else if (prompt.is_array ()) {
14441494 std::vector<json> prompts = prompt;
1445- for (size_t i = 0 ; i < prompts.size (); i++) {
1446- const auto & e = prompts[i];
1447- if (e.is_string () || json_is_array_of_numbers (e)) {
1448- data[" index" ] = i;
1449- create_task (data, true , e);
1450- } else {
1451- throw std::runtime_error (error_msg);
1495+ if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
1496+ for (size_t i = 1 ; i < prompts.size (); i++) {
1497+ json qd;
1498+ qd.push_back (prompts[0 ]);
1499+ qd.push_back (prompts[i]);
1500+ data[" index" ] = i - 1 ;
1501+ create_task (data, true , qd);
1502+ }
1503+ } else {
1504+ for (size_t i = 0 ; i < prompts.size (); i++) {
1505+ const auto & e = prompts[i];
1506+ if (e.is_string () || json_is_array_of_numbers (e)) {
1507+ data[" index" ] = i;
1508+ create_task (data, true , e);
1509+ } else {
1510+ throw std::runtime_error (error_msg);
1511+ }
14521512 }
14531513 }
14541514 }
@@ -1492,7 +1552,9 @@ struct server_context {
14921552 return ;
14931553 }
14941554
1495- size_t idx = result.data [" index" ];
1555+ const size_t idx = result.data [" index" ];
1556+ GGML_ASSERT (idx < results.size () && " index out of range" );
1557+
14961558 results[idx] = result;
14971559 }
14981560 result_handler (results);
@@ -1951,6 +2013,29 @@ struct server_context {
19512013 }
19522014
19532015 prompt_tokens = embd_inp;
2016+ } else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2017+ // require slot.prompt to be array of 2 strings
2018+ if (!slot.prompt .is_array () || slot.prompt .size () != 2 ) {
2019+ SLT_ERR (slot, " %s" , " invalid prompt for rerank task\n " );
2020+ slot.release ();
2021+ send_error (slot, " invalid prompt for rerank task" , ERROR_TYPE_INVALID_REQUEST);
2022+ continue ;
2023+ }
2024+
2025+ // prompt: <s>query</s><s>doc</s>
2026+ prompt_tokens.clear ();
2027+ prompt_tokens.push_back (llama_token_bos (model));
2028+ {
2029+ const auto part = tokenize (slot.prompt [0 ], false );
2030+ prompt_tokens.insert (prompt_tokens.end (), part.begin (), part.end ());
2031+ }
2032+ prompt_tokens.push_back (llama_token_eos (model));
2033+ prompt_tokens.push_back (llama_token_bos (model));
2034+ {
2035+ const auto part = tokenize (slot.prompt [1 ], false );
2036+ prompt_tokens.insert (prompt_tokens.end (), part.begin (), part.end ());
2037+ }
2038+ prompt_tokens.push_back (llama_token_eos (model));
19542039 } else {
19552040 prompt_tokens = tokenize (slot.prompt , system_prompt.empty ()); // add BOS if there isn't system prompt
19562041 }
@@ -1970,7 +2055,7 @@ struct server_context {
19702055 continue ;
19712056 }
19722057
1973- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
2058+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot. cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ) {
19742059 // this prompt is too large to process - discard it
19752060 if (slot.n_prompt_tokens > n_ubatch) {
19762061 slot.release ();
@@ -2048,15 +2133,18 @@ struct server_context {
20482133 slot.n_prompt_tokens_processed = 0 ;
20492134 }
20502135
2051- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
2136+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot. cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ) {
20522137 // cannot fit the prompt in the current batch - will try next iter
20532138 if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
20542139 continue ;
20552140 }
20562141 }
20572142
20582143 // check that we are in the right batch_type, if not defer the slot
2059- bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0 ;
2144+ const bool slot_type =
2145+ slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
2146+ slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0 ;
2147+
20602148 if (batch_type == -1 ) {
20612149 batch_type = slot_type;
20622150 } else if (batch_type != slot_type) {
@@ -2229,6 +2317,13 @@ struct server_context {
22292317 continue ; // continue loop of slots
22302318 }
22312319
2320+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2321+ send_rank (slot, batch_view);
2322+ slot.release ();
2323+ slot.i_batch = -1 ;
2324+ continue ; // continue loop of slots
2325+ }
2326+
22322327 // prompt evaluated for next-token prediction
22332328 slot.state = SLOT_STATE_GENERATING;
22342329 } else if (slot.state != SLOT_STATE_GENERATING) {
@@ -3023,6 +3118,82 @@ int main(int argc, char ** argv) {
30233118 res_ok (res, root);
30243119 };
30253120
3121+ const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3122+ const json body = json::parse (req.body );
3123+
3124+ // TODO: implement
3125+ // int top_n = 1;
3126+ // if (body.count("top_n") != 1) {
3127+ // top_n = body.at("top_n");
3128+ // } else {
3129+ // res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
3130+ // return;
3131+ // }
3132+
3133+ json query;
3134+ if (body.count (" query" ) == 1 ) {
3135+ query = body.at (" query" );
3136+ if (!query.is_string ()) {
3137+ res_error (res, format_error_response (" \" query\" must be a string" , ERROR_TYPE_INVALID_REQUEST));
3138+ return ;
3139+ }
3140+ } else {
3141+ exit (0 );
3142+ res_error (res, format_error_response (" \" query\" must be provided" , ERROR_TYPE_INVALID_REQUEST));
3143+ return ;
3144+ }
3145+
3146+ json documents;
3147+ if (body.count (" documents" ) != 0 ) {
3148+ documents = body.at (" documents" );
3149+ if (!documents.is_array () || documents.size () == 0 ) {
3150+ res_error (res, format_error_response (" \" documents\" must be a non-empty string array" , ERROR_TYPE_INVALID_REQUEST));
3151+ return ;
3152+ }
3153+ } else {
3154+ res_error (res, format_error_response (" \" documents\" must be provided" , ERROR_TYPE_INVALID_REQUEST));
3155+ return ;
3156+ }
3157+
3158+ // construct prompt object: array of ["query", "doc0", "doc1", ...]
3159+ json prompt;
3160+ prompt.push_back (query);
3161+ for (const auto & doc : documents) {
3162+ prompt.push_back (doc);
3163+ }
3164+
3165+ LOG_DBG (" rerank prompt: %s\n " , prompt.dump ().c_str ());
3166+
3167+ // create and queue the task
3168+ json responses = json::array ();
3169+ bool error = false ;
3170+ {
3171+ std::vector<server_task> tasks = ctx_server.create_tasks_cmpl ({{" prompt" , prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
3172+ ctx_server.queue_results .add_waiting_tasks (tasks);
3173+ ctx_server.queue_tasks .post (tasks);
3174+
3175+ // get the result
3176+ std::unordered_set<int > task_ids = server_task::get_list_id (tasks);
3177+
3178+ ctx_server.receive_cmpl_results (task_ids, [&](std::vector<server_task_result> & results) {
3179+ for (const auto & res : results) {
3180+ responses.push_back (res.data );
3181+ }
3182+ }, [&](const json & error_data) {
3183+ res_error (res, error_data);
3184+ error = true ;
3185+ });
3186+ }
3187+
3188+ if (error) {
3189+ return ;
3190+ }
3191+
3192+ // write JSON response
3193+ json root = format_response_rerank (body, responses);
3194+ res_ok (res, root);
3195+ };
3196+
30263197 const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
30273198 json result = json::array ();
30283199 for (size_t i = 0 ; i < ctx_server.loras .size (); ++i) {
@@ -3119,6 +3290,7 @@ int main(int argc, char ** argv) {
31193290 svr->Post (" /embedding" , handle_embeddings); // legacy
31203291 svr->Post (" /embeddings" , handle_embeddings);
31213292 svr->Post (" /v1/embeddings" , handle_embeddings);
3293+ svr->Post (" /v1/rerank" , handle_rerank);
31223294 svr->Post (" /tokenize" , handle_tokenize);
31233295 svr->Post (" /detokenize" , handle_detokenize);
31243296 // LoRA adapters hotswap
0 commit comments