@@ -4323,7 +4323,7 @@ struct llm_build_context {
43234323 struct ggml_tensor * Kcur = ggml_concat (ctx0, krotated, kpass);
43244324 cb (Kcur, " Kcur" , il);
43254325
4326- struct ggml_tensor * Q = ggml_cont (ctx0, ggml_permute (ctx0, Qcur, 1 , 2 , 0 , 3 ));
4326+ struct ggml_tensor * Q = ggml_cont (ctx0, ggml_permute (ctx0, Qcur, 2 , 1 , 0 , 3 ));
43274327 cb (Q, " Q" , il);
43284328
43294329 Kcur = ggml_cont (ctx0, ggml_permute (ctx0, Kcur, 2 , 1 , 0 , 3 ));
@@ -4710,34 +4710,106 @@ struct llm_build_context {
47104710 // self-attention
47114711 {
47124712 // compute Q and K and RoPE them
4713- struct ggml_tensor * Qcur = ggml_mul_mat (ctx0, model.layers [il].wq , cur);
4714- cb (Qcur , " Qcur " , il);
4713+ struct ggml_tensor * tmpq = ggml_mul_mat (ctx0, model.layers [il].wq , cur);
4714+ cb (tmpq , " tmpq " , il);
47154715
4716- struct ggml_tensor * Kcur = ggml_mul_mat (ctx0, model.layers [il].wk , cur);
4717- cb (Kcur , " Kcur " , il);
4716+ struct ggml_tensor * tmpk = ggml_mul_mat (ctx0, model.layers [il].wk , cur);
4717+ cb (tmpk , " tmpk " , il);
47184718
47194719 struct ggml_tensor * Vcur = ggml_mul_mat (ctx0, model.layers [il].wv , cur);
47204720 cb (Vcur, " Vcur" , il);
47214721
4722- Qcur = ggml_rope_custom (
4723- ctx0, ggml_reshape_3d (ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
4724- hparams.n_rot , 2 , 0 , n_orig_ctx, freq_base, freq_scale,
4725- ext_factor, attn_factor, beta_fast, beta_slow
4722+ // RoPE the first n_rot of q/k, pass the other half, and concat.
4723+ struct ggml_tensor * qrot = ggml_view_3d (
4724+ ctx0, tmpq, hparams.n_rot , n_head, n_tokens,
4725+ ggml_element_size (tmpq) * n_embd_head,
4726+ ggml_element_size (tmpq) * n_embd_head * n_head,
4727+ 0
4728+ );
4729+ cb (qrot, " qrot" , il);
4730+
4731+ struct ggml_tensor * krot = ggml_view_3d (
4732+ ctx0, tmpk, hparams.n_rot , n_head, n_tokens,
4733+ ggml_element_size (tmpk) * n_embd_head,
4734+ ggml_element_size (tmpk) * n_embd_head * n_head_kv,
4735+ 0
4736+ );
4737+ cb (krot, " krot" , il);
4738+
4739+ // get the second half of tmpq, e.g tmpq[n_rot:, :, :]
4740+ struct ggml_tensor * qpass = ggml_view_3d (
4741+ ctx0, tmpq, (n_embd_head - hparams.n_rot ), n_head, n_tokens,
4742+ ggml_element_size (tmpq) * n_embd_head,
4743+ ggml_element_size (tmpq) * n_embd_head * n_head,
4744+ ggml_element_size (tmpq) * hparams.n_rot
4745+ );
4746+ cb (qpass, " qpass" , il);
4747+
4748+ struct ggml_tensor * kpass = ggml_view_3d (
4749+ ctx0, tmpk, (n_embd_head - hparams.n_rot ), n_head_kv, n_tokens,
4750+ ggml_element_size (tmpk) * (n_embd_head),
4751+ ggml_element_size (tmpk) * (n_embd_head) * n_head_kv,
4752+ ggml_element_size (tmpk) * hparams.n_rot
4753+ );
4754+ cb (kpass, " kpass" , il);
4755+
4756+ struct ggml_tensor * qrotated = ggml_rope_custom (
4757+ ctx0, qrot, inp_pos, hparams.n_rot , 2 , 0 , n_orig_ctx,
4758+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
47264759 );
4727- cb (Qcur , " Qcur " , il);
4760+ cb (qrotated , " qrotated " , il);
47284761
4729- Kcur = ggml_rope_custom (
4730- ctx0, ggml_reshape_3d (ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
4731- hparams.n_rot , 2 , 0 , n_orig_ctx, freq_base, freq_scale,
4732- ext_factor, attn_factor, beta_fast, beta_slow
4762+ struct ggml_tensor * krotated = ggml_rope_custom (
4763+ ctx0, krot, inp_pos, hparams.n_rot , 2 , 0 , n_orig_ctx,
4764+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
47334765 );
4766+ cb (krotated, " krotated" , il);
4767+
4768+ // ggml currently only supports concatenation on dim=2
4769+ // so we need to permute qrot, qpass, concat, then permute back.
4770+ qrotated = ggml_cont (ctx0, ggml_permute (ctx0, qrotated, 2 , 1 , 0 , 3 ));
4771+ cb (qrotated, " qrotated" , il);
4772+
4773+ krotated = ggml_cont (ctx0, ggml_permute (ctx0, krotated, 2 , 1 , 0 , 3 ));
4774+ cb (krotated, " krotated" , il);
4775+
4776+ qpass = ggml_cont (ctx0, ggml_permute (ctx0, qpass, 2 , 1 , 0 , 3 ));
4777+ cb (qpass, " qpass" , il);
4778+
4779+ kpass = ggml_cont (ctx0, ggml_permute (ctx0, kpass, 2 , 1 , 0 , 3 ));
4780+ cb (kpass, " kpass" , il);
4781+
4782+ struct ggml_tensor * Qcur = ggml_concat (ctx0, qrotated, qpass);
4783+ cb (Qcur, " Qcur" , il);
4784+
4785+ struct ggml_tensor * Kcur = ggml_concat (ctx0, krotated, kpass);
47344786 cb (Kcur, " Kcur" , il);
47354787
4788+ struct ggml_tensor * Q = ggml_cont (ctx0, ggml_permute (ctx0, Qcur, 2 , 1 , 0 , 3 ));
4789+ cb (Q, " Q" , il);
4790+
4791+ Kcur = ggml_cont (ctx0, ggml_permute (ctx0, Kcur, 2 , 1 , 0 , 3 ));
4792+ cb (Kcur, " Kcur" , il);
4793+
4794+ // Qcur = ggml_rope_custom(
4795+ // ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
4796+ // hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
4797+ // ext_factor, attn_factor, beta_fast, beta_slow
4798+ // );
4799+ // cb(Qcur, "Qcur", il);
4800+
4801+ // Kcur = ggml_rope_custom(
4802+ // ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
4803+ // hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
4804+ // ext_factor, attn_factor, beta_fast, beta_slow
4805+ // );
4806+ // cb(Kcur, "Kcur", il);
4807+
47364808 llm_build_kv_store (ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
47374809
47384810 cur = llm_build_kqv (ctx0, hparams, kv_self,
47394811 model.layers [il].wo , NULL ,
4740- Qcur , KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1 .0f , cb, il);
4812+ Q , KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1 .0f , cb, il);
47414813 cb (cur, " kqv_out" , il);
47424814 }
47434815
0 commit comments