@@ -4829,7 +4829,9 @@ static struct ggml_tensor * ggml_soft_max_impl(
48294829 struct ggml_tensor * mask,
48304830 float scale,
48314831 bool inplace) {
4832+ GGML_ASSERT(ggml_is_contiguous(a));
48324833 if (mask) {
4834+ GGML_ASSERT(ggml_is_contiguous(mask));
48334835 GGML_ASSERT(mask->ne[2] == 1);
48344836 GGML_ASSERT(mask->ne[3] == 1);
48354837 GGML_ASSERT(ggml_can_repeat_rows(mask, a));
@@ -10571,20 +10573,25 @@ static void ggml_compute_forward_diag_mask_zero(
1057110573static void ggml_compute_forward_soft_max_f32(
1057210574 const struct ggml_compute_params * params,
1057310575 const struct ggml_tensor * src0,
10574- struct ggml_tensor * dst) {
10575- GGML_ASSERT(ggml_is_contiguous(src0));
10576- GGML_ASSERT (ggml_is_contiguous(dst));
10577- GGML_ASSERT (ggml_are_same_shape(src0, dst));
10576+ const struct ggml_tensor * src1,
10577+ struct ggml_tensor * dst) {
10578+ assert (ggml_is_contiguous(dst));
10579+ assert (ggml_are_same_shape(src0, dst));
1057810580
1057910581 if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1058010582 return;
1058110583 }
1058210584
10585+ float scale = 1.0f;
10586+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
10587+
1058310588 // TODO: handle transposed/permuted matrices
1058410589
1058510590 const int ith = params->ith;
1058610591 const int nth = params->nth;
1058710592
10593+ const int64_t ne11 = src1 ? src1->ne[1] : 1;
10594+
1058810595 const int nc = src0->ne[0];
1058910596 const int nr = ggml_nrows(src0);
1059010597
@@ -10595,29 +10602,39 @@ static void ggml_compute_forward_soft_max_f32(
1059510602 const int ir0 = dr*ith;
1059610603 const int ir1 = MIN(ir0 + dr, nr);
1059710604
10605+ float * wdata = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
10606+
1059810607 for (int i1 = ir0; i1 < ir1; i1++) {
10599- float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
10600- float *dp = (float *)((char *) dst->data + i1*dst->nb[1]);
10608+ float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
10609+ float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
10610+
10611+ // broadcast the mask across rows
10612+ float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
10613+
10614+ float * wp = wdata;
10615+ for (int i = 0; i < nc; i++) {
10616+ wp[i] = sp[i]*scale + (mp ? mp[i] : 0.0f);
10617+ }
1060110618
1060210619#ifndef NDEBUG
1060310620 for (int i = 0; i < nc; ++i) {
1060410621 //printf("p[%d] = %f\n", i, p[i]);
10605- assert(!isnan(sp [i]));
10622+ assert(!isnan(wp [i]));
1060610623 }
1060710624#endif
1060810625
1060910626 float max = -INFINITY;
10610- ggml_vec_max_f32(nc, &max, sp );
10627+ ggml_vec_max_f32(nc, &max, wp );
1061110628
1061210629 ggml_float sum = 0.0;
1061310630
1061410631 uint16_t scvt;
1061510632 for (int i = 0; i < nc; i++) {
10616- if (sp [i] == -INFINITY) {
10633+ if (wp [i] == -INFINITY) {
1061710634 dp[i] = 0.0f;
1061810635 } else {
10619- // const float val = (sp [i] == -INFINITY) ? 0.0 : exp(sp [i] - max);
10620- ggml_fp16_t s = GGML_FP32_TO_FP16(sp [i] - max);
10636+ // const float val = (wp [i] == -INFINITY) ? 0.0 : exp(wp [i] - max);
10637+ ggml_fp16_t s = GGML_FP32_TO_FP16(wp [i] - max);
1062110638 memcpy(&scvt, &s, sizeof(scvt));
1062210639 const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
1062310640 sum += (ggml_float)val;
@@ -10642,11 +10659,12 @@ static void ggml_compute_forward_soft_max_f32(
1064210659static void ggml_compute_forward_soft_max(
1064310660 const struct ggml_compute_params * params,
1064410661 const struct ggml_tensor * src0,
10645- struct ggml_tensor * dst) {
10662+ const struct ggml_tensor * src1,
10663+ struct ggml_tensor * dst) {
1064610664 switch (src0->type) {
1064710665 case GGML_TYPE_F32:
1064810666 {
10649- ggml_compute_forward_soft_max_f32(params, src0, dst);
10667+ ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
1065010668 } break;
1065110669 default:
1065210670 {
@@ -13883,7 +13901,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1388313901 } break;
1388413902 case GGML_OP_SOFT_MAX:
1388513903 {
13886- ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
13904+ ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor );
1388713905 } break;
1388813906 case GGML_OP_SOFT_MAX_BACK:
1388913907 {
@@ -15919,6 +15937,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1591915937 cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
1592015938 }
1592115939 } break;
15940+ case GGML_OP_SOFT_MAX:
15941+ {
15942+ n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
15943+
15944+ cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
15945+ } break;
1592215946 case GGML_OP_CONV_TRANSPOSE_1D:
1592315947 {
1592415948 GGML_ASSERT(node->src[0]->ne[3] == 1);
0 commit comments