@@ -3446,7 +3446,9 @@ static struct ggml_cgraph * llm_build_starcoder(
34463446 const int64_t n_layer = hparams.n_layer ;
34473447 const int64_t n_ctx = hparams.n_ctx ;
34483448 const int64_t n_head = hparams.n_head ;
3449+ const int64_t n_head_kv = hparams.n_head_kv ;
34493450 const int64_t n_embd_head = hparams.n_embd_head ();
3451+ const int64_t n_embd_gqa = hparams.n_embd_gqa ();
34503452
34513453 GGML_ASSERT (n_embd_head == hparams.n_rot );
34523454
@@ -3508,28 +3510,44 @@ static struct ggml_cgraph * llm_build_starcoder(
35083510 position = ggml_get_rows (ctx0, model.pos_embeddings , inp_positions);
35093511 }
35103512
3513+ struct ggml_tensor * KQ_scale = ggml_new_tensor_1d (ctx0, GGML_TYPE_F32, 1 );
3514+ ggml_allocr_alloc (lctx.alloc , KQ_scale);
3515+ if (!ggml_allocr_is_measure (lctx.alloc )) {
3516+ ggml_set_f32 (KQ_scale, 1 .0f /sqrtf (float (n_embd)/n_head));
3517+ }
3518+ ggml_set_name (KQ_scale, " 1/sqrt(n_embd_head)" );
3519+
35113520 inpL = ggml_add (ctx0, token, position);
3521+ ggml_set_name (inpL, " inpL" );
35123522
35133523 for (int il = 0 ; il < n_layer; ++il) {
35143524 {
35153525 // Norm
35163526 cur = ggml_norm (ctx0, inpL, norm_eps);
35173527 cur = ggml_add (ctx0, ggml_mul (ctx0, cur, model.layers [il].attn_norm ), model.layers [il].attn_norm_b );
3518-
35193528 }
35203529
35213530 {
35223531 // Self Attention
35233532 cur = ggml_add (ctx0, ggml_mul_mat (ctx0, model.layers [il].wqkv , cur), model.layers [il].bqkv );
35243533
3525- struct ggml_tensor * Qcur = ggml_view_2d (ctx0, cur, n_embd, N, cur->nb [1 ], 0 *sizeof (float )*n_embd);
3526- struct ggml_tensor * Kcur = ggml_view_2d (ctx0, cur, n_embd, N, cur->nb [1 ], 1 *sizeof (float )*n_embd);
3527- struct ggml_tensor * Vcur = ggml_view_2d (ctx0, cur, n_embd, N, cur->nb [1 ], 2 *sizeof (float )*n_embd);
3534+ struct ggml_tensor * tmpq = ggml_view_2d (ctx0, cur, n_embd, N, cur->nb [1 ], 0 *sizeof (float )*n_embd);
3535+ struct ggml_tensor * tmpk = ggml_view_2d (ctx0, cur, n_embd, N, cur->nb [1 ], 1 *sizeof (float )*n_embd);
3536+ struct ggml_tensor * tmpv = ggml_view_2d (ctx0, cur, n_embd, N, cur->nb [1 ], 2 *sizeof (float )*n_embd);
35283537
3529- // store key and value to memory
3530- if (N >= 1 ) {
3531- struct ggml_tensor * k = ggml_view_1d (ctx0, kv_self.k , N*n_embd, (ggml_element_size (kv_self.k )*n_embd)*(il*n_ctx + n_past));
3532- struct ggml_tensor * v = ggml_view_1d (ctx0, kv_self.v , N*n_embd, (ggml_element_size (kv_self.v )*n_embd)*(il*n_ctx + n_past));
3538+ struct ggml_tensor * Qcur = tmpq;
3539+ struct ggml_tensor * Kcur = tmpk;
3540+
3541+ {
3542+ struct ggml_tensor * Vcur = ggml_transpose (ctx0, ggml_reshape_2d (ctx0, ggml_cont (ctx0, tmpv), n_embd_gqa, N));
3543+ ggml_set_name (Vcur, " Vcur" );
3544+
3545+ struct ggml_tensor * k = ggml_view_1d (ctx0, kv_self.k , N*n_embd_gqa, (ggml_element_size (kv_self.k )*n_embd_gqa)*(il*n_ctx + n_past));
3546+ ggml_set_name (k, " k" );
3547+
3548+ struct ggml_tensor * v = ggml_view_2d (ctx0, kv_self.v , N, n_embd_gqa,
3549+ ( n_ctx)*ggml_element_size (kv_self.v ),
3550+ (il*n_ctx)*ggml_element_size (kv_self.v )*n_embd_gqa + n_past*ggml_element_size (kv_self.v ));
35333551
35343552 ggml_build_forward_expand (gf, ggml_cpy (ctx0, Kcur, k));
35353553 ggml_build_forward_expand (gf, ggml_cpy (ctx0, Vcur, v));
@@ -3541,56 +3559,62 @@ static struct ggml_cgraph * llm_build_starcoder(
35413559 Qcur,
35423560 ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
35433561 0 , 2 , 1 , 3 );
3562+ ggml_set_name (Q, " Q" );
35443563
35453564 struct ggml_tensor * K =
3546- ggml_permute (ctx0,
3547- ggml_reshape_3d (ctx0,
3548- ggml_view_1d (ctx0, kv_self.k , (n_past + N)*n_embd, il*n_ctx*ggml_element_size (kv_self.k )*n_embd),
3549- n_embd/n_head, n_head, n_past + N),
3550- 0 , 2 , 1 , 3 ); // TODO: need to be tiled
3565+ ggml_view_3d (ctx0, kv_self.k ,
3566+ n_embd_head, n_past + N, n_head_kv,
3567+ ggml_element_size (kv_self.k )*n_embd_gqa,
3568+ ggml_element_size (kv_self.k )*n_embd_head,
3569+ ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il);
3570+ ggml_set_name (K, " K" );
35513571
35523572 // K * Q
3553- // [n_past + N, N, 12]
35543573 struct ggml_tensor * KQ = ggml_mul_mat (ctx0, K, Q);
3574+ ggml_set_name (KQ, " KQ" );
35553575
3556- // KQ_scaled = KQ / sqrt(n_embd/n_head)
3557- // [n_past + N, N, 12]
3558- struct ggml_tensor * KQ_scaled =
3559- ggml_scale_inplace (ctx0,
3560- KQ,
3561- ggml_new_f32 (ctx0, 1 .0f /sqrt (float (n_embd)/n_head))
3562- );
3576+ // KQ_scaled = KQ / sqrt(n_embd_head)
3577+ // KQ_scaled shape [n_past + N, N, n_head, 1]
3578+ struct ggml_tensor * KQ_scaled = ggml_scale_inplace (ctx0, KQ, KQ_scale);
3579+ ggml_set_name (KQ_scaled, " KQ_scaled" );
35633580
35643581 // KQ_masked = mask_past(KQ_scaled)
3565- // [n_past + N, N, 12]
35663582 struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace (ctx0, KQ_scaled, n_past);
3583+ ggml_set_name (KQ_masked, " KQ_masked" );
35673584
35683585 // KQ = soft_max(KQ_masked)
3569- // [n_past + N, N, 12]
35703586 struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace (ctx0, KQ_masked);
3587+ ggml_set_name (KQ_soft_max, " KQ_soft_max" );
35713588
3572- // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
3573- // [n_past + N, 64, 12]
3574- struct ggml_tensor * V_trans =
3575- ggml_cpy (ctx0,
3576- ggml_permute (ctx0,
3577- ggml_reshape_3d (ctx0,
3578- ggml_view_1d (ctx0, kv_self.v , (n_past + N)*n_embd, il*n_ctx*ggml_element_size (kv_self.v )*n_embd),
3579- n_embd/n_head, n_head, n_past + N),
3580- 1 , 2 , 0 , 3 ),
3581- ggml_new_tensor_3d (ctx0, kv_self.v ->type , n_past + N, n_embd/n_head, n_head));
3582-
3583- // KQV = transpose(V) * KQ_soft_max
3584- // [64, N, 12]
3585- struct ggml_tensor * KQV = ggml_mul_mat (ctx0, V_trans, KQ_soft_max);
3589+ // split cached V into n_head heads
3590+ struct ggml_tensor * V =
3591+ ggml_view_3d (ctx0, kv_self.v ,
3592+ n_past + N, n_embd_head, n_head_kv,
3593+ ggml_element_size (kv_self.v )*n_ctx,
3594+ ggml_element_size (kv_self.v )*n_ctx*n_embd_head,
3595+ ggml_element_size (kv_self.v )*n_ctx*n_embd_gqa*il);
3596+ ggml_set_name (V, " V" );
3597+
3598+ #if 1
3599+ struct ggml_tensor * KQV = ggml_mul_mat (ctx0, V, KQ_soft_max);
3600+ ggml_set_name (KQV, " KQV" );
3601+ #else
3602+ // make V contiguous in memory to speed up the matmul, however we waste time on the copy
3603+ // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
3604+ // is there a better way?
3605+ struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head));
3606+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
3607+ #endif
35863608
35873609 // KQV_merged = KQV.permute(0, 2, 1, 3)
3588- // [64, 12, N]
35893610 struct ggml_tensor * KQV_merged = ggml_permute (ctx0, KQV, 0 , 2 , 1 , 3 );
3611+ ggml_set_name (KQV_merged, " KQV_merged" );
35903612
3613+ // cur = KQV_merged.contiguous().view(n_embd, N)
35913614 cur = ggml_cpy (ctx0,
35923615 KQV_merged,
35933616 ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, N));
3617+ ggml_set_name (cur, " KQV_merged_contiguous" );
35943618 }
35953619
35963620 // Projection
0 commit comments