@@ -12832,60 +12832,86 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
1283212832 }
1283312833}
1283412834
12835- void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_token_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) {
12836- // loop through each candidate
12837- for (size_t i = 0; i < candidates->size; ++i) {
12835+ void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) {
12836+ // sanity check
12837+ GGML_ASSERT(last_tokens_size > 0);
12838+
12839+ // get the last token
12840+ auto last_token = last_tokens[last_tokens_size - 1];
12841+
12842+ // if last token is part of the sequence breakers, skip whole sampler
12843+ if(std::find(seq_breakers, seq_breakers + seq_breakers_size, last_token) != seq_breakers + seq_breakers_size) {
12844+ return;
12845+ }
1283812846
12839- // if our candidate itself is part of the sequence breakers, we don't apply the dry penalty
12840- if (std::find(seq_breakers, seq_breakers + seq_breakers_size, candidates->data[i].id) != seq_breakers + seq_breakers_size) {
12847+ // create an unordered map of "next tokens" <-> max match length
12848+ std::unordered_map<llama_token, size_t> match_lengths;
12849+
12850+ // loop through each previous token (exclude the last token)
12851+ for (size_t i = 0; i < last_tokens_size - 1; ++i) {
12852+ // skip if the compare token if it's not the same as the last token
12853+ if(last_tokens[i] != last_token) {
1284112854 continue;
1284212855 }
1284312856
12844- int max_match_length = 0;
12857+ // get the next token (i + 1 is always less than last_tokens_size)
12858+ auto next_token = last_tokens[i + 1];
1284512859
12846- // loop through each previous token
12847- for (size_t j = 0; j < last_token_size; ++j) {
12848- // if the current candidate is the same as the previous token
12849- if (candidates->data[i].id == last_tokens[j]) {
12850- // greedily match sequence backwards starting from the current position with the end of prev
12851- int match_length = 1;
12860+ // try to extend the match backwards (match length starts a 1 because last token is already matched)
12861+ size_t match_length = 1;
1285212862
12853- // loop through the previous tokens
12854- for(;; match_length++) {
12855- // if we have reached the start of our stored prev , break
12856- if(j - match_length > 0 ) break;
12863+ // loop through the previous tokens
12864+ for(;; match_length++) {
12865+ // if we have reached the start of our last tokens , break
12866+ if(i < match_length) break;
1285712867
12858- // this shouldn't happen because (j - match_length) should always be smaller than (size - match_length)
12859- // but let's check here to avoid the unexpected
12860- if(last_token_size - match_length < 0) break;
12868+ // compare token starts at our prev index, going backwards by match length
12869+ auto compare_token = last_tokens[i - match_length];
1286112870
12862- // compare token starts at our prev index , going backwards by match length
12863- auto compare_token = last_tokens[j - match_length];
12871+ // head token starts at the end of last tokens , going backwards by match length, minus 1 because we start at the last token itself
12872+ auto head_token = last_tokens[last_tokens_size - 1 - match_length];
1286412873
12865- // head token starts at the end of prev, going backwards by match length
12866- auto head_token = last_tokens[last_token_size - match_length];
12874+ // if compare token is part of the sequence breakers, break out of the match
12875+ if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
12876+ break;
1286712877
12868- // if compare token is part of the sequence breakers, break out of the match
12869- if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
12870- break;
12878+ // break out of the match if any tokens don't match
12879+ if(compare_token != head_token)
12880+ break;
12881+ }
1287112882
12872- // break out of the match if any tokens don't match
12873- if(compare_token != head_token)
12874- break;
12875- }
12883+ // Check if the next token exists in the map
12884+ auto it = match_lengths.find(next_token);
1287612885
12877- // update our max match length
12878- max_match_length = std::max(max_match_length, match_length);
12879- }
12886+ if (it == match_lengths.end()) {
12887+ // Key does not exist, insert the new value
12888+ match_lengths[next_token] = match_length;
12889+ } else {
12890+ // Key exists, update it with the max of the new value or the existing value
12891+ it->second = std::max(it->second, match_length);
1288012892 }
12893+ }
1288112894
12882- // apply penalties
12883- if(max_match_length > dry_allowed_length ) {
12884- // calculate the penalty
12885- float penalty = dry_multiplier * pow(dry_base, max_match_length - dry_allowed_length) ;
12895+ // apply penalties
12896+ for (const auto& pair : match_lengths ) {
12897+ auto next_token = pair.first;
12898+ auto match_length = pair.second ;
1288612899
12887- // apply the dry penalty
12888- candidates->data[i].logit -= penalty;
12900+ // if the match length is greater than our allowed length in config, we apply penalities
12901+ if(match_length > dry_allowed_length) {
12902+
12903+ // find our next token in the candidates->data
12904+ size_t i = 0;
12905+ for (; i < candidates->size; ++i) {
12906+ if (candidates->data[i].id == next_token) {
12907+ // calculate the penalty
12908+ float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);
12909+
12910+ // apply the dry penalty
12911+ candidates->data[i].logit -= penalty;
12912+ break;
12913+ }
12914+ }
1288912915 }
1289012916 }
1289112917}
0 commit comments