@@ -808,7 +808,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
808808 0 );
809809}
810810
811- ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs , int32_t il, const slot_info & sinfo) const {
811+ ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs , int32_t il, const slot_info & sinfo) const {
812812 const int32_t ikv = map_layer_ids.at (il);
813813
814814 auto * k = layers[ikv].k ;
@@ -818,8 +818,8 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
818818
819819 k_cur = ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens);
820820
821- if (kv_idxs && supports_set_rows) {
822- return ggml_set_rows (ctx, k, k_cur, kv_idxs );
821+ if (k_idxs && supports_set_rows) {
822+ return ggml_set_rows (ctx, k, k_cur, k_idxs );
823823 }
824824
825825 // TODO: fallback to old ggml_cpy() method for backwards compatibility
@@ -832,7 +832,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
832832 return ggml_cpy (ctx, k_cur, k_view);
833833}
834834
835- ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs , int32_t il, const slot_info & sinfo) const {
835+ ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs , int32_t il, const slot_info & sinfo) const {
836836 const int32_t ikv = map_layer_ids.at (il);
837837
838838 auto * v = layers[ikv].v ;
@@ -842,9 +842,9 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
842842
843843 v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
844844
845- if (kv_idxs && supports_set_rows) {
845+ if (v_idxs && supports_set_rows) {
846846 if (!v_trans) {
847- return ggml_set_rows (ctx, v, v_cur, kv_idxs );
847+ return ggml_set_rows (ctx, v, v_cur, v_idxs );
848848 }
849849
850850 // the row becomes a single element
@@ -859,10 +859,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
859859 // v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
860860
861861 // we broadcast the KV indices n_embd_v_gqa times
862- // v [1, n_kv, n_embd_v_gqa]
863- // v_cur [1, n_tokens, n_embd_v_gqa]
864- // kv_idxs [n_tokens, 1, 1]
865- return ggml_set_rows (ctx, v_view, v_cur, kv_idxs );
862+ // v [1, n_kv, n_embd_v_gqa]
863+ // v_cur [1, n_tokens, n_embd_v_gqa]
864+ // v_idxs [n_tokens, 1, 1]
865+ return ggml_set_rows (ctx, v_view, v_cur, v_idxs );
866866 }
867867
868868 // TODO: fallback to old ggml_cpy() method for backwards compatibility
@@ -885,7 +885,49 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
885885 return ggml_cpy (ctx, v_cur, v_view);
886886}
887887
888- void llama_kv_cache_unified::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
888+ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs (ggml_context * ctx, const llama_ubatch & ubatch) const {
889+ const uint32_t n_tokens = ubatch.n_tokens ;
890+
891+ ggml_tensor * k_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens);
892+
893+ ggml_set_input (k_idxs);
894+
895+ return k_idxs;
896+ }
897+
898+ ggml_tensor * llama_kv_cache_unified::build_input_v_idxs (ggml_context * ctx, const llama_ubatch & ubatch) const {
899+ const uint32_t n_tokens = ubatch.n_tokens ;
900+
901+ ggml_tensor * v_idxs;
902+
903+ if (!v_trans) {
904+ v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens);
905+ } else {
906+ // TODO: assert that n_embd_v_gqa is the same for all layers, or take the max
907+ v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa ());
908+ }
909+
910+ ggml_set_input (v_idxs);
911+
912+ return v_idxs;
913+ }
914+
915+ void llama_kv_cache_unified::set_input_k_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
916+ if (!supports_set_rows) {
917+ return ;
918+ }
919+
920+ const uint32_t n_tokens = ubatch->n_tokens ;
921+
922+ GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
923+ int64_t * data = (int64_t *) dst->data ;
924+
925+ for (int64_t i = 0 ; i < n_tokens; ++i) {
926+ data[i] = sinfo.idxs [i];
927+ }
928+ }
929+
930+ void llama_kv_cache_unified::set_input_v_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
889931 if (!supports_set_rows) {
890932 return ;
891933 }
@@ -1906,20 +1948,32 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
19061948 return kv->get_v (ctx, il, n_kv);
19071949}
19081950
1909- ggml_tensor * llama_kv_cache_unified_context::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
1910- return kv->cpy_k (ctx, k_cur, kv_idxs, il, sinfos[i_cur]);
1951+ ggml_tensor * llama_kv_cache_unified_context::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
1952+ return kv->cpy_k (ctx, k_cur, k_idxs, il, sinfos[i_cur]);
1953+ }
1954+
1955+ ggml_tensor * llama_kv_cache_unified_context::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
1956+ return kv->cpy_v (ctx, v_cur, v_idxs, il, sinfos[i_cur]);
19111957}
19121958
1913- ggml_tensor * llama_kv_cache_unified_context::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
1914- return kv->cpy_v (ctx, v_cur, kv_idxs, il, sinfos[i_cur]);
1959+ ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs (ggml_context * ctx, const llama_ubatch & ubatch) const {
1960+ return kv->build_input_k_idxs (ctx, ubatch);
1961+ }
1962+
1963+ ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs (ggml_context * ctx, const llama_ubatch & ubatch) const {
1964+ return kv->build_input_v_idxs (ctx, ubatch);
19151965}
19161966
19171967void llama_kv_cache_unified_context::set_input_k_shift (ggml_tensor * dst) const {
19181968 kv->set_input_k_shift (dst);
19191969}
19201970
1921- void llama_kv_cache_unified_context::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const {
1922- kv->set_input_kv_idxs (dst, ubatch, sinfos[i_cur]);
1971+ void llama_kv_cache_unified_context::set_input_k_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const {
1972+ kv->set_input_k_idxs (dst, ubatch, sinfos[i_cur]);
1973+ }
1974+
1975+ void llama_kv_cache_unified_context::set_input_v_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const {
1976+ kv->set_input_v_idxs (dst, ubatch, sinfos[i_cur]);
19231977}
19241978
19251979void llama_kv_cache_unified_context::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
0 commit comments