@@ -12832,6 +12832,64 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
1283212832 }
1283312833}
1283412834
12835+ void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_token_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) {
12836+ // loop through each candidate
12837+ for (size_t i = 0; i < candidates->size; ++i) {
12838+
12839+ // if our candidate itself is part of the sequence breakers, we don't apply the dry penalty
12840+ if (std::find(seq_breakers, seq_breakers + seq_breakers_size, candidates->data[i].id) != seq_breakers + seq_breakers_size) {
12841+ continue;
12842+ }
12843+
12844+ int max_match_length = 0;
12845+
12846+ // loop through each previous token
12847+ for (size_t j = 0; j < last_token_size; ++j) {
12848+ // if the current candidate is the same as the previous token
12849+ if (candidates->data[i].id == last_tokens[j]) {
12850+ // greedily match sequence backwards starting from the current position with the end of prev
12851+ int match_length = 1;
12852+
12853+ // loop through the previous tokens
12854+ for(;; match_length++) {
12855+ // if we have reached the start of our stored prev, break
12856+ if(j - match_length > 0) break;
12857+
12858+ // this shouldn't happen because (j - match_length) should always be smaller than (size - match_length)
12859+ // but let's check here to avoid the unexpected
12860+ if(last_token_size - match_length < 0) break;
12861+
12862+ // compare token starts at our prev index, going backwards by match length
12863+ auto compare_token = last_tokens[j - match_length];
12864+
12865+ // head token starts at the end of prev, going backwards by match length
12866+ auto head_token = last_tokens[last_token_size - match_length];
12867+
12868+ // if compare token is part of the sequence breakers, break out of the match
12869+ if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
12870+ break;
12871+
12872+ // break out of the match if any tokens don't match
12873+ if(compare_token != head_token)
12874+ break;
12875+ }
12876+
12877+ // update our max match length
12878+ max_match_length = std::max(max_match_length, match_length);
12879+ }
12880+ }
12881+
12882+ // apply penalties
12883+ if(max_match_length > dry_allowed_length) {
12884+ // calculate the penalty
12885+ float penalty = dry_multiplier * pow(dry_base, max_match_length - dry_allowed_length);
12886+
12887+ // apply the dry penalty
12888+ candidates->data[i].logit -= penalty;
12889+ }
12890+ }
12891+ }
12892+
1283512893void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
1283612894 if (z >= 1.0f || candidates->size <= 2) {
1283712895 return;
0 commit comments