@@ -64,14 +64,15 @@ int main(int argc, char ** argv) {
6464 // first run
6565 printf (" \n %s" , params.prompt .c_str ());
6666 for (auto i = 0 ; i < params.n_predict ; i++) {
67- auto next_token = llama_sample_top_p_top_k (
68- ctx,
69- &last_n_tokens_data.back () - params.repeat_last_n ,
70- params.repeat_last_n ,
71- 40 ,
72- 1.0 ,
73- 1.0 ,
74- 1.1 );
67+ auto logits = llama_get_logits (ctx);
68+ auto n_vocab = llama_n_vocab (ctx);
69+ std::vector<llama_token_data> candidates;
70+ candidates.reserve (n_vocab);
71+ for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
72+ candidates.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
73+ }
74+ llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
75+ auto next_token = llama_sample_token (ctx, &candidates_p);
7576 auto next_token_str = llama_token_to_str (ctx, next_token);
7677 last_n_tokens_data.push_back (next_token);
7778 printf (" %s" , next_token_str);
@@ -106,14 +107,15 @@ int main(int argc, char ** argv) {
106107
107108 // second run
108109 for (auto i = 0 ; i < params.n_predict ; i++) {
109- auto next_token = llama_sample_top_p_top_k (
110- ctx2,
111- &last_n_tokens_data.back () - params.repeat_last_n ,
112- params.repeat_last_n ,
113- 40 ,
114- 1.0 ,
115- 1.0 ,
116- 1.1 );
110+ auto logits = llama_get_logits (ctx2);
111+ auto n_vocab = llama_n_vocab (ctx2);
112+ std::vector<llama_token_data> candidates;
113+ candidates.reserve (n_vocab);
114+ for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
115+ candidates.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
116+ }
117+ llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
118+ auto next_token = llama_sample_token (ctx2, &candidates_p);
117119 auto next_token_str = llama_token_to_str (ctx2, next_token);
118120 last_n_tokens_data.push_back (next_token);
119121 printf (" %s" , next_token_str);
0 commit comments