@@ -230,8 +230,8 @@ int main(int argc, char ** argv) {
230230 fprintf (stderr, " Input prefix: '%s'\n " , params.input_prefix .c_str ());
231231 }
232232 }
233- fprintf (stderr, " sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " ,
234- params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
233+ fprintf (stderr, " sampling: repeat_last_n = %d, repeat_penalty = % f, alpha_presence = %f, alpha_frequency = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f\n " ,
234+ params.repeat_last_n , params.repeat_penalty , params. alpha_presence , params. alpha_frequency , params. top_k , params.tfs_z , params. top_p , params.typical_p , params.temp );
235235 fprintf (stderr, " generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n " , n_ctx, params.n_batch , params.n_predict , params.n_keep );
236236 fprintf (stderr, " \n\n " );
237237
@@ -304,23 +304,69 @@ int main(int argc, char ** argv) {
304304
305305 if ((int ) embd_inp.size () <= n_consumed && !is_interacting) {
306306 // out of user input, sample next token
307- const int32_t top_k = params.top_k ;
308- const float top_p = params.top_p ;
309307 const float temp = params.temp ;
308+ const int32_t top_k = params.top_k <= 0 ? llama_n_vocab (ctx) : params.top_k ;
309+ const float top_p = params.top_p ;
310+ const float tfs_z = params.tfs_z ;
311+ const float typical_p = params.typical_p ;
312+ const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n ;
310313 const float repeat_penalty = params.repeat_penalty ;
314+ const float alpha_presence = params.alpha_presence ;
315+ const float alpha_frequency = params.alpha_frequency ;
311316
312317 llama_token id = 0 ;
313318
314319 {
315320 auto logits = llama_get_logits (ctx);
321+ auto n_vocab = llama_n_vocab (ctx);
316322
317323 if (params.ignore_eos ) {
318- logits[llama_token_eos ()] = 0 ;
324+ logits[llama_token_eos ()] = -INFINITY;
325+ }
326+
327+ std::vector<llama_token_data> candidates;
328+ candidates.reserve (n_vocab);
329+ for (size_t i = 0 ; i < n_vocab; i++) {
330+ candidates.emplace_back (i, logits[i], 0 .0f );
319331 }
320332
321- id = llama_sample_top_p_top_k (ctx,
322- last_n_tokens.data () + n_ctx - params.repeat_last_n ,
323- params.repeat_last_n , top_k, top_p, temp, repeat_penalty);
333+ llama_token_data_array candidates_p = { candidates.data (), candidates.size () };
334+
335+ // Apply penalties
336+ auto last_n_repeat = std::min (std::min ((int )last_n_tokens.size (), repeat_last_n), n_ctx);
337+ llama_sample_repetition_penalty (&candidates_p,
338+ last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
339+ last_n_repeat, repeat_penalty);
340+ llama_sample_frequency_and_presence_penalties (&candidates_p,
341+ last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
342+ last_n_repeat, alpha_frequency, alpha_presence);
343+
344+
345+ #if 1
346+ if (temp <= 0 ) {
347+ // Greedy sampling
348+ id = llama_sample_token_greedy (ctx, &candidates_p);
349+ } else {
350+ // Temperature sampling
351+ llama_sample_top_k (&candidates_p, top_k);
352+ llama_sample_tail_free (&candidates_p, tfs_z);
353+ llama_sample_typical (&candidates_p, typical_p);
354+ llama_sample_top_p (&candidates_p, top_p);
355+
356+ llama_sample_temperature (&candidates_p, temp);
357+ // printf("`%d`", candidates_p.size);
358+ id = llama_sample_token (ctx, &candidates_p);
359+ }
360+ #else
361+ const float tau = 5.0f;
362+ static float mu = 2.0f * tau;
363+ static int k = 40;
364+ const float eta = 0.1f;
365+ const int m = 100;
366+ const float N = n_vocab;
367+ id = llama_sample_mirostat(ctx, &candidates_p, tau, eta, m, N, &k, &mu);
368+ // id = llama_sample_mirostat_v2(ctx, &candidates_p, tau, eta, &mu);
369+ #endif
324370
325371 last_n_tokens.erase (last_n_tokens.begin ());
326372 last_n_tokens.push_back (id);
0 commit comments