@@ -168,76 +168,19 @@ static llama_token llama_sampling_sample_impl(
168168 bool is_resampling) { // Add a parameter to indicate if we are resampling
169169 const llama_sampling_params & params = ctx_sampling->params ;
170170
171- const int n_vocab = llama_n_vocab (llama_get_model (ctx_main));
172-
173171 const float temp = params.temp ;
174- const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n ;
175- const float penalty_repeat = params.penalty_repeat ;
176- const float penalty_freq = params.penalty_freq ;
177- const float penalty_present = params.penalty_present ;
178172 const int mirostat = params.mirostat ;
179173 const float mirostat_tau = params.mirostat_tau ;
180174 const float mirostat_eta = params.mirostat_eta ;
181- const bool penalize_nl = params.penalize_nl ;
182175
183- auto & prev = ctx_sampling->prev ;
184- auto & cur = ctx_sampling->cur ;
185-
186- llama_token id = 0 ;
187-
188- // Get a pointer to the logits
189- float * logits = llama_get_logits_ith (ctx_main, idx);
190-
191- // Declare original_logits at the beginning of the function scope
192176 std::vector<float > original_logits;
193-
177+ auto cur_p = llama_sampling_prepare (ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
194178 if (!is_resampling) {
195- // Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
196- original_logits = std::vector<float >(logits, logits + llama_n_vocab (llama_get_model (ctx_main)));
197- }
198-
199- // apply params.logit_bias map
200- for (auto it = params.logit_bias .begin (); it != params.logit_bias .end (); it++) {
201- logits[it->first ] += it->second ;
202- }
203-
204- if (ctx_cfg) {
205- float * logits_guidance = llama_get_logits_ith (ctx_cfg, idx);
206- llama_sample_apply_guidance (ctx_main, logits, logits_guidance, params.cfg_scale );
207- }
208-
209- cur.clear ();
210-
211- for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
212- cur.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
213- }
214-
215- llama_token_data_array cur_p = { cur.data (), cur.size (), false };
216-
217- // apply penalties
218- const auto & penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
219- const int penalty_tokens_used_size = std::min ((int )penalty_tokens.size (), penalty_last_n);
220- if (penalty_tokens_used_size) {
221- const float nl_logit = logits[llama_token_nl (llama_get_model (ctx_main))];
222-
223- llama_sample_repetition_penalties (ctx_main, &cur_p,
224- penalty_tokens.data () + penalty_tokens.size () - penalty_tokens_used_size,
225- penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
226-
227- if (!penalize_nl) {
228- for (size_t idx = 0 ; idx < cur_p.size ; idx++) {
229- if (cur_p.data [idx].id == llama_token_nl (llama_get_model (ctx_main))) {
230- cur_p.data [idx].logit = nl_logit;
231- break ;
232- }
233- }
234- }
235- }
236-
237- // If we are in the resampling phase, apply grammar checks before sampling logic
238- if (is_resampling && ctx_sampling->grammar != NULL ) {
239- llama_sample_grammar (ctx_main, &cur_p, ctx_sampling->grammar );
179+ GGML_ASSERT (!original_logits.empty ());
240180 }
181+ llama_token id = 0 ;
182+ // Get a pointer to the logits
183+ float * logits = llama_get_logits_ith (ctx_main, idx);
241184
242185 if (temp < 0.0 ) {
243186 // greedy sampling, with probs
@@ -302,11 +245,13 @@ static llama_token llama_sampling_sample_impl(
302245 return id;
303246}
304247
305- static llama_token_data_array llama_sample_probability_distribution_impl (
248+ static llama_token_data_array llama_sampling_prepare_impl (
306249 struct llama_sampling_context * ctx_sampling,
307250 struct llama_context * ctx_main,
308251 struct llama_context * ctx_cfg,
309- const int idx) {
252+ const int idx,
253+ bool apply_grammar,
254+ std::vector<float > * original_logits) {
310255 const llama_sampling_params & params = ctx_sampling->params ;
311256
312257 const int n_vocab = llama_n_vocab (llama_get_model (ctx_main));
@@ -315,6 +260,7 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
315260 const float penalty_repeat = params.penalty_repeat ;
316261 const float penalty_freq = params.penalty_freq ;
317262 const float penalty_present = params.penalty_present ;
263+
318264 const bool penalize_nl = params.penalize_nl ;
319265
320266 auto & prev = ctx_sampling->prev ;
@@ -323,8 +269,10 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
323269 // Get a pointer to the logits
324270 float * logits = llama_get_logits_ith (ctx_main, idx);
325271
326- // Declare original_logits at the beginning of the function scope
327- std::vector<float > original_logits;
272+ if (apply_grammar && original_logits != NULL ) {
273+ // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
274+ *original_logits = {logits, logits + llama_n_vocab (llama_get_model (ctx_main))};
275+ }
328276
329277 // apply params.logit_bias map
330278 for (auto it = params.logit_bias .begin (); it != params.logit_bias .end (); it++) {
@@ -364,12 +312,11 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
364312 }
365313 }
366314
367- // apply grammar checks
368- if (ctx_sampling->grammar != NULL ) {
315+ // apply grammar checks before sampling logic
316+ if (apply_grammar && ctx_sampling->grammar != NULL ) {
369317 llama_sample_grammar (ctx_main, &cur_p, ctx_sampling->grammar );
370318 }
371319
372- llama_sample_softmax (ctx_main, &cur_p);
373320 return cur_p;
374321}
375322
@@ -382,12 +329,14 @@ llama_token llama_sampling_sample(
382329 return llama_sampling_sample_impl (ctx_sampling, ctx_main, ctx_cfg, idx, false );
383330}
384331
385- llama_token_data_array llama_sampling_probability_distribution (
332+ llama_token_data_array llama_sampling_prepare (
386333 struct llama_sampling_context * ctx_sampling,
387334 struct llama_context * ctx_main,
388335 struct llama_context * ctx_cfg,
389- const int idx) {
390- return llama_sample_probability_distribution_impl (ctx_sampling,ctx_main, ctx_cfg, idx);
336+ const int idx,
337+ bool apply_grammar,
338+ std::vector<float > * original_logits) {
339+ return llama_sampling_prepare_impl (ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
391340}
392341
393342void llama_sampling_accept (
0 commit comments