1111#include < random>
1212#include < unordered_map>
1313
14+ static int llama_sample_dist (llama_token_data_array * cur_p, std::mt19937 & rng, std::vector<float > & probs) {
15+ probs.resize (cur_p->size );
16+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
17+ probs[i] = cur_p->data [i].p ;
18+ }
19+
20+ std::discrete_distribution<size_t > dist (probs.begin (), probs.end ());
21+
22+ return dist (rng);
23+ }
24+
1425static void llama_log_softmax (float * array, size_t size) {
1526 float max_l = *std::max_element (array, array + size);
1627 float sum = 0 .f ;
@@ -456,22 +467,16 @@ struct llama_sampler_context_dist {
456467 const uint32_t seed;
457468
458469 std::mt19937 rng;
470+
471+ std::vector<float > probs; // work array
459472};
460473
461474static struct llama_sampler_i llama_sampler_dist_i = {
462475 /* .name = */ [](const struct llama_sampler * /* smpl*/ ) { return " dist" ; },
463476 /* .accept = */ nullptr ,
464477 /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
465478 auto * ctx = (llama_sampler_context_dist *) smpl->ctx ;
466- std::vector<float > probs;
467- probs.reserve (cur_p->size );
468- for (size_t i = 0 ; i < cur_p->size ; ++i) {
469- probs.push_back (cur_p->data [i].p );
470- }
471-
472- std::discrete_distribution<size_t > dist (probs.begin (), probs.end ());
473-
474- cur_p->selected = dist (ctx->rng );
479+ cur_p->selected = llama_sample_dist (cur_p, ctx->rng , ctx->probs );
475480 },
476481 /* .reset = */ nullptr ,
477482 /* .clone = */ [](const struct llama_sampler * smpl) {
@@ -489,6 +494,7 @@ struct llama_sampler * llama_sampler_init_dist_impl(uint32_t seed) {
489494 /* .ctx = */ new llama_sampler_context_dist {
490495 /* .seed = */ seed,
491496 /* .rng = */ std::mt19937 (seed),
497+ /* .probs = */ {},
492498 },
493499 };
494500}
@@ -761,35 +767,23 @@ struct llama_sampler * llama_sampler_init_temp_ext_impl(float temp, float delta,
761767struct llama_sampler_context_mirostat {
762768 const struct llama_vocab * vocab;
763769
770+ const uint32_t seed;
771+
764772 const float tau;
765773 const float eta;
766774
767775 const int32_t m;
768776
769777 float mu;
770778
771- std::vector<llama_token_data> cur;
779+ std::mt19937 rng;
780+
781+ std::vector<float > probs;
772782};
773783
774784static struct llama_sampler_i llama_sampler_mirostat_i = {
775785 /* .name = */ [](const struct llama_sampler * /* smpl*/ ) { return " mirostat" ; },
776- /* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
777- auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx ;
778-
779- int32_t idx = -1 ;
780- for (size_t i = 0 ; i < ctx->cur .size (); ++i) {
781- if (ctx->cur [i].id == token) {
782- idx = i;
783- break ;
784- }
785- }
786-
787- float observed_surprise = -log2f (ctx->cur [idx].p );
788- float e = observed_surprise - ctx->tau ;
789-
790- // Update mu using the learning rate and error
791- ctx->mu = ctx->mu - ctx->eta * e;
792- },
786+ /* .accept = */ nullptr ,
793787 /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
794788 auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx ;
795789
@@ -812,70 +806,66 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
812806 float k = powf ((epsilon_hat * powf (2 , ctx->mu )) / (1 - powf (ctx->vocab ->n_vocab , -epsilon_hat)), 1 / s_hat);
813807
814808 llama_sampler_top_k_impl (cur_p, std::max (int (k), 1 ));
809+ llama_sampler_softmax_impl (cur_p);
815810
816- // remember the order to be able to compute the distance later when accepting the token
817- ctx->cur .resize (cur_p->size );
818- for (size_t i = 0 ; i < cur_p->size ; ++i) {
819- ctx->cur [i] = cur_p->data [i];
820- }
811+ const int idx = llama_sample_dist (cur_p, ctx->rng , ctx->probs );
812+
813+ cur_p->selected = idx;
814+
815+ float observed_surprise = -log2f (cur_p->data [idx].p );
816+ float e = observed_surprise - ctx->tau ;
817+
818+ // Update mu using the learning rate and error
819+ ctx->mu = ctx->mu - ctx->eta * e;
821820 },
822821 /* .reset = */ [](struct llama_sampler * smpl) {
823822 auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx ;
824823 ctx->mu = 2 .0f *ctx->tau ;
824+ ctx->rng = std::mt19937 (ctx->seed );
825825 },
826826 /* .clone = */ [](const struct llama_sampler * smpl) {
827827 const auto * ctx = (const llama_sampler_context_mirostat *) smpl->ctx ;
828- return llama_sampler_init_mirostat_impl (*ctx->vocab , ctx->tau , ctx->eta , ctx->m );
828+ return llama_sampler_init_mirostat_impl (*ctx->vocab , ctx->seed , ctx-> tau , ctx->eta , ctx->m );
829829 },
830830 /* .free = */ [](struct llama_sampler * smpl) {
831831 delete (llama_sampler_context_mirostat *) smpl->ctx ;
832832 },
833833};
834834
835- struct llama_sampler * llama_sampler_init_mirostat_impl (const struct llama_vocab & vocab, float tau, float eta, int32_t m) {
835+ struct llama_sampler * llama_sampler_init_mirostat_impl (const struct llama_vocab & vocab, uint32_t seed, float tau, float eta, int32_t m) {
836836 return new llama_sampler {
837837 /* .iface = */ &llama_sampler_mirostat_i,
838838 /* .ctx = */ new llama_sampler_context_mirostat {
839839 /* .vocab = */ &vocab,
840+ /* .seed = */ seed,
840841 /* .tau = */ tau,
841842 /* .eta = */ eta,
842843 /* .m = */ m,
843844 /* .mu = */ 2 .0f *tau,
844- /* .cur = */ {},
845+ /* .rng = */ std::mt19937 (seed),
846+ /* .probs = */ {},
845847 },
846848 };
847849}
848850
849851// mirostat v2
850852
851853struct llama_sampler_context_mirostat_v2 {
854+ const uint32_t seed;
855+
852856 const float tau;
853857 const float eta;
854858
855859 float mu;
856860
857- std::vector<llama_token_data> cur;
861+ std::mt19937 rng;
862+
863+ std::vector<float > probs;
858864};
859865
860866static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
861867 /* .name = */ [](const struct llama_sampler * /* smpl*/ ) { return " mirostat-v2" ; },
862- /* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
863- auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx ;
864-
865- int32_t idx = -1 ;
866- for (size_t i = 0 ; i < ctx->cur .size (); ++i) {
867- if (ctx->cur [i].id == token) {
868- idx = i;
869- break ;
870- }
871- }
872-
873- float observed_surprise = -log2f (ctx->cur [idx].p );
874- float e = observed_surprise - ctx->tau ;
875-
876- // Update mu using the learning rate and error
877- ctx->mu = ctx->mu - ctx->eta * e;
878- },
868+ /* .accept = */ nullptr ,
879869 /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
880870 auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx ;
881871
@@ -893,33 +883,40 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
893883 // Normalize the probabilities of the remaining words
894884 llama_sampler_softmax_impl (cur_p);
895885
896- // remember the order to be able to compute the distance later when accepting the token
897- ctx->cur .resize (cur_p->size );
898- for (size_t i = 0 ; i < cur_p->size ; ++i) {
899- ctx->cur [i] = cur_p->data [i];
900- }
886+ const int idx = llama_sample_dist (cur_p, ctx->rng , ctx->probs );
887+
888+ cur_p->selected = idx;
889+
890+ float observed_surprise = -log2f (cur_p->data [idx].p );
891+ float e = observed_surprise - ctx->tau ;
892+
893+ // Update mu using the learning rate and error
894+ ctx->mu = ctx->mu - ctx->eta * e;
901895 },
902896 /* .reset = */ [](struct llama_sampler * smpl) {
903897 auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx ;
904898 ctx->mu = 2 .0f *ctx->tau ;
899+ ctx->rng = std::mt19937 (ctx->seed );
905900 },
906901 /* .clone = */ [](const struct llama_sampler * smpl) {
907902 const auto * ctx = (const llama_sampler_context_mirostat_v2 *) smpl->ctx ;
908- return llama_sampler_init_mirostat_v2_impl (ctx->tau , ctx->eta );
903+ return llama_sampler_init_mirostat_v2_impl (ctx->seed , ctx-> tau , ctx->eta );
909904 },
910905 /* .free = */ [](struct llama_sampler * smpl) {
911906 delete (llama_sampler_context_mirostat_v2 *) smpl->ctx ;
912907 },
913908};
914909
915- struct llama_sampler * llama_sampler_init_mirostat_v2_impl (float tau, float eta) {
910+ struct llama_sampler * llama_sampler_init_mirostat_v2_impl (uint32_t seed, float tau, float eta) {
916911 return new llama_sampler {
917912 /* .iface = */ &llama_sampler_mirostat_v2_i,
918913 /* .ctx = */ new llama_sampler_context_mirostat_v2 {
919- /* .tau = */ tau,
920- /* .eta = */ eta,
921- /* .mu = */ 2 .0f *tau,
922- /* .cur = */ {},
914+ /* .seed = */ seed,
915+ /* .tau = */ tau,
916+ /* .eta = */ eta,
917+ /* .mu = */ 2 .0f *tau,
918+ /* .rng = */ std::mt19937 (seed),
919+ /* .probs = */ {},
923920 },
924921 };
925922}
@@ -1154,9 +1151,15 @@ struct llama_sampler * llama_sampler_init_logit_bias_impl(
11541151
11551152static struct llama_sampler_i llama_sampler_chain_i = {
11561153 /* .name = */ [](const struct llama_sampler * /* smpl*/ ) { return " chain" ; },
1157- /* .accept = */ [](struct llama_sampler * smpl, llama_token /* token*/ ) {
1154+ /* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
11581155 auto * chain = (llama_sampler_chain *) smpl->ctx ;
11591156
1157+ time_meas tm (chain->t_sample_us , chain->params .no_timing );
1158+
1159+ for (auto * smpl : chain->samplers ) {
1160+ llama_sampler_accept_impl (*smpl, token);
1161+ }
1162+
11601163 chain->n_sample ++;
11611164 },
11621165 /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
0 commit comments