@@ -7794,7 +7794,9 @@ static void llm_build_kv_store(
77947794 cb(k_cache_view, "k_cache_view", il);
77957795
77967796 // note: storing RoPE-ed version of K in the KV cache
7797- ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
7797+ ggml_tensor * tmp = ggml_cpy(ctx, k_cur, k_cache_view);
7798+ tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_K;
7799+ ggml_build_forward_expand(graph, tmp);
77987800
77997801 assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
78007802
@@ -7812,8 +7814,9 @@ static void llm_build_kv_store(
78127814 v_cur = ggml_transpose(ctx, v_cur);
78137815 }
78147816 cb(v_cache_view, "v_cache_view", il);
7815-
7816- ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
7817+ tmp=ggml_cpy(ctx, v_cur, v_cache_view);
7818+ tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_V;
7819+ ggml_build_forward_expand(graph, tmp);
78177820}
78187821
78197822static struct ggml_tensor * llm_build_norm(
@@ -14606,48 +14609,42 @@ static int llama_decode_internal(
1460614609
1460714610 if(ggml_use_cached_graph(lctx.sched)) {
1460814611
14609- // If using flash attention, find mask node so it can be skipped when updating
14610- // KV cache paramaters in cached graph nodes below
14611- void * flash_attn_mask_node = nullptr;
14612- if(cparams.flash_attn) {
14613- for (int i = 0; i < gf->n_nodes; i++) {
14614- ggml_tensor * node = gf->nodes[i];
14615- if (node->op == GGML_OP_FLASH_ATTN_EXT) {
14616- flash_attn_mask_node = node->src[3];
14617- break;
14618- }
14619- }
14620- }
14621-
1462214612 // Temporarily store KV cache parameters that will need updated in cached graph.
1462314613 const struct llama_hparams & hparams = model.hparams;
1462414614 const int64_t n_layer = hparams.n_layer;
1462514615 const int64_t kv_head = kv_self.head;
1462614616 std::vector<void *> kv_cache_ptrs;
14617+ std::vector<void *> k_cache_ptrs;
14618+ std::vector<void *> v_cache_ptrs;
1462714619 for (int il = 0; il < n_layer; ++il) {
1462814620 const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
1462914621 const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
1463014622 ggml_tensor * tmp_tensor = kv_self.k_l[il];
1463114623 size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
1463214624 kv_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14625+ k_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
1463314626 tmp_tensor = kv_self.v_l[il];
1463414627 if (cparams.flash_attn) {
1463514628 tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
1463614629 } else {
1463714630 tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
1463814631 }
1463914632 kv_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14633+ v_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
1464014634 }
1464114635
1464214636 // Update KV cache parameters in cached graph.
14643- int copy_op_count = 0;
14637+ int k_count = 0;
14638+ int v_count = 0;
1464414639 if(gf != nullptr && gf->nodes != nullptr){
1464514640 for (int i = 0; i < gf->n_nodes; i++) {
1464614641 ggml_tensor * node = gf->nodes[i];
1464714642 if (node->op == GGML_OP_CPY) {
14648- if (node != flash_attn_mask_node) {
14649- node->src[1]->data = kv_cache_ptrs[copy_op_count];
14650- copy_op_count++;
14643+ if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_K) {
14644+ node->src[1]->data = k_cache_ptrs[k_count++];
14645+ }
14646+ if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_V) {
14647+ node->src[1]->data = v_cache_ptrs[v_count++];
1465114648 }
1465214649 }
1465314650 }
0 commit comments