@@ -4814,92 +4814,34 @@ struct llm_build_context {
48144814 // self-attention
48154815 {
48164816 // compute Q and K and RoPE them
4817- struct ggml_tensor * tmpq = ggml_mul_mat (ctx0, model.layers [il].wq , cur);
4818- cb (tmpq , " tmpq " , il);
4817+ struct ggml_tensor * Qcur = ggml_mul_mat (ctx0, model.layers [il].wq , cur);
4818+ cb (Qcur , " Qcur " , il);
48194819
4820- struct ggml_tensor * tmpk = ggml_mul_mat (ctx0, model.layers [il].wk , cur);
4821- cb (tmpk , " tmpk " , il);
4820+ struct ggml_tensor * Kcur = ggml_mul_mat (ctx0, model.layers [il].wk , cur);
4821+ cb (Kcur , " Kcur " , il);
48224822
48234823 struct ggml_tensor * Vcur = ggml_mul_mat (ctx0, model.layers [il].wv , cur);
48244824 cb (Vcur, " Vcur" , il);
48254825
4826- // RoPE the first n_rot of q/k, pass the other half, and concat.
4827- struct ggml_tensor * qrot = ggml_cont (ctx0, ggml_view_3d (
4828- ctx0, tmpq, hparams.n_rot , n_head, n_tokens,
4829- ggml_element_size (tmpq) * n_embd_head,
4830- ggml_element_size (tmpq) * n_embd_head * n_head,
4831- 0
4832- ));
4833- cb (qrot, " qrot" , il);
4834-
4835- struct ggml_tensor * krot = ggml_cont (ctx0, ggml_view_3d (
4836- ctx0, tmpk, hparams.n_rot , n_head, n_tokens,
4837- ggml_element_size (tmpk) * n_embd_head,
4838- ggml_element_size (tmpk) * n_embd_head * n_head_kv,
4839- 0
4840- ));
4841- cb (krot, " krot" , il);
4842-
4843- // get the second half of tmpq, e.g tmpq[n_rot:, :, :]
4844- struct ggml_tensor * qpass = ggml_view_3d (
4845- ctx0, tmpq, (n_embd_head - hparams.n_rot ), n_head, n_tokens,
4846- ggml_element_size (tmpq) * n_embd_head,
4847- ggml_element_size (tmpq) * n_embd_head * n_head,
4848- ggml_element_size (tmpq) * hparams.n_rot
4849- );
4850- cb (qpass, " qpass" , il);
4851-
4852- struct ggml_tensor * kpass = ggml_view_3d (
4853- ctx0, tmpk, (n_embd_head - hparams.n_rot ), n_head_kv, n_tokens,
4854- ggml_element_size (tmpk) * (n_embd_head),
4855- ggml_element_size (tmpk) * (n_embd_head) * n_head_kv,
4856- ggml_element_size (tmpk) * hparams.n_rot
4857- );
4858- cb (kpass, " kpass" , il);
4859-
4860- struct ggml_tensor * qrotated = ggml_rope_custom (
4861- ctx0, qrot, inp_pos, hparams.n_rot , 2 , 0 , n_orig_ctx,
4862- freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
4863- );
4864- cb (qrotated, " qrotated" , il);
4865-
4866- struct ggml_tensor * krotated = ggml_rope_custom (
4867- ctx0, krot, inp_pos, hparams.n_rot , 2 , 0 , n_orig_ctx,
4868- freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
4826+ Qcur = ggml_rope_custom (
4827+ ctx0, ggml_reshape_3d (ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
4828+ hparams.n_rot , 2 , 0 , n_orig_ctx, freq_base, freq_scale,
4829+ ext_factor, attn_factor, beta_fast, beta_slow
48694830 );
4870- cb (krotated, " krotated" , il);
4871-
4872- // ggml currently only supports concatenation on dim=2
4873- // so we need to permute qrot, qpass, concat, then permute back.
4874- qrotated = ggml_cont (ctx0, ggml_permute (ctx0, qrotated, 2 , 1 , 0 , 3 ));
4875- cb (qrotated, " qrotated" , il);
4876-
4877- krotated = ggml_cont (ctx0, ggml_permute (ctx0, krotated, 2 , 1 , 0 , 3 ));
4878- cb (krotated, " krotated" , il);
4879-
4880- qpass = ggml_cont (ctx0, ggml_permute (ctx0, qpass, 2 , 1 , 0 , 3 ));
4881- cb (qpass, " qpass" , il);
4882-
4883- kpass = ggml_cont (ctx0, ggml_permute (ctx0, kpass, 2 , 1 , 0 , 3 ));
4884- cb (kpass, " kpass" , il);
4885-
4886- struct ggml_tensor * Qcur = ggml_concat (ctx0, qrotated, qpass);
48874831 cb (Qcur, " Qcur" , il);
48884832
4889- struct ggml_tensor * Kcur = ggml_concat (ctx0, krotated, kpass);
4890- cb (Kcur, " Kcur" , il);
4891-
4892- struct ggml_tensor * Q = ggml_cont (ctx0, ggml_permute (ctx0, Qcur, 2 , 1 , 0 , 3 ));
4893- cb (Q, " Q" , il);
4894-
4895- Kcur = ggml_cont (ctx0, ggml_permute (ctx0, Kcur, 2 , 1 , 0 , 3 ));
4833+ Kcur = ggml_rope_custom (
4834+ ctx0, ggml_reshape_3d (ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
4835+ hparams.n_rot , 2 , 0 , n_orig_ctx, freq_base, freq_scale,
4836+ ext_factor, attn_factor, beta_fast, beta_slow
4837+ );
48964838 cb (Kcur, " Kcur" , il);
48974839
48984840 llm_build_kv_store (ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
48994841
49004842 cur = llm_build_kqv (ctx0, hparams, kv_self,
49014843 model.layers [il].wo , NULL ,
4902- Q , KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1 .0f , cb, il);
4844+ Qcur , KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1 .0f , cb, il);
49034845 cb (cur, " kqv_out" , il);
49044846 }
49054847
0 commit comments