@@ -232,94 +232,230 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra
232232 }
233233}
234234
235- void llama_sample_dry_impl (llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) {
236- // skip dry sampler if we don't have a previous token
237- if (last_tokens_size < 1 ) return ;
235+ std::vector<llama_token> llama_tokenize (
236+ const struct llama_context * ctx,
237+ const std::string & text,
238+ bool add_special,
239+ bool parse_special) {
240+ return llama_tokenize (llama_get_model (ctx), text, add_special, parse_special);
241+ }
242+
243+ std::vector<llama_token> llama_tokenize (
244+ const struct llama_model * model,
245+ const std::string & text,
246+ bool add_special,
247+ bool parse_special) {
248+ // upper limit for the number of tokens
249+ int n_tokens = text.length () + 2 * add_special;
250+ std::vector<llama_token> result (n_tokens);
251+ n_tokens = llama_tokenize (model, text.data (), text.length (), result.data (), result.size (), add_special, parse_special);
252+ if (n_tokens < 0 ) {
253+ result.resize (-n_tokens);
254+ int check = llama_tokenize (model, text.data (), text.length (), result.data (), result.size (), add_special, parse_special);
255+ GGML_ASSERT (check == -n_tokens);
256+ } else {
257+ result.resize (n_tokens);
258+ }
259+ return result;
260+ }
261+
262+ std::string llama_detokenize (llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
263+ std::string text;
264+ text.resize (std::max (text.capacity (), tokens.size ()));
265+ int32_t n_chars = llama_detokenize (llama_get_model (ctx), tokens.data (), (int32_t )tokens.size (), &text[0 ], (int32_t )text.size (), false , special);
266+ if (n_chars < 0 ) {
267+ text.resize (-n_chars);
268+ n_chars = llama_detokenize (llama_get_model (ctx), tokens.data (), (int32_t )tokens.size (), &text[0 ], (int32_t )text.size (), false , special);
269+ GGML_ASSERT (n_chars <= (int32_t )text.size ()); // whitespace trimming is performed after per-token detokenization
270+ }
271+
272+ text.resize (n_chars);
273+
274+ // NOTE: the original tokenizer decodes bytes after collecting the pieces.
275+ return text;
276+ }
277+
278+ std::string llama_detokenize_single (llama_context * ctx, llama_token token, bool special) {
279+ std::vector<llama_token> tokens = {token};
280+ return llama_detokenize (ctx, tokens, special);
281+ }
238282
239- // get the last token
240- auto last_token = last_tokens[last_tokens_size - 1 ];
283+ // Constants for preventing overflow
284+ const float FLOAT_MAX_LOG = 88 .7228391f ;
285+ const int MAX_CHAR_LEN = 40 ;
286+ const int MAX_SEQ_LEN = 20 ;
241287
242- // if last token is part of the sequence breakers, skip whole sampler
243- if (std::find (dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) {
288+
289+ void llama_sample_dry_impl (struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector<std::string> & dry_seq_breakers) {
290+ if (last_tokens_size < 1 ) {
244291 return ;
245292 }
246293
247- // create an unordered map of "next tokens" <-> max match length
294+ // Cache for token-to-string conversions
295+ std::unordered_map<llama_token, std::string> token_to_string_cache;
296+ // Store sequence breakers for more efficient lookup
297+ std::unordered_multimap<std::string, std::vector<std::string>> restart_sequences;
298+
299+ auto detokenize_with_cache = [&](llama_token token) -> std::string {
300+ auto it = token_to_string_cache.find (token);
301+ if (it != token_to_string_cache.end ()) {
302+ return it->second ;
303+ }
304+ std::string token_str = llama_detokenize_single (ctx, token, false );
305+ token_to_string_cache[token] = token_str;
306+ return token_str;
307+ };
308+
309+ // Pre-process dry_seq_breakers
310+ for (const auto & breaker : dry_seq_breakers) {
311+ std::string breaker_trimmed = breaker.substr (0 , MAX_CHAR_LEN);
312+ std::vector<llama_token> tokens = llama_tokenize (ctx, breaker_trimmed, false , false );
313+
314+ if (!tokens.empty ()) {
315+ std::string head = detokenize_with_cache (tokens[0 ]);
316+ std::vector<std::string> tail;
317+
318+ for (size_t i = 1 ; i < tokens.size () && i <= MAX_SEQ_LEN; ++i) {
319+ tail.push_back (detokenize_with_cache (tokens[i]));
320+ }
321+ restart_sequences.emplace (head, tail);
322+ }
323+ }
324+
325+ // Find max repetition length considering restart sequences
326+ int rep_limit = last_tokens_size;
327+
328+ for (size_t i = 0 ; i < last_tokens_size; ++i) {
329+ size_t ix = last_tokens_size - 1 - i;
330+ std::string token_str = detokenize_with_cache (last_tokens[ix]);
331+
332+ // Check if the token is a potential sequence breaker
333+ auto its = restart_sequences.equal_range (token_str);
334+ if (its.first == restart_sequences.end ()) continue ;
335+
336+ int longest_match = -1 ;
337+ // Check all potential sequence breakers starting with this token
338+ for (auto it = its.first ; it != its.second ; ++it) {
339+ int seq_len = (int )it->second .size ();
340+ if (seq_len > longest_match && seq_len <= i) {
341+ bool match = true ;
342+ // Check if the following tokens match the sequence breaker
343+ for (size_t offset = 0 ; offset < seq_len; ++offset) {
344+ if (it->second [offset] != detokenize_with_cache (last_tokens[ix + 1 + offset])) {
345+ match = false ;
346+ break ;
347+ }
348+ }
349+ if (match) {
350+ longest_match = seq_len;
351+ }
352+ }
353+ }
354+
355+ if (longest_match >= 0 ) {
356+ rep_limit = static_cast <int >(i) - longest_match;
357+ break ;
358+ }
359+ }
360+
361+ if (rep_limit <= dry_allowed_length) {
362+ return ;
363+ }
364+
365+ // Store max match length for each token
248366 std::unordered_map<llama_token, size_t > match_lengths;
249367
250- // loop through each previous token (exclude the last token)
368+ // Find repeated sequences
251369 for (size_t i = 0 ; i < last_tokens_size - 1 ; ++i) {
252- // skip if the compare token is not the same as the last token
253- if (last_tokens[i] != last_token) {
370+ if (last_tokens[i] != last_tokens[last_tokens_size - 1 ]) {
254371 continue ;
255372 }
256373
257- // get the next token (i + 1 is always less than last_tokens_size)
258374 auto next_token = last_tokens[i + 1 ];
375+ std::string next_token_str = detokenize_with_cache (next_token);
259376
260- // if next token is part of the sequence breakers, skip
261- if (std::find (dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) {
377+ // Skip if next token is a sequence breaker
378+ auto its = restart_sequences.equal_range (next_token_str);
379+ if (its.first != restart_sequences.end ()) {
262380 continue ;
263381 }
264382
265- // try to extend the match backwards (match length starts at 1 because last token is already matched)
266383 size_t match_length = 1 ;
267384
268- // loop through the previous tokens
385+ // Extend match as far as possible
269386 for (;; match_length++) {
270- // if we have reached the start of our last tokens, break
271- if (i < match_length) break ;
387+ if (i < match_length || match_length > rep_limit) {
388+ break ;
389+ }
272390
273- // compare token starts at our prev index, going backwards by match length
274391 auto compare_token = last_tokens[i - match_length];
392+ std::string compare_token_str = detokenize_with_cache (compare_token);
275393
276- // head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself
277394 auto head_token = last_tokens[last_tokens_size - 1 - match_length];
395+ std::string head_token_str = detokenize_with_cache (head_token);
278396
279- // break out of the match if any tokens don't match
280- if (compare_token != head_token) {
397+ if (compare_token_str != head_token_str) {
281398 break ;
282399 }
283400
284- // if compare token is part of the sequence breakers, break out of the match
285- if (std::find (dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) {
401+ // Check if we've hit a sequence breaker
402+ its = restart_sequences.equal_range (compare_token_str);
403+ if (its.first != restart_sequences.end ()) {
286404 break ;
287405 }
288406 }
289407
290- // Check if the next token exists in the map
408+ // Update max match length for this token
291409 auto it = match_lengths.find (next_token);
292-
293410 if (it == match_lengths.end ()) {
294- // Key does not exist, insert the new value
295411 match_lengths[next_token] = match_length;
296412 } else {
297- // Key exists, update it with the max of the new value or the existing value
298413 it->second = std::max (it->second , match_length);
299414 }
300415 }
301416
302- // apply penalties
417+ // Calculate max safe exponent
418+ int max_exponent = 0 ;
419+ if (dry_base > 1 .000001f ) {
420+ max_exponent = static_cast <int >(FLOAT_MAX_LOG / log (dry_base));
421+ }
422+
423+ #ifdef DEBUG
424+ LLAMA_LOG_INFO (" DRY Sampling parameters:\n " );
425+ LLAMA_LOG_INFO (" dry_base: %f\n " , dry_base);
426+ LLAMA_LOG_INFO (" dry_multiplier: %f\n " , dry_multiplier);
427+ LLAMA_LOG_INFO (" dry_allowed_length: %d\n " , dry_allowed_length);
428+ LLAMA_LOG_INFO (" max_exponent: %d\n " , max_exponent);
429+ LLAMA_LOG_INFO (" DRY penalties [" );
430+ #endif
431+
432+ // Apply penalties
303433 for (const auto & pair : match_lengths) {
304434 auto next_token = pair.first ;
305435 auto match_length = pair.second ;
306436
307- // if the match length is greater than or equal to our allowed length in config, we apply penalities
308- if (match_length >= (size_t )dry_allowed_length) {
309-
310- // find our next token in the candidates->data
437+ if (match_length >= static_cast <size_t >(dry_allowed_length)) {
311438 for (size_t i = 0 ; i < candidates->size ; ++i) {
312439 if (candidates->data [i].id == next_token) {
313- // calculate the penalty
314- float penalty = dry_multiplier * pow (dry_base, match_length - dry_allowed_length);
315-
316- // apply the dry penalty
440+ int repeat_exp = static_cast <int >(match_length - dry_allowed_length);
441+ if (max_exponent > 0 && repeat_exp > max_exponent) {
442+ repeat_exp = max_exponent;
443+ }
444+ float penalty = dry_multiplier * pow (dry_base, static_cast <float >(repeat_exp));
317445 candidates->data [i].logit -= penalty;
446+
447+ #ifdef DEBUG
448+ LLAMA_LOG_INFO (" Token %d: %s (Penalty: %.2f)" , next_token, detokenize_with_cache (next_token).c_str (), penalty);
449+ #endif
318450 break ;
319451 }
320452 }
321453 }
322454 }
455+
456+ #ifdef DEBUG
457+ LLAMA_LOG_INFO (" ]\n " );
458+ #endif
323459}
324460
325461void llama_sample_tail_free_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
0 commit comments