-
Notifications
You must be signed in to change notification settings - Fork 12.9k
server : add SWA checkpoints #15293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
server : add SWA checkpoints #15293
Conversation
ggml-ci
@slaren Let me know if this works on your end. I'll look to clean this up and prepare for merge. |
src/llama-kv-cache-unified.cpp
Outdated
if (!hparams.is_swa(il)) { | ||
continue; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Temporary hack to store just the SWA data
I have been trying this for a while with the 20B and 120B models, and it seems to work as expected. Definitely helps a lot, instead of several minutes reprocessing the entire context before every interaction, it takes only a few seconds before it starts generating the response. This improves dramatically the usability of the 120B model on systems with limited VRAM. |
Suggestions how to update the Lines 835 to 857 in 0b64ee5
I'll wrap this up tomorrow. |
I can't think of anything better than just adding a flag parameter to use only the SWA layers, this use case is too specific to generalize it. It could be a generic bit flags parameter that can be extended with additional flags in the future if necessary. |
This is ready for review and testing |
Could the changes in this PR also be applied to fix #14625? (Jamba) |
I think so. Likely the change is as simple as respecting the SWA flag in the hybrid cache implementation. |
ref #15082 (comment)
The server now makes checkpoints of the SWA memory in order to minimize the amount of context reprocessing. A SWA checkpoint represents the state (both the KV cells and KV data) of the cache. Only the SWA part is stored in the checkpoint, therefore the size is relatively small (proportional to the SWA window that the model uses).
The number of checkpoints per slot by default is 3 and can be configured with
--swa-checkpoints N
.A checkpoint is created upon finishing the processing of a prompt:
llama.cpp/tools/server/server.cpp
Lines 3579 to 3604 in e7d2ecd
Checkpoints are created only if the
--swa-full
argument is not specified. If the argument is used, we can branch from any past positions of the context (so no need to do checkpoints), but the drawback is that the SWA memory size is much larger in this case.libllama
API changesllama_state_seq_get_size_ext()
llama_state_seq_get_data_ext()
llama_state_seq_set_data_ext()
TODO:
libllama
interface to specify SWA and non-SWA state savingllama-server