@@ -447,14 +447,13 @@ kernel void kernel_rms_norm(
447447 constant int64_t & ne00,
448448 constant uint64_t & nb01,
449449 constant float & eps,
450- threadgroup float * sum [[threadgroup(0 )]],
450+ threadgroup float * buf [[threadgroup(0 )]],
451451 uint tgpig[[threadgroup_position_in_grid]],
452452 uint tpitg[[thread_position_in_threadgroup]],
453453 uint sgitg[[simdgroup_index_in_threadgroup]],
454454 uint tiisg[[thread_index_in_simdgroup]],
455455 uint ntg[[threads_per_threadgroup]]) {
456- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
457- device const float * x_scalar = (device const float *) x;
456+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
458457
459458 float4 sumf = 0 ;
460459 float all_sum = 0 ;
@@ -465,40 +464,30 @@ kernel void kernel_rms_norm(
465464 }
466465 all_sum = sumf[0 ] + sumf[1 ] + sumf[2 ] + sumf[3 ];
467466 all_sum = simd_sum (all_sum);
468- if (tiisg == 0 ) {
469- sum[sgitg] = all_sum;
470- }
467+ if (ntg > N_SIMDWIDTH) {
468+ if (sgitg == 0 ) {
469+ buf[tiisg] = 0 .0f ;
470+ }
471471
472- threadgroup_barrier (mem_flags::mem_threadgroup);
472+ threadgroup_barrier (mem_flags::mem_threadgroup);
473473
474- // broadcast, simd group number is ntg / 32
475- for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
476- if (tpitg < i) {
477- sum[tpitg] += sum[tpitg + i];
478- }
479- }
480- if (tpitg == 0 ) {
481- for (int i = 4 * (ne00 / 4 ); i < ne00; i++) {
482- sum[0 ] += x_scalar[i];
474+ if (tiisg == 0 ) {
475+ buf[sgitg] = all_sum;
483476 }
484- sum[0 ] /= ne00;
485- }
486477
487- threadgroup_barrier (mem_flags::mem_threadgroup);
478+ threadgroup_barrier (mem_flags::mem_threadgroup);
488479
489- const float mean = sum[0 ];
480+ all_sum = buf[tiisg];
481+ all_sum = simd_sum (all_sum);
482+ }
483+
484+ const float mean = all_sum/ne00;
490485 const float scale = 1 .0f /sqrt (mean + eps);
491486
492487 device float4 * y = (device float4 *) (dst + tgpig*ne00);
493- device float * y_scalar = (device float *) y;
494488 for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
495489 y[i00] = x[i00] * scale;
496490 }
497- if (tpitg == 0 ) {
498- for (int i00 = 4 * (ne00 / 4 ); i00 < ne00; i00++) {
499- y_scalar[i00] = x_scalar[i00] * scale;
500- }
501- }
502491}
503492
504493// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
0 commit comments