@@ -1007,7 +1007,8 @@ struct llama_layer {
10071007};
10081008
10091009struct llama_kv_cell {
1010- llama_pos pos = -1 ;
1010+ llama_pos pos = -1 ;
1011+ llama_pos delta = 0 ;
10111012
10121013 std::set<llama_seq_id> seq_id;
10131014
@@ -1018,7 +1019,7 @@ struct llama_kv_cell {
10181019
10191020// ring-buffer of cached KV data
10201021struct llama_kv_cache {
1021- bool is_roped = false ;
1022+ bool has_shift = false ;
10221023
10231024 uint32_t head = 0 ;
10241025 uint32_t size = 0 ;
@@ -1223,6 +1224,8 @@ static bool llama_kv_cache_init(
12231224 const int64_t n_mem = n_layer*n_ctx;
12241225 const int64_t n_elements = n_embd*n_mem;
12251226
1227+ cache.has_shift = false ;
1228+
12261229 cache.head = 0 ;
12271230 cache.size = n_ctx;
12281231
@@ -1333,9 +1336,13 @@ void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t
13331336 }
13341337}
13351338
1336- void llama_kv_cache_rm_seq (struct llama_kv_cache & cache, llama_seq_id seq_id) {
1339+ void llama_kv_cache_rm_seq (
1340+ struct llama_kv_cache & cache,
1341+ llama_seq_id seq_id,
1342+ llama_pos p0,
1343+ llama_pos p1) {
13371344 for (uint32_t i = 0 ; i < cache.size ; ++i) {
1338- if (cache.cells [i].has_seq_id (seq_id)) {
1345+ if (cache.cells [i].has_seq_id (seq_id) && cache. cells [i]. pos >= p0 && cache. cells [i]. pos < p1 ) {
13391346 cache.cells [i].seq_id .erase (seq_id);
13401347 if (cache.cells [i].seq_id .empty ()) {
13411348 cache.cells [i].pos = -1 ;
@@ -1353,18 +1360,22 @@ void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id)
13531360 }
13541361}
13551362
1356- void llama_kv_cache_shift (
1357- struct llama_context & ctx ,
1363+ void llama_kv_cache_shift_seq (
1364+ struct llama_kv_cache & cache ,
13581365 llama_seq_id seq_id,
13591366 llama_pos p0,
13601367 llama_pos p1,
13611368 llama_pos delta) {
1362- auto & hparams = ctx.model .hparams ;
1363- auto & cache = ctx.kv_self ;
1364-
13651369 for (uint32_t i = 0 ; i < cache.size ; ++i) {
13661370 if (cache.cells [i].has_seq_id (seq_id) && cache.cells [i].pos >= p0 && cache.cells [i].pos < p1) {
13671371 cache.cells [i].pos += delta;
1372+ if (cache.cells [i].pos < 0 ) {
1373+ cache.cells [i].pos = -1 ;
1374+ cache.cells [i].seq_id .clear ();
1375+ } else {
1376+ cache.has_shift = true ;
1377+ cache.cells [i].delta = delta;
1378+ }
13681379 }
13691380 }
13701381}
@@ -2595,6 +2606,8 @@ static struct ggml_cgraph * llm_build_llama(
25952606 const int32_t n_tokens = batch.n_tokens ;
25962607 const int32_t n_kv = llama_kv_cache_cell_max (kv_self);
25972608
2609+ const bool do_rope_shift = kv_self.has_shift || ggml_allocr_is_measure (lctx.alloc );
2610+
25982611 auto & buf_compute = lctx.buf_compute ;
25992612
26002613 struct ggml_init_params params = {
@@ -2698,6 +2711,16 @@ static struct ggml_cgraph * llm_build_llama(
26982711 }
26992712 }
27002713
2714+ // K_shift
2715+ struct ggml_tensor * K_shift = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_ctx);
2716+ ggml_allocr_alloc (lctx.alloc , K_shift);
2717+ if (!ggml_allocr_is_measure (lctx.alloc )) {
2718+ int * data = (int *) K_shift->data ;
2719+ for (int i = 0 ; i < n_ctx; ++i) {
2720+ data[i] = kv_self.cells [i].delta ;
2721+ }
2722+ }
2723+
27012724 for (int il = 0 ; il < n_layer; ++il) {
27022725 ggml_format_name (inpL, " layer_inp_%d" , il);
27032726
@@ -2723,6 +2746,17 @@ static struct ggml_cgraph * llm_build_llama(
27232746 ggml_set_name (cur, " attention_norm_0" );
27242747 }
27252748
2749+ if (do_rope_shift) {
2750+ ggml_build_forward_expand (gf,
2751+ ggml_rope_custom_inplace (ctx0,
2752+ ggml_view_3d (ctx0, kv_self.k ,
2753+ n_embd_head, n_head_kv, n_ctx,
2754+ ggml_element_size (kv_self.k )*n_embd_head,
2755+ ggml_element_size (kv_self.k )*n_embd_gqa,
2756+ ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il),
2757+ K_shift, n_embd_head, 0 , 0 , freq_base, freq_scale));
2758+ }
2759+
27262760 // self-attention
27272761 {
27282762 // compute Q and K and RoPE them
@@ -4033,7 +4067,8 @@ static bool llama_eval_internal(
40334067#endif
40344068
40354069 // update the kv ring buffer
4036- lctx.kv_self .head += n_tokens;
4070+ lctx.kv_self .head += n_tokens;
4071+ lctx.kv_self .has_shift = false ;
40374072
40384073#ifdef GGML_PERF
40394074 // print timing information per ggml operation (for debugging purposes)
@@ -6562,10 +6597,6 @@ struct llama_context * llama_new_context_with_model(
65626597 return nullptr ;
65636598 }
65646599
6565- if (model->arch == LLM_ARCH_LLAMA) {
6566- ctx->kv_self .is_roped = true ;
6567- }
6568-
65696600 {
65706601 const size_t memory_size = ggml_nbytes (ctx->kv_self .k ) + ggml_nbytes (ctx->kv_self .v );
65716602 LLAMA_LOG_INFO (" %s: kv self size = %7.2f MB\n " , __func__, memory_size / 1024.0 / 1024.0 );
@@ -6803,16 +6834,16 @@ void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1
68036834 llama_kv_cache_rm_tokens (ctx->kv_self , c0, c1);
68046835}
68056836
6806- void llama_kv_cache_rm_seq (struct llama_context * ctx, llama_seq_id seq_id) {
6807- llama_kv_cache_rm_seq (ctx->kv_self , seq_id);
6837+ void llama_kv_cache_rm_seq (struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1 ) {
6838+ llama_kv_cache_rm_seq (ctx->kv_self , seq_id, p0, p1 );
68086839}
68096840
68106841void llama_kv_cache_keep_seq (struct llama_context * ctx, llama_seq_id seq_id) {
68116842 llama_kv_cache_keep_seq (ctx->kv_self , seq_id);
68126843}
68136844
6814- void llama_kv_cache_shift (struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
6815- llama_kv_cache_shift (* ctx, seq_id, p0, p1, delta);
6845+ void llama_kv_cache_shift_seq (struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
6846+ llama_kv_cache_shift_seq ( ctx-> kv_self , seq_id, p0, p1, delta);
68166847}
68176848
68186849// Returns the *maximum* size of the state
0 commit comments