@@ -1007,7 +1007,8 @@ struct llama_layer {
10071007};
10081008
10091009struct llama_kv_cell {
1010- llama_pos pos = -1 ;
1010+ llama_pos pos = -1 ;
1011+ llama_pos delta = 0 ;
10111012
10121013 std::set<llama_seq_id> seq_id;
10131014
@@ -1018,7 +1019,7 @@ struct llama_kv_cell {
10181019
10191020// ring-buffer of cached KV data
10201021struct llama_kv_cache {
1021- bool is_roped = false ;
1022+ bool has_shift = false ;
10221023
10231024 uint32_t head = 0 ;
10241025 uint32_t size = 0 ;
@@ -1333,9 +1334,13 @@ void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t
13331334 }
13341335}
13351336
1336- void llama_kv_cache_rm_seq (struct llama_kv_cache & cache, llama_seq_id seq_id) {
1337+ void llama_kv_cache_rm_seq (
1338+ struct llama_kv_cache & cache,
1339+ llama_seq_id seq_id,
1340+ llama_pos p0,
1341+ llama_pos p1) {
13371342 for (uint32_t i = 0 ; i < cache.size ; ++i) {
1338- if (cache.cells [i].has_seq_id (seq_id)) {
1343+ if (cache.cells [i].has_seq_id (seq_id) && cache. cells [i]. pos >= p0 && cache. cells [i]. pos < p1 ) {
13391344 cache.cells [i].seq_id .erase (seq_id);
13401345 if (cache.cells [i].seq_id .empty ()) {
13411346 cache.cells [i].pos = -1 ;
@@ -1353,18 +1358,22 @@ void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id)
13531358 }
13541359}
13551360
1356- void llama_kv_cache_shift (
1357- struct llama_context & ctx ,
1361+ void llama_kv_cache_shift_seq (
1362+ struct llama_kv_cache & cache ,
13581363 llama_seq_id seq_id,
13591364 llama_pos p0,
13601365 llama_pos p1,
13611366 llama_pos delta) {
1362- auto & hparams = ctx.model .hparams ;
1363- auto & cache = ctx.kv_self ;
1364-
13651367 for (uint32_t i = 0 ; i < cache.size ; ++i) {
13661368 if (cache.cells [i].has_seq_id (seq_id) && cache.cells [i].pos >= p0 && cache.cells [i].pos < p1) {
13671369 cache.cells [i].pos += delta;
1370+ if (cache.cells [i].pos < 0 ) {
1371+ cache.cells [i].pos = -1 ;
1372+ cache.cells [i].seq_id .clear ();
1373+ } else {
1374+ cache.has_shift = true ;
1375+ cache.cells [i].delta = delta;
1376+ }
13681377 }
13691378 }
13701379}
@@ -2595,6 +2604,8 @@ static struct ggml_cgraph * llm_build_llama(
25952604 const int32_t n_tokens = batch.n_tokens ;
25962605 const int32_t n_kv = llama_kv_cache_cell_max (kv_self);
25972606
2607+ const bool do_rope_shift = kv_self.has_shift ;
2608+
25982609 auto & buf_compute = lctx.buf_compute ;
25992610
26002611 struct ggml_init_params params = {
@@ -2698,6 +2709,16 @@ static struct ggml_cgraph * llm_build_llama(
26982709 }
26992710 }
27002711
2712+ // K_shift
2713+ struct ggml_tensor * K_shift = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_ctx);
2714+ ggml_allocr_alloc (lctx.alloc , K_shift);
2715+ if (!ggml_allocr_is_measure (lctx.alloc ) && do_rope_shift) {
2716+ int * data = (int *) K_shift->data ;
2717+ for (int i = 0 ; i < n_ctx; ++i) {
2718+ data[i] = kv_self.cells [i].delta ;
2719+ }
2720+ }
2721+
27012722 for (int il = 0 ; il < n_layer; ++il) {
27022723 ggml_format_name (inpL, " layer_inp_%d" , il);
27032724
@@ -2723,6 +2744,17 @@ static struct ggml_cgraph * llm_build_llama(
27232744 ggml_set_name (cur, " attention_norm_0" );
27242745 }
27252746
2747+ if (do_rope_shift) {
2748+ ggml_build_forward_expand (gf,
2749+ ggml_rope_custom_inplace (ctx0,
2750+ ggml_view_3d (ctx0, kv_self.k ,
2751+ n_embd_head, n_head_kv, n_ctx,
2752+ ggml_element_size (kv_self.k )*n_embd_head,
2753+ ggml_element_size (kv_self.k )*n_embd_gqa,
2754+ ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il),
2755+ K_shift, n_embd_head, 0 , 0 , freq_base, freq_scale));
2756+ }
2757+
27262758 // self-attention
27272759 {
27282760 // compute Q and K and RoPE them
@@ -4033,7 +4065,8 @@ static bool llama_eval_internal(
40334065#endif
40344066
40354067 // update the kv ring buffer
4036- lctx.kv_self .head += n_tokens;
4068+ lctx.kv_self .head += n_tokens;
4069+ lctx.kv_self .has_shift = false ;
40374070
40384071#ifdef GGML_PERF
40394072 // print timing information per ggml operation (for debugging purposes)
@@ -6562,10 +6595,6 @@ struct llama_context * llama_new_context_with_model(
65626595 return nullptr ;
65636596 }
65646597
6565- if (model->arch == LLM_ARCH_LLAMA) {
6566- ctx->kv_self .is_roped = true ;
6567- }
6568-
65696598 {
65706599 const size_t memory_size = ggml_nbytes (ctx->kv_self .k ) + ggml_nbytes (ctx->kv_self .v );
65716600 LLAMA_LOG_INFO (" %s: kv self size = %7.2f MB\n " , __func__, memory_size / 1024.0 / 1024.0 );
@@ -6803,16 +6832,16 @@ void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1
68036832 llama_kv_cache_rm_tokens (ctx->kv_self , c0, c1);
68046833}
68056834
6806- void llama_kv_cache_rm_seq (struct llama_context * ctx, llama_seq_id seq_id) {
6807- llama_kv_cache_rm_seq (ctx->kv_self , seq_id);
6835+ void llama_kv_cache_rm_seq (struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1 ) {
6836+ llama_kv_cache_rm_seq (ctx->kv_self , seq_id, p0, p1 );
68086837}
68096838
68106839void llama_kv_cache_keep_seq (struct llama_context * ctx, llama_seq_id seq_id) {
68116840 llama_kv_cache_keep_seq (ctx->kv_self , seq_id);
68126841}
68136842
6814- void llama_kv_cache_shift (struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
6815- llama_kv_cache_shift (* ctx, seq_id, p0, p1, delta);
6843+ void llama_kv_cache_shift_seq (struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
6844+ llama_kv_cache_shift_seq ( ctx-> kv_self , seq_id, p0, p1, delta);
68166845}
68176846
68186847// Returns the *maximum* size of the state
0 commit comments