@@ -803,6 +803,8 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
803803 }
804804 }
805805
806+ assert (res.s1 >= res.s0 );
807+
806808 return res;
807809}
808810
@@ -908,13 +910,8 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint
908910
909911 auto * k = layers[ikv].k ;
910912
911- assert (sinfo.s1 >= sinfo.s0 );
912-
913913 const uint32_t ns = sinfo.s1 - sinfo.s0 + 1 ;
914914
915- assert (ns > 0 );
916- assert (ns <= n_seq_virt);
917-
918915 const uint64_t size_virt = ggml_row_size (k->type , hparams.n_embd_k_gqa (il)*get_size ());
919916
920917 return ggml_view_4d (ctx, k,
@@ -932,9 +929,6 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
932929
933930 const uint32_t ns = sinfo.s1 - sinfo.s0 + 1 ;
934931
935- assert (ns > 0 );
936- assert (ns <= n_seq_virt);
937-
938932 const uint64_t size_virt = ggml_row_size (v->type , hparams.n_embd_v_gqa (il)*get_size ());
939933
940934 if (!v_trans) {
@@ -967,9 +961,20 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
967961 k_cur = ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens);
968962
969963 if (kv_idxs && supports_set_rows) {
970- k = ggml_reshape_2d (ctx, k, k->ne [0 ], k->ne [1 ]*k->ne [2 ]);
964+ const uint32_t ns = sinfo.s1 - sinfo.s0 + 1 ;
965+
966+ const uint64_t size_virt = ggml_row_size (k->type , hparams.n_embd_k_gqa (il)*get_size ());
967+
968+ ggml_tensor * k_view = ggml_view_3d (ctx, k, k->ne [0 ], k->ne [1 ], ns,
969+ ggml_row_size (k->type , k->ne [0 ]),
970+ size_virt,
971+ size_virt*sinfo.s0 );
972+
973+ k_cur = ggml_reshape_3d (ctx, k_cur, k_cur->ne [0 ], k_cur->ne [1 ]/ns, ns);
974+
975+ kv_idxs = ggml_reshape_2d (ctx, kv_idxs, n_tokens/ns, ns);
971976
972- return ggml_set_rows (ctx, k , k_cur, kv_idxs);
977+ return ggml_set_rows (ctx, k_view , k_cur, kv_idxs);
973978 }
974979
975980 // TODO: fallback to old ggml_cpy() method for backwards compatibility
@@ -995,27 +1000,46 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
9951000 v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
9961001
9971002 if (kv_idxs && supports_set_rows) {
1003+ const uint32_t ns = sinfo.s1 - sinfo.s0 + 1 ;
1004+
1005+ const uint64_t size_virt = ggml_row_size (v->type , hparams.n_embd_v_gqa (il)*get_size ());
1006+
9981007 if (!v_trans) {
999- v = ggml_reshape_2d (ctx, v, v->ne [0 ], v->ne [1 ]*v->ne [2 ]);
1008+ ggml_tensor * v_view = ggml_view_3d (ctx, v, v->ne [0 ], v->ne [1 ], ns,
1009+ ggml_row_size (v->type , v->ne [0 ]),
1010+ size_virt,
1011+ size_virt*sinfo.s0 );
1012+
1013+ v_cur = ggml_reshape_3d (ctx, v_cur, v_cur->ne [0 ], v_cur->ne [1 ]/ns, ns);
10001014
1001- return ggml_set_rows (ctx, v, v_cur, kv_idxs);
1015+ kv_idxs = ggml_reshape_2d (ctx, kv_idxs, n_tokens/ns, ns);
1016+
1017+ return ggml_set_rows (ctx, v_view, v_cur, kv_idxs);
10021018 }
10031019
10041020 // the row becomes a single element
1005- ggml_tensor * v_view = ggml_reshape_3d (ctx, v, 1 , v->ne [1 ]*v->ne [2 ], v->ne [0 ]);
1021+ ggml_tensor * v_view = ggml_view_4d (ctx, v, 1 , v->ne [1 ], v->ne [0 ], ns,
1022+ ggml_row_size (v->type , 1 ),
1023+ ggml_row_size (v->type , v->ne [1 ]),
1024+ size_virt,
1025+ size_virt*sinfo.s0 );
10061026
10071027 // note: the V cache is transposed when not using flash attention
1008- v_cur = ggml_permute (ctx, ggml_reshape_3d (ctx, v_cur, v_cur->ne [0 ], 1 , v_cur->ne [1 ]), 2 , 0 , 1 , 3 );
1028+ v_cur = ggml_permute (ctx, ggml_reshape_4d (ctx, v_cur, v_cur->ne [0 ], 1 , v_cur->ne [1 ]/ns, ns ), 2 , 0 , 1 , 3 );
10091029
10101030 // note: we can be more explicit here at the cost of extra cont
10111031 // however, above we take advantage that a row of single element is always contiguous regardless of the row stride
1032+ // v_cur = ggml_reshape_3d(ctx, v_cur, n_embd_v_gqa, v_cur->ne[1]/ns, ns);
10121033 // v_cur = ggml_transpose(ctx, v_cur);
1013- // v_cur = ggml_cont_3d (ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
1034+ // v_cur = ggml_cont_4d (ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1], v_cur->ne[2 ]);
10141035
10151036 // we broadcast the KV indices n_embd_v_gqa times
1016- // v [1, n_kv, n_embd_v_gqa]
1017- // v_cur [1, n_tokens, n_embd_v_gqa]
1018- // kv_idxs [n_tokens, 1, 1]
1037+ // v [1, n_kv, n_embd_v_gqa, ns]
1038+ // v_cur [1, n_tokens/ns, n_embd_v_gqa, ns]
1039+ // kv_idxs [n_tokens/ns, 1, ns]
1040+
1041+ kv_idxs = ggml_reshape_3d (ctx, kv_idxs, n_tokens/ns, 1 , ns);
1042+
10191043 return ggml_set_rows (ctx, v_view, v_cur, kv_idxs);
10201044 }
10211045
@@ -1053,10 +1077,8 @@ void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ub
10531077 int64_t * data = (int64_t *) dst->data ;
10541078
10551079 for (uint32_t s = 0 ; s < sinfo.n_seq_virt (); ++s) {
1056- const int64_t offs = sinfo.seq_id_virt [s]*get_size ();
1057-
10581080 for (uint32_t i = 0 ; i < sinfo.size (); ++i) {
1059- data[s*sinfo.size () + i] = offs + sinfo.idxs [s][i];
1081+ data[s*sinfo.size () + i] = sinfo.idxs [s][i];
10601082 }
10611083 }
10621084}
0 commit comments