@@ -32,6 +32,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_model * m
3232 lparams.penalize_nl = params.penalize_nl ;
3333 lparams.ignore_eos = params.ignore_eos ;
3434
35+ lparams.n_samplers = params.samplers .size ();
36+
3537 result->smpl = llama_sampling_init (model, lparams);
3638
3739 llama_sampling_set_grammar (result->smpl , params.grammar .c_str (), " root" );
@@ -101,7 +103,7 @@ std::string llama_sampling_print(const gpt_sampling_params & params) {
101103std::string llama_sampling_order_print (const gpt_sampling_params & params) {
102104 std::string result = " CFG -> Penalties " ;
103105 if (params.mirostat == 0 ) {
104- for (auto sampler_type : params.samplers_sequence ) {
106+ for (auto sampler_type : params.samplers ) {
105107 const auto sampler_type_name = llama_sampling_type_to_str (sampler_type);
106108 if (!sampler_type_name.empty ()) {
107109 result += " -> " + sampler_type_name + " " ;
@@ -114,6 +116,18 @@ std::string llama_sampling_order_print(const gpt_sampling_params & params) {
114116 return result;
115117}
116118
119+ char llama_sampling_type_to_chr (llama_sampler_type sampler_type) {
120+ switch (sampler_type) {
121+ case LLAMA_SAMPLER_TYPE_TOP_K: return ' k' ;
122+ case LLAMA_SAMPLER_TYPE_TFS_Z: return ' f' ;
123+ case LLAMA_SAMPLER_TYPE_TYPICAL_P: return ' y' ;
124+ case LLAMA_SAMPLER_TYPE_TOP_P: return ' p' ;
125+ case LLAMA_SAMPLER_TYPE_MIN_P: return ' m' ;
126+ case LLAMA_SAMPLER_TYPE_TEMPERATURE: return ' t' ;
127+ default : return ' ?' ;
128+ }
129+ }
130+
117131std::string llama_sampling_type_to_str (llama_sampler_type sampler_type) {
118132 switch (sampler_type) {
119133 case LLAMA_SAMPLER_TYPE_TOP_K: return " top_k" ;
@@ -128,26 +142,26 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
128142
129143std::vector<llama_sampler_type> llama_sampling_types_from_names (const std::vector<std::string> & names, bool allow_alt_names) {
130144 std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
131- {" top_k" , LLAMA_SAMPLER_TYPE_TOP_K},
132- {" top_p" , LLAMA_SAMPLER_TYPE_TOP_P},
133- {" typical_p" , LLAMA_SAMPLER_TYPE_TYPICAL_P},
134- {" min_p" , LLAMA_SAMPLER_TYPE_MIN_P},
135- {" tfs_z" , LLAMA_SAMPLER_TYPE_TFS_Z},
136- {" temperature" , LLAMA_SAMPLER_TYPE_TEMPERATURE}
145+ { " top_k" , LLAMA_SAMPLER_TYPE_TOP_K },
146+ { " top_p" , LLAMA_SAMPLER_TYPE_TOP_P },
147+ { " typical_p" , LLAMA_SAMPLER_TYPE_TYPICAL_P },
148+ { " min_p" , LLAMA_SAMPLER_TYPE_MIN_P },
149+ { " tfs_z" , LLAMA_SAMPLER_TYPE_TFS_Z },
150+ { " temperature" , LLAMA_SAMPLER_TYPE_TEMPERATURE },
137151 };
138152
139153 // since samplers names are written multiple ways
140154 // make it ready for both system names and input names
141155 std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
142- {" top-k" , LLAMA_SAMPLER_TYPE_TOP_K},
143- {" top-p" , LLAMA_SAMPLER_TYPE_TOP_P},
144- {" nucleus" , LLAMA_SAMPLER_TYPE_TOP_P},
145- {" typical-p" , LLAMA_SAMPLER_TYPE_TYPICAL_P},
146- {" typical" , LLAMA_SAMPLER_TYPE_TYPICAL_P},
147- {" min-p" , LLAMA_SAMPLER_TYPE_MIN_P},
148- {" tfs-z" , LLAMA_SAMPLER_TYPE_TFS_Z},
149- {" tfs" , LLAMA_SAMPLER_TYPE_TFS_Z},
150- {" temp" , LLAMA_SAMPLER_TYPE_TEMPERATURE}
156+ { " top-k" , LLAMA_SAMPLER_TYPE_TOP_K },
157+ { " top-p" , LLAMA_SAMPLER_TYPE_TOP_P },
158+ { " nucleus" , LLAMA_SAMPLER_TYPE_TOP_P },
159+ { " typical-p" , LLAMA_SAMPLER_TYPE_TYPICAL_P },
160+ { " typical" , LLAMA_SAMPLER_TYPE_TYPICAL_P },
161+ { " min-p" , LLAMA_SAMPLER_TYPE_MIN_P },
162+ { " tfs-z" , LLAMA_SAMPLER_TYPE_TFS_Z },
163+ { " tfs" , LLAMA_SAMPLER_TYPE_TFS_Z },
164+ { " temp" , LLAMA_SAMPLER_TYPE_TEMPERATURE },
151165 };
152166
153167 std::vector<llama_sampler_type> sampler_types;
@@ -172,12 +186,12 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
172186
173187std::vector<llama_sampler_type> llama_sampling_types_from_chars (const std::string & names_string) {
174188 std::unordered_map<char , llama_sampler_type> sampler_name_map {
175- {' k ' , LLAMA_SAMPLER_TYPE_TOP_K},
176- {' p ' , LLAMA_SAMPLER_TYPE_TOP_P },
177- {' y ' , LLAMA_SAMPLER_TYPE_TYPICAL_P},
178- {' m ' , LLAMA_SAMPLER_TYPE_MIN_P },
179- {' f ' , LLAMA_SAMPLER_TYPE_TFS_Z },
180- {' t ' , LLAMA_SAMPLER_TYPE_TEMPERATURE}
189+ { llama_sampling_type_to_chr (LLAMA_SAMPLER_TYPE_TOP_K), LLAMA_SAMPLER_TYPE_TOP_K },
190+ { llama_sampling_type_to_chr (LLAMA_SAMPLER_TYPE_TFS_Z), LLAMA_SAMPLER_TYPE_TFS_Z },
191+ { llama_sampling_type_to_chr (LLAMA_SAMPLER_TYPE_TYPICAL_P), LLAMA_SAMPLER_TYPE_TYPICAL_P },
192+ { llama_sampling_type_to_chr (LLAMA_SAMPLER_TYPE_TOP_P), LLAMA_SAMPLER_TYPE_TOP_P },
193+ { llama_sampling_type_to_chr (LLAMA_SAMPLER_TYPE_MIN_P), LLAMA_SAMPLER_TYPE_MIN_P },
194+ { llama_sampling_type_to_chr (LLAMA_SAMPLER_TYPE_TEMPERATURE) , LLAMA_SAMPLER_TYPE_TEMPERATURE }
181195 };
182196
183197 std::vector<llama_sampler_type> sampler_types;
@@ -199,10 +213,10 @@ static void sampler_queue(
199213
200214 const gpt_sampling_params & params = ctx_sampling->params ;
201215
202- const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence ;
216+ const std::vector<llama_sampler_type> & samplers = params.samplers ;
203217
204- for (auto sampler_type : samplers_sequence ) {
205- switch (sampler_type ) {
218+ for (const auto & sampler : samplers ) {
219+ switch (sampler ) {
206220 case LLAMA_SAMPLER_TYPE_TOP_K: llama_sampling_top_k (smpl, cur_p); break ;
207221 case LLAMA_SAMPLER_TYPE_TFS_Z: llama_sampling_tail_free (smpl, cur_p); break ;
208222 case LLAMA_SAMPLER_TYPE_TYPICAL_P: llama_sampling_typical (smpl, cur_p); break ;
0 commit comments