@@ -197,10 +197,10 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
197197 sampler_tester tester (n_vocab);
198198
199199 llama_token min_token_id = 0 ;
200- const llama_token max_token_id = n_vocab- 1 ;
200+ const llama_token max_token_id = n_vocab - 1 ;
201201
202202 for (auto s : samplers_sequence) {
203- switch (s){
203+ switch (s) {
204204 case ' k' : tester.apply (llama_sampler_init_top_k (top_k)); break ;
205205 case ' y' : GGML_ABORT (" typical test not implemented" );
206206 case ' p' : tester.apply (llama_sampler_init_top_p (top_p, 1 )); break ;
@@ -243,10 +243,10 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
243243 }
244244
245245 GGML_ASSERT (size == expected_size);
246- GGML_ASSERT (cur_p.data [0 ].id == max_token_id);
247- GGML_ASSERT (cur_p.data [expected_size-1 ].id == min_token_id);
246+ GGML_ASSERT (!cur_p. sorted || cur_p.data [0 ].id == max_token_id);
247+ GGML_ASSERT (!cur_p. sorted || cur_p.data [expected_size-1 ].id == min_token_id);
248248 } else if (s == ' m' ) {
249- int expected_size = ceilf ((1 .0f - min_p) * n_vocab);
249+ int expected_size = ceilf ((1 .0f - min_p) * n_vocab);
250250 expected_size = std::max (expected_size, 1 );
251251 expected_size = std::min (expected_size, size);
252252
@@ -256,14 +256,14 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
256256 min_token_id = std::min (min_token_id, (llama_token)(n_vocab - 1 ));
257257
258258 GGML_ASSERT (size == expected_size);
259- GGML_ASSERT (cur_p.data [0 ].id == max_token_id);
260- GGML_ASSERT (cur_p.data [expected_size-1 ].id == min_token_id);
259+ GGML_ASSERT (!cur_p. sorted || cur_p.data [0 ].id == max_token_id);
260+ GGML_ASSERT (!cur_p. sorted || cur_p.data [expected_size-1 ].id == min_token_id);
261261 } else {
262262 GGML_ABORT (" fatal error" );
263263 }
264264 }
265265
266- printf (" Sampler queue %3s OK with n_vocab=%05zu top_k=%05d top_p=%f min_p=%f\n " ,
266+ printf (" Sampler queue %3s OK with n_vocab=%05zu top_k=%5d top_p=%f min_p=%f\n " ,
267267 samplers_sequence.c_str (), n_vocab, top_k, top_p, min_p);
268268}
269269
@@ -308,28 +308,28 @@ static void test_perf() {
308308int main (void ) {
309309 ggml_time_init ();
310310
311- test_temp ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 1 .0f );
312- test_temp ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {1 .0f , 0 .0f , 0 .0f , 0 .0f }, 0 .0f );
311+ test_temp ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .1f , 0 .2f , 0 .3f , 0 .4f }, 1 .0f );
312+ test_temp ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .0f , 0 .0f , 0 .0f , 1 .0f }, 0 .0f );
313313
314- test_temp_ext ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 1 .0f , 0 .0f , 1 .0f );
315- test_temp_ext ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {1 .0f , 0 .0f , 0 .0f , 0 .0f }, 0 .0f , 0 .0f , 1 .0f );
314+ test_temp_ext ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .1f , 0 .2f , 0 .3f , 0 .4f }, 1 .0f , 0 .0f , 1 .0f );
315+ test_temp_ext ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .0f , 0 .0f , 0 .0f , 1 .0f }, 0 .0f , 0 .0f , 1 .0f );
316316
317317 test_top_k ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {1 .0f }, 1 );
318318 test_top_k ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .44444f , 0 .33333f , 0 .22222f }, 3 );
319319 test_top_k ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 4 );
320- test_top_k ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 0 );
320+ test_top_k ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .1f , 0 .2f , 0 .3f , 0 .4f }, 0 );
321321
322322 test_top_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {1 .0f }, 0 );
323323 test_top_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .571429f , 0 .428571f }, 0 .7f );
324324 test_top_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .44444f , 0 .33333f , 0 .22222f }, 0 .8f );
325- test_top_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 1 .0f );
326-
327- test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /1 .0f , 0 .3f /1 .0f , 0 .2f /1 .0f , 0 .1f /1 .0f }, 0 .00f );
328- test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /1 .0f , 0 .3f /1 .0f , 0 .2f /1 .0f , 0 .1f /1 .0f }, 0 .24f );
329- test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .9f , 0 .3f /0 .9f , 0 .2f /0 .9f }, 0 .26f );
330- test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .9f , 0 .3f /0 .9f , 0 .2f /0 .9f }, 0 .49f );
331- test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .7f , 0 .3f /0 .7f }, 0 .51f );
332- test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .7f , 0 .3f /0 .7f }, 0 .74f );
325+ test_top_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .1f , 0 .2f , 0 .3f , 0 .4f }, 1 .0f );
326+
327+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .1f /1 .0f , 0 .2f /1 .0f , 0 .3f /1 .0f , 0 .4f /1 .0f }, 0 .00f );
328+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .1f /1 .0f , 0 .2f /1 .0f , 0 .3f /1 .0f , 0 .4f /1 .0f }, 0 .24f );
329+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .2f /0 .9f , 0 .3f /0 .9f , 0 .4f /0 .9f }, 0 .26f );
330+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .2f /0 .9f , 0 .3f /0 .9f , 0 .4f /0 .9f }, 0 .49f );
331+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .3f /0 .7f , 0 .4f /0 .7f }, 0 .51f );
332+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .3f /0 .7f , 0 .4f /0 .7f }, 0 .74f );
333333 test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .4f }, 0 .76f );
334334 test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .4f }, 1 .00f );
335335 test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .4f }, 1 .05f );
@@ -345,23 +345,23 @@ int main(void) {
345345 test_typical ({0 .97f , 0 .01f , 0 .01f , 0 .01f }, {0 .97f }, 0 .5f );
346346 test_typical ({0 .4f , 0 .2f , 0 .2f , 0 .2f }, {0 .2f , 0 .2f , 0 .2f }, 0 .5f );
347347
348- test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 }, {0 . 25f , 0 .25f , 0 .25f , 0 .25f , 0 }, 50 .0f , 0 .0f , 0 .0f );
349- test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 }, {0 . 5f , 0 . 5f , 0 , 0 , 0 }, 50 .0f , 0 .0f , 0 .0f );
350- test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 0 , 0 }, {0 . 5f , 0 . 5f , 0 , 0 , 0 }, 50 .0f , 0 .0f , 0 .0f );
348+ test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 }, {0 , 0 .25f , 0 .25f , 0 .25f , 0 . 25f }, 50 .0f , 0 .0f , 0 .0f );
349+ test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 }, {0 , 0 , 0 , 0 . 5f , 0 . 5f }, 50 .0f , 0 .0f , 0 .0f );
350+ test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 0 , 0 }, {0 , 0 , 0 , 0 . 5f , 0 . 5f }, 50 .0f , 0 .0f , 0 .0f );
351351
352- test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 }, {0 .249997f , 0 .249997f , 0 .249997f , 0 .249997f , 0 .000011f }, 1 .0f , 5 .0f , 5 .0f );
353- test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 }, {0 .499966f , 0 .499966f , 0 .000023f , 0 .000023f , 0 .000023f }, 1 .0f , 5 .0f , 5 .0f );
354- test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 0 , 0 }, {0 .499977f , 0 .499977f , 0 .000023f , 0 .000023f , 0 .000000f }, 1 .0f , 5 .0f , 5 .0f );
352+ test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 }, {0 .000011f , 0 .249997f , 0 .249997f , 0 .249997f , 0 .249997f }, 1 .0f , 5 .0f , 5 .0f );
353+ test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 }, {0 .000023f , 0 .000023f , 0 .000023f , 0 .499966f , 0 .499966f }, 1 .0f , 5 .0f , 5 .0f );
354+ test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 0 , 0 }, {0 .000000f , 0 .000023f , 0 .000023f , 0 .499977f , 0 .499977f }, 1 .0f , 5 .0f , 5 .0f );
355355
356356
357357 test_dry ({0 .25f , 0 .25f , 0 .25f , 0 .25f }, {0 , 1 }, {0 .25f , 0 .25f , 0 .25f , 0 .25f }, 1 .0f , 1 .1f , 2 , 4 , {});
358- test_dry ({0 .25f , 0 .25f , 0 .25f , 0 .25f }, {0 , 1 , 2 , 0 , 1 }, {0 .296923f , 0 .296923f , 0 .296923f , 0 .109232f }, 1 .0f , 1 .1f , 2 , 5 , {});
358+ test_dry ({0 .25f , 0 .25f , 0 .25f , 0 .25f }, {0 , 1 , 2 , 0 , 1 }, {0 .296923f , 0 .296923f , 0 .109232f , 0 .296923f }, 1 .0f , 1 .1f , 2 , 5 , {});
359359 test_dry ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 3 , 4 , 0 , 1 }, {0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, 1 .0f , 1 .1f , 2 , 6 , {{3 }});
360- test_dry ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 0 , 1 }, {0 .241818f , 0 .241818f , 0 .241818f , 0 .241818f , 0 .032727f }, 2 .0f , 1 .1f , 2 , 5 , {});
360+ test_dry ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 0 , 1 }, {0 .241818f , 0 .241818f , 0 .032727f , 0 .241818f , 0 .241818f }, 2 .0f , 1 .1f , 2 , 5 , {});
361361 test_dry ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 , 3 , 4 , 0 , 1 }, {0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, 1 .0f , 1 .1f , 4 , 7 , {});
362362
363363 test_top_n_sigma ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .571429f , 0 .428571f , 0 .0f , 0 .0f }, 1 .00f );
364- test_top_n_sigma ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 0 .00f ); // top_n_sigma == 0 now represents a no-op rather than greedy decoding as of PR#13345
364+ test_top_n_sigma ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .1f , 0 .2f , 0 .3f , 0 .4f }, 0 .00f ); // top_n_sigma == 0 now represents a no-op rather than greedy decoding as of PR#13345
365365 test_top_n_sigma ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 3 .00f );
366366
367367 test_sampler_queue (10000 , " k" , 10000 , 1 .0f , 1 .0f );
0 commit comments