Skip to content

Conversation

ggerganov
Copy link
Member

@ggerganov ggerganov commented Aug 13, 2025

ref #15082 (comment)

The server now makes checkpoints of the SWA memory in order to minimize the amount of context reprocessing. A SWA checkpoint represents the state (both the KV cells and KV data) of the cache. Only the SWA part is stored in the checkpoint, therefore the size is relatively small (proportional to the SWA window that the model uses).

The number of checkpoints per slot by default is 3 and can be configured with --swa-checkpoints N.

A checkpoint is created upon finishing the processing of a prompt:

// make a checkpoint with the SWA memory
// checkpoints are needed only if we are not using "--swa-full"
if (llama_model_n_swa(model) > 0 && !params_base.swa_full) {
if (slot.swa_checkpoints.size() >= SERVER_MAX_SWA_CHECKPOINTS_PER_SLOT) {
{
const auto & cur = slot.swa_checkpoints.back();
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
}
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
}
const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
/*.data = */ std::vector<uint8_t>(swa_size),
});
llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024);
}

Checkpoints are created only if the --swa-full argument is not specified. If the argument is used, we can branch from any past positions of the context (so no need to do checkpoints), but the drawback is that the SWA memory size is much larger in this case.

libllama API changes

  • Add llama_state_seq_get_size_ext()
  • Add llama_state_seq_get_data_ext()
  • Add llama_state_seq_set_data_ext()

TODO:

  • Update libllama interface to specify SWA and non-SWA state saving
  • Sanity-checks that the SWA checkpoint is valid
  • Clean-up llama-server

@ggerganov
Copy link
Member Author

@slaren Let me know if this works on your end. I'll look to clean this up and prepare for merge.

Comment on lines 1960 to 1963
if (!hparams.is_swa(il)) {
continue;
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporary hack to store just the SWA data

@slaren
Copy link
Member

slaren commented Aug 13, 2025

I have been trying this for a while with the 20B and 120B models, and it seems to work as expected. Definitely helps a lot, instead of several minutes reprocessing the entire context before every interaction, it takes only a few seconds before it starts generating the response. This improves dramatically the usability of the 120B model on systems with limited VRAM.

@ggerganov
Copy link
Member Author

Suggestions how to update the llama_state_seq_... API to support this use case are welcome:

llama.cpp/include/llama.h

Lines 835 to 857 in 0b64ee5

// Get the exact size needed to copy the state of a single sequence
LLAMA_API size_t llama_state_seq_get_size(
struct llama_context * ctx,
llama_seq_id seq_id);
// Copy the state of a single sequence into the specified buffer
LLAMA_API size_t llama_state_seq_get_data(
struct llama_context * ctx,
uint8_t * dst,
size_t size,
llama_seq_id seq_id);
// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
// Returns:
// - Positive: Ok
// - Zero: Failed to load
LLAMA_API size_t llama_state_seq_set_data(
struct llama_context * ctx,
const uint8_t * src,
size_t size,
llama_seq_id dest_seq_id);

I'll wrap this up tomorrow.

@slaren
Copy link
Member

slaren commented Aug 13, 2025

Suggestions how to update the llama_state_seq_... API to support this use case are welcome:

I can't think of anything better than just adding a flag parameter to use only the SWA layers, this use case is too specific to generalize it. It could be a generic bit flags parameter that can be extended with additional flags in the future if necessary.

@ggerganov ggerganov marked this pull request as ready for review August 14, 2025 08:10
@ggerganov ggerganov requested a review from ngxson as a code owner August 14, 2025 08:10
@ggerganov
Copy link
Member Author

This is ready for review and testing

@ggerganov ggerganov merged commit d32e03f into master Aug 14, 2025
45 of 47 checks passed
@ggerganov ggerganov deleted the gg/server-swa-checkpoints branch August 14, 2025 11:59
@ddh0
Copy link
Contributor

ddh0 commented Aug 15, 2025

Could the changes in this PR also be applied to fix #14625? (Jamba)

@ggerganov
Copy link
Member Author

I think so. Likely the change is as simple as respecting the SWA flag in the hybrid cache implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants