-
Notifications
You must be signed in to change notification settings - Fork 12.9k
sampling : optimize samplers by reusing bucket sort #15665
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logically this seems correct but I think it would be better to integrate the limit for top_p
into the bucket sort function. You could keep track of the probability content per bucket and only process as many as would be needed to reach the required threshold. Though quite honestly I think the top 256 tokens are going to be enough for all practical applications so this probably doesn't matter much.
if (post_sampling) { | ||
const auto * cur_p = common_sampler_get_candidates(slot.smpl); | ||
const auto * cur_p = common_sampler_get_candidates(slot.smpl, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In situations where the application requires the candidates to be sorted, using common_sampler_get_candidates(smpl, true);
will perform the sorting for convenience.
37d3dd4
to
2efc7e4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keeping the buffers to avoid allocations is probably overkill. If there is a significant overhead from creating new vectors on every call, I think it is more likely that the issue is the memory being initialized in the resize
call.
src/llama-sampling.cpp
Outdated
static void llama_token_data_array_sort_inplace(llama_token_data_array * cur_p, int k, llama_sort_data & buf) { | ||
static const auto comp = [](const llama_token_data & a, const llama_token_data & b) { | ||
return a.logit > b.logit; | ||
}; | ||
|
||
if (k <= 128) { | ||
std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp); | ||
return; | ||
} | ||
|
||
llama_token_data_array_sort(*cur_p, k, buf); | ||
|
||
std::memcpy(cur_p->data, buf.data.data(), k*sizeof(llama_token_data)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cur_p->sorted = true
could be set here, currently it is done after every call to this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The last memcpy
copying only the first k
elements does not seem correct. If you do a partial sort, you still need to copy all the elements. If the intention is to reduce the size, then cur_p->size
should be updated here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, in C++ code std::copy
may be preferred to memcpy
, since it is a type-safe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more thing, it was not obvious to me that the k
parameter is used to do a partial sort. The name of the argument is not clear enough by itself, so this should be documented. Renaming the function to "partial_sort" could also help.
d74a6ab
to
6d2a38c
Compare
Thanks. I removed the helpers buffers and updated the code as recommended ( |
Co-authored-by: Johannes Gäßler <[email protected]>
The major change here is that the
dist
sampler no longer sorts. A lot of the tests intest-sampling
assume that the outputs are sorted, hence the many reordering of the values there. If we don't make this change, then using a sampler chain such as the one recommended by OpenAI forgpt-oss
(top-p=1 + min-p=0 + top-k=0 (i.e. disabled)
) would result in very slow sorting of the full vocabulary in thedist
sampler because there is nothing to cut the low-probability tokens.top-p
andmin-p
samplersdist
sampler (255b070)common_sampler_get_candidates()
can now explicitly return sorted candidates if requested viabool do_sort
libllama
API changesllama_sampler_init_softmax()
dist
sampler created withllama_sampler_init_dist()
will no longer sort the candidates. The old behaviour of implicitly sorting the candidates was not documented, so technically this is not a breaking change, but it's possible that user code assumed that the results will be sorted - therefore making a note