@@ -556,6 +556,7 @@ struct slot_params {
556556 std::vector<std::string> antiprompt;
557557
558558 bool timings_per_token = false ;
559+ bool post_sampling_probs = false ;
559560 json input_prefix;
560561 json input_suffix;
561562
@@ -1545,6 +1546,8 @@ struct server_context {
15451546 slot.sparams .n_probs = json_value (data, " n_probs" , default_sparams.n_probs );
15461547 slot.sparams .min_keep = json_value (data, " min_keep" , default_sparams.min_keep );
15471548
1549+ slot.params .post_sampling_probs = json_value (data, " post_sampling_probs" , default_params.post_sampling_probs );
1550+
15481551 // speculative decoding parameters
15491552 slot.params .speculative .n_max = json_value (data, " speculative.n_max" , params.n_draft );
15501553 slot.params .speculative .n_min = json_value (data, " speculative.n_min" , params.n_draft_min );
@@ -1947,26 +1950,7 @@ struct server_context {
19471950 }
19481951
19491952 // check if there is incomplete UTF-8 character at the end
1950- bool incomplete = false ;
1951- for (unsigned i = 1 ; i < 5 && i <= slot.generated_text .size (); ++i) {
1952- unsigned char c = slot.generated_text [slot.generated_text .size () - i];
1953- if ((c & 0xC0 ) == 0x80 ) {
1954- // continuation byte: 10xxxxxx
1955- continue ;
1956- }
1957- if ((c & 0xE0 ) == 0xC0 ) {
1958- // 2-byte character: 110xxxxx ...
1959- incomplete = i < 2 ;
1960- } else if ((c & 0xF0 ) == 0xE0 ) {
1961- // 3-byte character: 1110xxxx ...
1962- incomplete = i < 3 ;
1963- } else if ((c & 0xF8 ) == 0xF0 ) {
1964- // 4-byte character: 11110xxx ...
1965- incomplete = i < 4 ;
1966- }
1967- // else 1-byte character or invalid byte
1968- break ;
1969- }
1953+ bool incomplete = validate_utf8 (slot.generated_text ) < slot.generated_text .size ();
19701954
19711955 if (!incomplete) {
19721956 size_t pos = std::min (slot.n_sent_text , slot.generated_text .size ());
@@ -2062,6 +2046,56 @@ struct server_context {
20622046 return slot.has_next_token ; // continue
20632047 }
20642048
2049+ void populate_token_probs (const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
2050+ size_t n_probs = slot.sparams .n_probs ;
2051+ size_t n_vocab = llama_n_vocab (llama_get_model (ctx));
2052+
2053+ if (post_sampling) {
2054+ const auto * cur_p = llama_sampling_get_candidates (slot.ctx_sampling );
2055+ const size_t max_probs = cur_p->size ;
2056+
2057+ // set probability for sampled token
2058+ for (size_t i = 0 ; i < max_probs; i++) {
2059+ if (cur_p->data [i].id == result.tok ) {
2060+ result.prob = cur_p->data [i].p ;
2061+ break ;
2062+ }
2063+ }
2064+
2065+ // set probability for top n_probs tokens
2066+ result.probs .reserve (max_probs);
2067+ for (size_t i = 0 ; i < std::min (max_probs, n_probs); i++) {
2068+ result.probs .push_back ({
2069+ cur_p->data [i].id ,
2070+ llama_detokenize (ctx, {cur_p->data [i].id }, special),
2071+ cur_p->data [i].p
2072+ });
2073+ }
2074+ } else {
2075+ // TODO: optimize this with min-p optimization
2076+ std::vector<llama_token_data> cur = get_token_probabilities (ctx, idx);
2077+
2078+ // set probability for sampled token
2079+ for (size_t i = 0 ; i < n_vocab; i++) {
2080+ // set probability for sampled token
2081+ if (cur[i].id == result.tok ) {
2082+ result.prob = cur[i].p ;
2083+ break ;
2084+ }
2085+ }
2086+
2087+ // set probability for top n_probs tokens
2088+ result.probs .reserve (n_probs);
2089+ for (size_t i = 0 ; i < std::min (n_vocab, n_probs); i++) {
2090+ result.probs .push_back ({
2091+ cur[i].id ,
2092+ llama_detokenize (ctx, {cur[i].id }, special),
2093+ cur[i].p
2094+ });
2095+ }
2096+ }
2097+ }
2098+
20652099 json get_formated_generation (const server_slot & slot) const {
20662100 const auto eos_bias = slot.sparams .logit_bias .find (llama_token_eos (model));
20672101 const bool ignore_eos = eos_bias != slot.sparams .logit_bias .end () && eos_bias->second < 0 .0f && std::isinf (eos_bias->second );
@@ -2159,38 +2193,19 @@ struct server_context {
21592193 res.stop = false ;
21602194 res.stream = slot.params .stream ;
21612195 res.content = tkn.text_to_send ;
2196+ res.post_sampling_probs = slot.params .post_sampling_probs ;
21622197 res.oaicompat = slot.params .oaicompat ;
21632198 res.oaicompat_model = slot.params .oaicompat_model ;
21642199 res.oaicompat_cmpl_id = slot.params .oaicompat_cmpl_id ;
21652200 res.n_decoded = slot.n_decoded ;
21662201 res.n_prompt_tokens = slot.n_prompt_tokens ;
2167- res.data = json {
2168- {" content" , tkn.text_to_send },
2169- {" stop" , false },
2170- {" id_slot" , slot.id },
2171- {" multimodal" , false }
2172- };
21732202 slot.update_chat_msg (res.oaicompat_msg_diffs );
2174- if (slot.sparams .n_probs > 0 ) {
2175- const std::vector<llama_token> to_send_toks = llama_tokenize (ctx, tkn.text_to_send , false );
2176- const size_t probs_pos = std::min (slot.n_sent_token_probs , slot.generated_token_probs .size ());
2177- const size_t probs_stop_pos = std::min (slot.n_sent_token_probs + to_send_toks.size (), slot.generated_token_probs .size ());
2178-
2179- std::vector<completion_token_output> probs_output;
2180- if (probs_pos < probs_stop_pos) {
2181- probs_output = std::vector<completion_token_output>(
2182- slot.generated_token_probs .begin () + probs_pos,
2183- slot.generated_token_probs .begin () + probs_stop_pos);
2184- }
2185- slot.n_sent_token_probs = probs_stop_pos;
21862203
2187- res.data [" completion_probabilities" ] = probs_vector_to_json (ctx, probs_output);
2204+ // populate res.probs_output
2205+ if (slot.sparams .n_probs > 0 ) {
2206+ res.probs_output = {tkn}; // copy the token probs
21882207 }
21892208
2190- if (slot.oaicompat ) {
2191- res.data [" oaicompat_token_ctr" ] = slot.n_decoded ;
2192- res.data [" model" ] = slot.oaicompat_model ;
2193- }
21942209 // populate timings if this is final response or timings_per_token is enabled
21952210 if (slot.params .timings_per_token ) {
21962211 res.timings = slot.get_timings ();
@@ -2207,56 +2222,30 @@ struct server_context {
22072222 res.stop = true ; // to do: set value
22082223 res.stream = slot.params .stream ;
22092224 res.content = slot.generated_text ;
2225+ res.timings = slot.get_timings ();
2226+ res.post_sampling_probs = slot.params .post_sampling_probs ;
22102227 res.oaicompat = slot.params .oaicompat ;
22112228 res.oaicompat_model = slot.params .oaicompat_model ;
22122229 res.oaicompat_cmpl_id = slot.params .oaicompat_cmpl_id ;
22132230 res.oaicompat_msg = slot.update_chat_msg (res.oaicompat_msg_diffs );
22142231 res.n_decoded = slot.n_decoded ;
22152232 res.n_prompt_tokens = slot.n_prompt_tokens ;
22162233 res.oaicompat_model = slot.oaicompat_model ;
2217- res.data = json {
2218- {" content" , !slot.params .stream ? slot.generated_text : " " },
2219- {" generated_text" , slot.generated_text }, // Always include full text for finish_reason logic
2220- {" id_slot" , slot.id },
2221- {" stop" , true },
2222- {" model" , params.model_alias },
2223- {" tokens_predicted" , slot.n_decoded },
2224- {" tokens_evaluated" , slot.n_prompt_tokens },
2225- {" generation_settings" , get_formated_generation (slot)},
2226- {" prompt" , slot.prompt },
2227- {" truncated" , slot.truncated },
2228- {" stopped_eos" , slot.stopped_eos },
2229- {" stopped_word" , slot.stopped_word },
2230- {" stopped_limit" , slot.stopped_limit },
2231- {" stopping_word" , slot.stopping_word },
2232- {" tokens_cached" , slot.n_past },
2233- {" timings" , slot.get_formated_timings ()},
2234- // {"oaicompat_chat_format", slot.params.oaicompat_chat_format},
2235- };
22362234
2235+ // populate res.probs_output
22372236 if (slot.sparams .n_probs > 0 ) {
2238- std::vector<completion_token_output> probs;
22392237 if (!slot.params .stream && slot.stopped_word ) {
22402238 const std::vector<llama_token> stop_word_toks = llama_tokenize (ctx, slot.stopping_word , false );
22412239
22422240 size_t safe_offset = std::min (slot.generated_token_probs .size (), stop_word_toks.size ());
2243- probs = std::vector<completion_token_output>(
2241+ res. probs_output = std::vector<completion_token_output>(
22442242 slot.generated_token_probs .begin (),
22452243 slot.generated_token_probs .end () - safe_offset);
22462244 } else {
2247- probs = std::vector<completion_token_output>(
2245+ res. probs_output = std::vector<completion_token_output>(
22482246 slot.generated_token_probs .begin (),
22492247 slot.generated_token_probs .end ());
22502248 }
2251- // res.generation_params = slot.params;
2252- res.data [" completion_probabilities" ] = probs_vector_to_json (ctx, probs);
2253- }
2254-
2255- res.timings = slot.get_timings ();
2256-
2257- if (slot.oaicompat ) {
2258- res.data [" oaicompat_token_ctr" ] = slot.n_decoded ;
2259- res.data [" model" ] = slot.oaicompat_model ;
22602249 }
22612250
22622251 queue_results.send (std::move (res));
@@ -3194,7 +3183,8 @@ struct server_context {
31943183 }
31953184
31963185 completion_token_output result;
3197- const llama_token id = llama_sampling_sample (slot.ctx_sampling , ctx, NULL , slot.i_batch - i);
3186+ const int tok_idx = slot.i_batch - i;
3187+ const llama_token id = llama_sampling_sample (slot.ctx_sampling , ctx, NULL , tok_idx);
31983188
31993189 llama_sampling_accept (slot.ctx_sampling , ctx, id, true );
32003190
@@ -3210,35 +3200,12 @@ struct server_context {
32103200
32113201 slot.t_token_generation = (t_current - slot.t_start_generation ) / 1e3 ;
32123202
3213- llama_token_data_array cur_p = { slot.ctx_sampling ->cur .data (), slot.ctx_sampling ->cur .size (), false };
32143203 result.tok = id;
3204+ result.prob = 1 .0f ; // TODO: set it here instead of doing inside populate_token_probs
32153205 result.text_to_send = llama_token_to_piece (ctx, result.tok , accept_special_token (slot, result.tok ));
32163206
3217- const size_t n_probs = std::min (cur_p.size , (size_t ) slot.sparams .n_probs );
3218- if (n_probs > 0 ) {
3219- const size_t n_valid = slot.ctx_sampling ->n_valid ;
3220-
3221- // Make sure at least n_probs top tokens are at the front of the vector:
3222- if (slot.sparams .temp == 0 .0f && n_probs > n_valid) {
3223- llama_sample_top_k (ctx, &cur_p, n_probs, 0 );
3224- }
3225-
3226- if (slot.sparams .temp == 0 .0f ) {
3227- // With greedy sampling the probabilities have possibly not been calculated.
3228- for (size_t i = 0 ; i < n_probs; ++i) {
3229- result.probs .push_back ({
3230- cur_p.data [i].id ,llama_detokenize (ctx, {cur_p.data [i].id }, params.special ),
3231- i == 0 ? 1 .0f : 0 .0f
3232- });
3233- }
3234- } else {
3235- for (size_t i = 0 ; i < n_probs; ++i) {
3236- result.probs .push_back ({
3237- cur_p.data [i].id , llama_detokenize (ctx, {cur_p.data [i].id }, params.special ),
3238- i >= n_valid ? 0 .0f : cur_p.data [i].p // Tokens filtered out due to e.g. top_k have 0 probability.
3239- });
3240- }
3241- }
3207+ if (slot.sparams .n_probs > 0 ) {
3208+ populate_token_probs (slot, result, slot.params .post_sampling_probs , params.special , tok_idx);
32423209 }
32433210
32443211 if (!process_token (result, slot)) {
@@ -3343,7 +3310,11 @@ struct server_context {
33433310
33443311 result.tok = ids[i];
33453312 result.text_to_send = llama_token_to_piece (ctx, result.tok , params.special );
3346- // result.prob = 1.0f; // set later
3313+ result.prob = 1 .0f ; // set later
3314+
3315+ if (slot.sparams .n_probs > 0 ) {
3316+ populate_token_probs (slot, result, slot.params .post_sampling_probs , params.special , i);
3317+ }
33473318
33483319 if (!process_token (result, slot)) {
33493320 // release slot because of stop condition
0 commit comments