@@ -6436,7 +6436,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
64366436 struct ggml_tensor * k,
64376437 struct ggml_tensor * v,
64386438 struct ggml_tensor * mask,
6439- float scale) {
6439+ float scale,
6440+ float max_bias) {
64406441 GGML_ASSERT(ggml_can_mul_mat(k, q));
64416442 // TODO: check if vT can be multiplied by (k*qT)
64426443 if (mask) {
@@ -6458,7 +6459,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
64586459 int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
64596460 struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
64606461
6461- float params[] = { scale };
6462+ float params[] = { scale, max_bias };
64626463 ggml_set_op_params(result, params, sizeof(params));
64636464
64646465 result->op = GGML_OP_FLASH_ATTN_EXT;
@@ -6478,7 +6479,7 @@ void ggml_flash_attn_ext_set_prec(
64786479
64796480 const int32_t prec_i32 = (int32_t) prec;
64806481
6481- ggml_set_op_params_i32(a, 1 , prec_i32); // scale is on first pos
6482+ ggml_set_op_params_i32(a, 2 , prec_i32); // scale is on first pos, max_bias on second
64826483}
64836484
64846485// ggml_flash_ff
@@ -13308,8 +13309,8 @@ static void ggml_compute_forward_soft_max_f32(
1330813309
1330913310 // TODO: is this supposed to be ceil instead of floor?
1331013311 // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
13311- const uint32_t n_head_kv = ne02;
13312- const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv ));
13312+ const uint32_t n_head = ne02;
13313+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head ));
1331313314
1331413315 const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1331513316 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -15524,8 +15525,17 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1552415525 const int ir0 = dr*ith;
1552515526 const int ir1 = MIN(ir0 + dr, nr);
1552615527
15527- float scale = 1.0f;
15528- memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15528+ float scale = 1.0f;
15529+ float max_bias = 0.0f;
15530+
15531+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15532+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
15533+
15534+ const uint32_t n_head = neq2;
15535+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
15536+
15537+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
15538+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1552915539
1553015540 // loop over n_batch and n_head
1553115541 for (int ir = ir0; ir < ir1; ++ir) {
@@ -15534,6 +15544,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1553415544 const int iq2 = (ir - iq3*neq2*neq1)/neq1;
1553515545 const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
1553615546
15547+ const int h = iq2; // head
15548+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
15549+
1553715550 float S = 0.0f;
1553815551 float M = -INFINITY;
1553915552
@@ -15557,7 +15570,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1555715570 // loop over n_kv and n_head_kv
1555815571 // ref: https://arxiv.org/pdf/2112.05682.pdf
1555915572 for (int64_t ic = 0; ic < nek1; ++ic) {
15560- const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
15573+ const float mv = mp ? slope* GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
1556115574 if (mv == -INFINITY) {
1556215575 continue;
1556315576 }
@@ -15628,7 +15641,7 @@ static void ggml_compute_forward_flash_attn_ext(
1562815641 const struct ggml_tensor * v,
1562915642 const struct ggml_tensor * mask,
1563015643 struct ggml_tensor * dst) {
15631- switch (dst->op_params[1 ]) {
15644+ switch (dst->op_params[2 ]) {
1563215645 case GGML_PREC_DEFAULT:
1563315646 case GGML_PREC_F32:
1563415647 {
0 commit comments