From 96db966b1e1ca6c5bf39404890954fc54ec165e9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 13 Aug 2025 16:47:25 +0300 Subject: [PATCH 1/8] server : add SWA checkpoints ggml-ci --- src/llama-kv-cache-unified.cpp | 24 ++++++++++++ tools/server/server.cpp | 68 ++++++++++++++++++++++++++++++++-- 2 files changed, 88 insertions(+), 4 deletions(-) diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 88c88552aaad0..ffddfbfcf9cac 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -1957,6 +1957,10 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ for (const auto & layer : layers) { const uint32_t il = layer.il; + if (!hparams.is_swa(il)) { + continue; + } + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); auto * k = layer.k_stream[cr.strm]; @@ -1981,6 +1985,10 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ for (const auto & layer : layers) { const uint32_t il = layer.il; + if (!hparams.is_swa(il)) { + continue; + } + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[cr.strm]; @@ -2007,6 +2015,10 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ for (const auto & layer : layers) { const uint32_t il = layer.il; + if (!hparams.is_swa(il)) { + continue; + } + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[cr.strm]; @@ -2162,6 +2174,10 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm for (const auto & layer : layers) { const uint32_t il = layer.il; + if (!hparams.is_swa(il)) { + continue; + } + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); auto * k = layer.k_stream[strm]; @@ -2194,6 +2210,10 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm for (const auto & layer : layers) { const uint32_t il = layer.il; + if (!hparams.is_swa(il)) { + continue; + } + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[strm]; @@ -2226,6 +2246,10 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm for (const auto & layer : layers) { const uint32_t il = layer.il; + if (!hparams.is_swa(il)) { + continue; + } + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[strm]; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index f549cda476657..fea363dad8101 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -692,6 +692,13 @@ struct completion_token_output { } }; +struct swa_checkpoint { + std::vector data; + + llama_pos pos_min; + llama_pos pos_max; +}; + struct server_task_result_cmpl_final : server_task_result { int index = 0; @@ -1336,6 +1343,8 @@ struct server_slot { std::vector generated_token_probs; + std::vector swa_checkpoints; + bool has_next_token = true; bool has_new_line = false; bool truncated = false; @@ -3300,10 +3309,42 @@ struct server_context { const auto n_swa = llama_model_n_swa(model); if (pos_min > std::max(0, slot.n_past - n_swa)) { - SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); - SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", - "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); - slot.n_past = 0; + // search for a SWA checkpoint + int ic = -1; + int np = std::numeric_limits::max(); + for (int i = 0; i < (int) slot.swa_checkpoints.size(); i++) { + const auto & cur = slot.swa_checkpoints[i]; + if (cur.pos_min <= std::max(0, slot.n_past - n_swa)) { + const int p = std::max(0, slot.n_past - cur.pos_max); + + if (p < np) { + ic = i; + np = p; + } + } + } + + if (ic == -1) { + SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); + SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", + "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + slot.n_past = 0; + + slot.swa_checkpoints.clear(); + } else { + // erase all checkpoints after the one we are using + slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + ic + 1, slot.swa_checkpoints.end()); + + // restore the checkpoint + const auto & cur = slot.swa_checkpoints[ic]; + + const size_t swa_size = cur.data.size(); + llama_state_seq_set_data(ctx, cur.data.data(), swa_size, slot.id); + + slot.n_past = std::min(slot.n_past, cur.pos_max); + + SLT_WRN(slot, "prompt swa checkpoint restored, pos_min = %d, pos_max = %d, size = %f MB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024); + } } } } @@ -3517,6 +3558,25 @@ struct server_context { // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; + + // make a checkpoint + if (llama_model_n_swa(model) > 0) { + if (slot.swa_checkpoints.size() > 8) { + slot.swa_checkpoints.erase(slot.swa_checkpoints.begin()); + } + + auto & cur = slot.swa_checkpoints.emplace_back(); + + cur.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + cur.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); + + const size_t swa_size = llama_state_seq_get_size(ctx, slot.id); + cur.data.resize(swa_size); + + llama_state_seq_get_data(ctx, cur.data.data(), swa_size, slot.id); + + SLT_WRN(slot, "prompt swa checkpoint, pos_min = %d, pos_max = %d, size = %f MB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024); + } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots } From 487b9223a48286013129acb51a9b912bbb837a7c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 13 Aug 2025 20:39:41 +0300 Subject: [PATCH 2/8] cont : server clean-up --- tools/server/server.cpp | 84 ++++++++++++++++++++++++----------------- 1 file changed, 49 insertions(+), 35 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index fea363dad8101..950b8088457a0 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -31,6 +31,8 @@ #include #include +#define SERVER_MAX_SWA_CHECKPOINTS_PER_SLOT 3 + using json = nlohmann::ordered_json; constexpr int HTTP_POLLING_SECONDS = 1; @@ -693,10 +695,10 @@ struct completion_token_output { }; struct swa_checkpoint { - std::vector data; - llama_pos pos_min; llama_pos pos_max; + + std::vector data; }; struct server_task_result_cmpl_final : server_task_result { @@ -3300,6 +3302,8 @@ struct server_context { slot.n_past = 0; } + const auto n_swa = llama_model_n_swa(model); + if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); if (pos_min == -1) { @@ -3307,43 +3311,47 @@ struct server_context { GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); } - const auto n_swa = llama_model_n_swa(model); - if (pos_min > std::max(0, slot.n_past - n_swa)) { + const auto pos_min_thold = std::max(0, slot.n_past - n_swa); + + if (pos_min > pos_min_thold) { // search for a SWA checkpoint - int ic = -1; - int np = std::numeric_limits::max(); - for (int i = 0; i < (int) slot.swa_checkpoints.size(); i++) { - const auto & cur = slot.swa_checkpoints[i]; - if (cur.pos_min <= std::max(0, slot.n_past - n_swa)) { - const int p = std::max(0, slot.n_past - cur.pos_max); - - if (p < np) { - ic = i; - np = p; - } + auto it = std::find_if( + slot.swa_checkpoints.rbegin(), + slot.swa_checkpoints.rend(), + [&](const auto & cur) { + return cur.pos_min <= pos_min_thold; } - } + ); - if (ic == -1) { + if (it == slot.swa_checkpoints.rend()) { SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); - slot.n_past = 0; + slot.n_past = 0; slot.swa_checkpoints.clear(); } else { - // erase all checkpoints after the one we are using - slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + ic + 1, slot.swa_checkpoints.end()); - // restore the checkpoint - const auto & cur = slot.swa_checkpoints[ic]; + const size_t swa_size = it->data.size(); + llama_state_seq_set_data(ctx, it->data.data(), swa_size, slot.id); - const size_t swa_size = cur.data.size(); - llama_state_seq_set_data(ctx, cur.data.data(), swa_size, slot.id); + slot.n_past = std::min(slot.n_past, it->pos_max); - slot.n_past = std::min(slot.n_past, cur.pos_max); + SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024); + } + } + } - SLT_WRN(slot, "prompt swa checkpoint restored, pos_min = %d, pos_max = %d, size = %f MB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024); + if (n_swa > 0) { + const auto pos_min_thold = std::max(0, slot.n_past - n_swa); + + // erase any checkpoints with pos_min > pos_min_thold + for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) { + const auto & cur = slot.swa_checkpoints[i]; + if (cur.pos_min > pos_min_thold) { + slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i); + + SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); } } } @@ -3559,23 +3567,29 @@ struct server_context { // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; - // make a checkpoint + // make a checkpoint with the SWA memory if (llama_model_n_swa(model) > 0) { - if (slot.swa_checkpoints.size() > 8) { - slot.swa_checkpoints.erase(slot.swa_checkpoints.begin()); - } + if (slot.swa_checkpoints.size() >= SERVER_MAX_SWA_CHECKPOINTS_PER_SLOT) { + { + const auto & cur = slot.swa_checkpoints.back(); - auto & cur = slot.swa_checkpoints.emplace_back(); + SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + } - cur.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); - cur.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); + slot.swa_checkpoints.erase(slot.swa_checkpoints.begin()); + } const size_t swa_size = llama_state_seq_get_size(ctx, slot.id); - cur.data.resize(swa_size); + + auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{ + /*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id), + /*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id), + /*.data = */ std::vector(swa_size), + }); llama_state_seq_get_data(ctx, cur.data.data(), swa_size, slot.id); - SLT_WRN(slot, "prompt swa checkpoint, pos_min = %d, pos_max = %d, size = %f MB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024); + SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %f MiB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024); } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots From 5b0d207daf3768975dd2aefdf38cfa96b5c0c905 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 Aug 2025 09:31:49 +0300 Subject: [PATCH 3/8] server : handle state restore fails --- tools/server/server.cpp | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 950b8088457a0..8513a3431cd59 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3314,8 +3314,10 @@ struct server_context { const auto pos_min_thold = std::max(0, slot.n_past - n_swa); if (pos_min > pos_min_thold) { + SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); + // search for a SWA checkpoint - auto it = std::find_if( + const auto it = std::find_if( slot.swa_checkpoints.rbegin(), slot.swa_checkpoints.rend(), [&](const auto & cur) { @@ -3323,21 +3325,29 @@ struct server_context { } ); - if (it == slot.swa_checkpoints.rend()) { - SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); - SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", - "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + bool do_reset = it == slot.swa_checkpoints.rend(); - slot.n_past = 0; - slot.swa_checkpoints.clear(); - } else { + if (!do_reset) { // restore the checkpoint const size_t swa_size = it->data.size(); - llama_state_seq_set_data(ctx, it->data.data(), swa_size, slot.id); + const size_t n = llama_state_seq_set_data(ctx, it->data.data(), swa_size, slot.id); + + if (n != swa_size) { + SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024); + do_reset = true; + } else { + slot.n_past = std::min(slot.n_past, it->pos_max); - slot.n_past = std::min(slot.n_past, it->pos_max); + SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024); + } + } - SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024); + if (do_reset) { + SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", + "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + + slot.n_past = 0; + slot.swa_checkpoints.clear(); } } } From 025af1541c9101230a1b9e3b3a2d8a36d7ee1e81 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 Aug 2025 10:15:39 +0300 Subject: [PATCH 4/8] llama : add extended llama_state_seq_ API --- include/llama.h | 24 ++++++++++++++++ src/llama-context.cpp | 44 ++++++++++++++++++----------- src/llama-context.h | 10 +++---- src/llama-kv-cache-unified-iswa.cpp | 18 ++++++++---- src/llama-kv-cache-unified-iswa.h | 4 +-- src/llama-kv-cache-unified.cpp | 32 ++++----------------- src/llama-kv-cache-unified.h | 4 +-- src/llama-memory-hybrid.cpp | 8 ++++-- src/llama-memory-hybrid.h | 4 +-- src/llama-memory-recurrent.cpp | 8 ++++-- src/llama-memory-recurrent.h | 4 +-- src/llama-memory.h | 4 +-- 12 files changed, 97 insertions(+), 67 deletions(-) diff --git a/include/llama.h b/include/llama.h index 545e957e5f52b..f5df9f920f77e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -870,6 +870,30 @@ extern "C" { size_t n_token_capacity, size_t * n_token_count_out); +#define LLAMA_STATE_SEQ_FLAGS_NONE 0 +#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 + + typedef uint32_t llama_state_seq_flags; + + LLAMA_API size_t llama_state_seq_get_size_ext( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_state_seq_flags flags); + + LLAMA_API size_t llama_state_seq_get_data_ext( + struct llama_context * ctx, + uint8_t * dst, + size_t size, + llama_seq_id seq_id, + llama_state_seq_flags flags); + + LLAMA_API size_t llama_state_seq_set_data_ext( + struct llama_context * ctx, + const uint8_t * src, + size_t size, + llama_seq_id dest_seq_id, + llama_state_seq_flags flags); + // // Decoding // diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 26a5cf9c3f8db..ca7938549ac04 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1657,30 +1657,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) { } } -size_t llama_context::state_seq_get_size(llama_seq_id seq_id) { +size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) { llama_io_write_dummy io; try { - return state_seq_write_data(io, seq_id); + return state_seq_write_data(io, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); return 0; } } -size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) { +size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) { llama_io_write_buffer io(dst, size); try { - return state_seq_write_data(io, seq_id); + return state_seq_write_data(io, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); return 0; } } -size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) { +size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) { llama_io_read_buffer io(src, size); try { - return state_seq_read_data(io, seq_id); + return state_seq_read_data(io, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); return 0; @@ -1778,7 +1778,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file { const size_t state_size = file.size() - file.tell(); llama_io_read_file io(&file); - const size_t nread = state_seq_read_data(io, seq_id); + const size_t nread = state_seq_read_data(io, seq_id, LLAMA_STATE_SEQ_FLAGS_NONE); if (!nread) { LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); return 0; @@ -1802,7 +1802,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file // save the context state using stream saving llama_io_write_file io(&file); - state_seq_write_data(io, seq_id); + state_seq_write_data(io, seq_id, LLAMA_STATE_SEQ_FLAGS_NONE); const size_t res = file.tell(); GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes()); @@ -1971,21 +1971,21 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { return io.n_bytes(); } -size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { +size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { GGML_UNUSED(seq_id); if (memory) { - memory->state_write(io, seq_id); + memory->state_write(io, seq_id, flags); } return io.n_bytes(); } -size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { +size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { GGML_UNUSED(seq_id); if (memory) { - memory->state_read(io, seq_id); + memory->state_read(io, seq_id, flags); } return io.n_bytes(); @@ -2801,19 +2801,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const } size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) { - return ctx->state_seq_get_size(seq_id); + return llama_state_seq_get_size_ext(ctx, seq_id, LLAMA_STATE_SEQ_FLAGS_NONE); } size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { + return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, LLAMA_STATE_SEQ_FLAGS_NONE); +} + +size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) { + return llama_state_seq_set_data_ext(ctx, src, size, seq_id, LLAMA_STATE_SEQ_FLAGS_NONE); +} + +size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) { + return ctx->state_seq_get_size(seq_id, flags); +} + +size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { ctx->synchronize(); - return ctx->state_seq_get_data(seq_id, dst, size); + return ctx->state_seq_get_data(seq_id, dst, size, flags); } -size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) { +size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { ctx->synchronize(); - return ctx->state_seq_set_data(seq_id, src, size); + return ctx->state_seq_set_data(seq_id, src, size, flags); } size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { diff --git a/src/llama-context.h b/src/llama-context.h index 25c143d56dfb2..cf3078d2bff8a 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -111,9 +111,9 @@ struct llama_context { size_t state_get_data( uint8_t * dst, size_t size); size_t state_set_data(const uint8_t * src, size_t size); - size_t state_seq_get_size(llama_seq_id seq_id); - size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size); - size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size); + size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags); + size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags); + size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags); bool state_load_file( const char * filepath, @@ -212,8 +212,8 @@ struct llama_context { size_t state_write_data(llama_io_write_i & io); size_t state_read_data (llama_io_read_i & io); - size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id); - size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id); + size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags); + size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags); // // members diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index 01d27fb4db9b1..1e363fff2a554 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -194,14 +194,20 @@ bool llama_kv_cache_unified_iswa::get_can_shift() const { return kv_base->get_size() == kv_swa->get_size(); } -void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - kv_base->state_write(io, seq_id); - kv_swa ->state_write(io, seq_id); +void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) { + kv_base->state_write(io, seq_id, flags); + } + + kv_swa->state_write(io, seq_id, flags); } -void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - kv_base->state_read(io, seq_id); - kv_swa ->state_read(io, seq_id); +void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) { + kv_base->state_read(io, seq_id, flags); + } + + kv_swa->state_read(io, seq_id, flags); } llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const { diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h index d2650dadd3595..9faa36f189d3f 100644 --- a/src/llama-kv-cache-unified-iswa.h +++ b/src/llama-kv-cache-unified-iswa.h @@ -56,8 +56,8 @@ class llama_kv_cache_unified_iswa : public llama_memory_i { // state write/load - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) override; // // llama_kv_cache_unified_iswa specific API diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index ffddfbfcf9cac..478ebffac0f63 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -1828,7 +1828,9 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { return false; } -void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { +void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + GGML_UNUSED(flags); + io.write(&n_stream, sizeof(n_stream)); for (uint32_t s = 0; s < n_stream; ++s) { @@ -1879,7 +1881,9 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq } } -void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) { +void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + GGML_UNUSED(flags); + GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); uint32_t n_stream_cur; @@ -1957,10 +1961,6 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ for (const auto & layer : layers) { const uint32_t il = layer.il; - if (!hparams.is_swa(il)) { - continue; - } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); auto * k = layer.k_stream[cr.strm]; @@ -1985,10 +1985,6 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ for (const auto & layer : layers) { const uint32_t il = layer.il; - if (!hparams.is_swa(il)) { - continue; - } - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[cr.strm]; @@ -2015,10 +2011,6 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ for (const auto & layer : layers) { const uint32_t il = layer.il; - if (!hparams.is_swa(il)) { - continue; - } - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[cr.strm]; @@ -2174,10 +2166,6 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm for (const auto & layer : layers) { const uint32_t il = layer.il; - if (!hparams.is_swa(il)) { - continue; - } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); auto * k = layer.k_stream[strm]; @@ -2210,10 +2198,6 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm for (const auto & layer : layers) { const uint32_t il = layer.il; - if (!hparams.is_swa(il)) { - continue; - } - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[strm]; @@ -2246,10 +2230,6 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm for (const auto & layer : layers) { const uint32_t il = layer.il; - if (!hparams.is_swa(il)) { - continue; - } - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[strm]; diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 342a675962e2a..15020b76dd966 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -136,8 +136,8 @@ class llama_kv_cache_unified : public llama_memory_i { // state write/load - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) override; // // llama_kv_cache_unified specific API diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index e98b4e3546959..cbeeb21344ece 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -165,12 +165,16 @@ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const { return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id)); } -void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { +void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + GGML_UNUSED(flags); + mem_attn->state_write(io, seq_id); mem_recr->state_write(io, seq_id); } -void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) { +void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + GGML_UNUSED(flags); + mem_attn->state_read(io, seq_id); mem_recr->state_read(io, seq_id); } diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h index c2d56cd541594..9a02a9d73f04f 100644 --- a/src/llama-memory-hybrid.h +++ b/src/llama-memory-hybrid.h @@ -74,8 +74,8 @@ class llama_memory_hybrid : public llama_memory_i { // state write/load - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) override; // // llama_memory_hybrid specific API diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index c0c2ec084dc14..849675c418891 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -680,7 +680,9 @@ size_t llama_memory_recurrent::size_s_bytes() const { return size_s_bytes; } -void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { +void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + GGML_UNUSED(flags); + std::vector> cell_ranges; // ranges, from inclusive, to exclusive uint32_t cell_count = 0; @@ -718,7 +720,9 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq state_write_data(io, cell_ranges); } -void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { +void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + GGML_UNUSED(flags); + uint32_t cell_count; io.read_to(&cell_count, sizeof(cell_count)); diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 4d094f9a05788..7dae07e67b48c 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -63,8 +63,8 @@ class llama_memory_recurrent : public llama_memory_i { // state write/load - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) override; uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) uint32_t size = 0; // total number of cells, shared across all sequences diff --git a/src/llama-memory.h b/src/llama-memory.h index e8ba336e8525d..37865b3058f3d 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -104,8 +104,8 @@ struct llama_memory_i { // state write/read // - virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; - virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; + virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) const = 0; + virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) = 0; }; using llama_memory_ptr = std::unique_ptr; From e7d2ecdf2ae86c660aca445a1b69aa9dc21fa833 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 Aug 2025 10:16:28 +0300 Subject: [PATCH 5/8] server : do not make checkpoints if --swa-full ggml-ci --- tools/server/server.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 8513a3431cd59..c5f5a943f4db3 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3330,7 +3330,7 @@ struct server_context { if (!do_reset) { // restore the checkpoint const size_t swa_size = it->data.size(); - const size_t n = llama_state_seq_set_data(ctx, it->data.data(), swa_size, slot.id); + const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY); if (n != swa_size) { SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024); @@ -3361,7 +3361,7 @@ struct server_context { if (cur.pos_min > pos_min_thold) { slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i); - SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); } } } @@ -3578,18 +3578,19 @@ struct server_context { slot.state = SLOT_STATE_GENERATING; // make a checkpoint with the SWA memory - if (llama_model_n_swa(model) > 0) { + // checkpoints are needed only if we are not using "--swa-full" + if (llama_model_n_swa(model) > 0 && !params_base.swa_full) { if (slot.swa_checkpoints.size() >= SERVER_MAX_SWA_CHECKPOINTS_PER_SLOT) { { const auto & cur = slot.swa_checkpoints.back(); - SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); } slot.swa_checkpoints.erase(slot.swa_checkpoints.begin()); } - const size_t swa_size = llama_state_seq_get_size(ctx, slot.id); + const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY); auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{ /*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id), @@ -3597,9 +3598,9 @@ struct server_context { /*.data = */ std::vector(swa_size), }); - llama_state_seq_get_data(ctx, cur.data.data(), swa_size, slot.id); + llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY); - SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %f MiB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024); + SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024); } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots From c2b5cfb7739879c509c373359b877a3b7584a2a1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 Aug 2025 13:35:25 +0300 Subject: [PATCH 6/8] llama : remove flags value for NONE --- include/llama.h | 1 - src/llama-context.cpp | 10 +++++----- src/llama-kv-cache-unified-iswa.h | 4 ++-- src/llama-kv-cache-unified.h | 4 ++-- src/llama-memory-hybrid.h | 4 ++-- src/llama-memory-recurrent.h | 4 ++-- src/llama-memory.h | 4 ++-- 7 files changed, 15 insertions(+), 16 deletions(-) diff --git a/include/llama.h b/include/llama.h index f5df9f920f77e..571a25df70b71 100644 --- a/include/llama.h +++ b/include/llama.h @@ -870,7 +870,6 @@ extern "C" { size_t n_token_capacity, size_t * n_token_count_out); -#define LLAMA_STATE_SEQ_FLAGS_NONE 0 #define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 typedef uint32_t llama_state_seq_flags; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ca7938549ac04..d3122501d0969 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1778,7 +1778,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file { const size_t state_size = file.size() - file.tell(); llama_io_read_file io(&file); - const size_t nread = state_seq_read_data(io, seq_id, LLAMA_STATE_SEQ_FLAGS_NONE); + const size_t nread = state_seq_read_data(io, seq_id, 0); if (!nread) { LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); return 0; @@ -1802,7 +1802,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file // save the context state using stream saving llama_io_write_file io(&file); - state_seq_write_data(io, seq_id, LLAMA_STATE_SEQ_FLAGS_NONE); + state_seq_write_data(io, seq_id, 0); const size_t res = file.tell(); GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes()); @@ -2801,15 +2801,15 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const } size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) { - return llama_state_seq_get_size_ext(ctx, seq_id, LLAMA_STATE_SEQ_FLAGS_NONE); + return llama_state_seq_get_size_ext(ctx, seq_id, 0); } size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { - return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, LLAMA_STATE_SEQ_FLAGS_NONE); + return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0); } size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) { - return llama_state_seq_set_data_ext(ctx, src, size, seq_id, LLAMA_STATE_SEQ_FLAGS_NONE); + return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0); } size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) { diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h index 9faa36f189d3f..7bc4df718d342 100644 --- a/src/llama-kv-cache-unified-iswa.h +++ b/src/llama-kv-cache-unified-iswa.h @@ -56,8 +56,8 @@ class llama_kv_cache_unified_iswa : public llama_memory_i { // state write/load - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) override; + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; // // llama_kv_cache_unified_iswa specific API diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 15020b76dd966..07a7c9e4e46a1 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -136,8 +136,8 @@ class llama_kv_cache_unified : public llama_memory_i { // state write/load - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) override; + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; // // llama_kv_cache_unified specific API diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h index 9a02a9d73f04f..acdbc26bfb624 100644 --- a/src/llama-memory-hybrid.h +++ b/src/llama-memory-hybrid.h @@ -74,8 +74,8 @@ class llama_memory_hybrid : public llama_memory_i { // state write/load - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) override; + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; // // llama_memory_hybrid specific API diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 7dae07e67b48c..95c617b2c94bd 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -63,8 +63,8 @@ class llama_memory_recurrent : public llama_memory_i { // state write/load - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) override; + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) uint32_t size = 0; // total number of cells, shared across all sequences diff --git a/src/llama-memory.h b/src/llama-memory.h index 37865b3058f3d..171d312cc99d9 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -104,8 +104,8 @@ struct llama_memory_i { // state write/read // - virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) const = 0; - virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = LLAMA_STATE_SEQ_FLAGS_NONE) = 0; + virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0; + virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0; }; using llama_memory_ptr = std::unique_ptr; From 52b775edf30106dca06216f612681cef682b9248 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 Aug 2025 13:50:20 +0300 Subject: [PATCH 7/8] server : configure number of SWA checkpoints with CLI arg ggml-ci --- common/arg.cpp | 8 ++++++++ common/common.h | 11 ++++++----- tools/server/server.cpp | 17 +++++++++++------ 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 4e4c52b5f8748..776050e9f58e8 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1506,6 +1506,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.swa_full = true; } ).set_env("LLAMA_ARG_SWA_FULL")); + add_opt(common_arg( + {"--swa-checkpoints"}, "N", + string_format("max number of SWA checkpoints per slot to create (default: %d)\n" + "[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints), + [](common_params & params, int value) { + params.n_swa_checkpoints = value; + } + ).set_env("LLAMA_ARG_SWA_CHECKPOINTS")); add_opt(common_arg( {"--kv-unified", "-kvu"}, string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n" diff --git a/common/common.h b/common/common.h index c09509b669e54..111644b71108f 100644 --- a/common/common.h +++ b/common/common.h @@ -385,11 +385,12 @@ struct common_params { std::string cls_sep = "\t"; // separator of classification sequences // server params - int32_t port = 8080; // server listens on this network port - int32_t timeout_read = 600; // http read timeout in seconds - int32_t timeout_write = timeout_read; // http write timeout in seconds - int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) - int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting + int32_t port = 8080; // server listens on this network port + int32_t timeout_read = 600; // http read timeout in seconds + int32_t timeout_write = timeout_read; // http write timeout in seconds + int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) + int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting + int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT diff --git a/tools/server/server.cpp b/tools/server/server.cpp index c5f5a943f4db3..9979828dace3e 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -31,8 +31,6 @@ #include #include -#define SERVER_MAX_SWA_CHECKPOINTS_PER_SLOT 3 - using json = nlohmann::ordered_json; constexpr int HTTP_POLLING_SECONDS = 1; @@ -3579,12 +3577,13 @@ struct server_context { // make a checkpoint with the SWA memory // checkpoints are needed only if we are not using "--swa-full" - if (llama_model_n_swa(model) > 0 && !params_base.swa_full) { - if (slot.swa_checkpoints.size() >= SERVER_MAX_SWA_CHECKPOINTS_PER_SLOT) { + if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) { + if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) { { const auto & cur = slot.swa_checkpoints.back(); - SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", + cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); } slot.swa_checkpoints.erase(slot.swa_checkpoints.begin()); @@ -3600,7 +3599,13 @@ struct server_context { llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY); - SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024); + float size_total = 0.0f; + for (const auto & checkpoint : slot.swa_checkpoints) { + size_total += (float) checkpoint.data.size() / 1024 / 1024; + } + + SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n", + cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total); } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots From 3d08a6562480a9e7073d8de26bebc528c0581c02 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 Aug 2025 14:08:05 +0300 Subject: [PATCH 8/8] args : fix scope of new argument --- common/arg.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/arg.cpp b/common/arg.cpp index 776050e9f58e8..da44e79d53ce3 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1513,7 +1513,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, int value) { params.n_swa_checkpoints = value; } - ).set_env("LLAMA_ARG_SWA_CHECKPOINTS")); + ).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--kv-unified", "-kvu"}, string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"