@@ -5058,39 +5058,32 @@ kernel void kernel_mul_mv_q6_K_f32(
50585058
50595059// ======================= Ternary
50605060
5061+ template <typename args_t >
50615062void kernel_mul_mv_tq2_0_f32_impl (
5062- device const void * src0,
5063- device const float * src1,
5064- device float * dst,
5065- int64_t ne00,
5066- int64_t ne01,
5067- int64_t ne02,
5068- int64_t ne10,
5069- int64_t ne12,
5070- int64_t ne0,
5071- int64_t ne1,
5072- uint r2,
5073- uint r3,
5074- threadgroup int8_t * shared_values,
5075- uint3 tgpig,
5076- uint tiisg,
5077- uint sgitg) {
5078-
5079- const int nb = ne00/QK_K;
5063+ args_t args,
5064+ device const char * src0,
5065+ device const char * src1,
5066+ device char * dst,
5067+ threadgroup char * shmem,
5068+ uint3 tgpig,
5069+ ushort tiisg,
5070+ ushort sgitg) {
5071+
5072+ const int nb = args.ne00 /QK_K;
50805073 const int r0 = tgpig.x ;
50815074 const int r1 = tgpig.y ;
50825075 const int im = tgpig.z ;
50835076
50845077 const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
50855078 const int ib_row = first_row * nb;
50865079
5087- const uint i12 = im%ne12;
5088- const uint i13 = im/ne12;
5080+ const uint i12 = im%args. ne12 ;
5081+ const uint i13 = im/args. ne12 ;
50895082
5090- const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
5083+ const uint offset0 = (i12/args. r2 )*(nb*args. ne01 ) + (i13/args. r3 )*(nb*args. ne01 *args. ne02 );
50915084
50925085 device const block_tq2_0 * x = (device const block_tq2_0 *) src0 + ib_row + offset0;
5093- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
5086+ device const float * y = (device const float *) src1 + r1*args. ne10 + im*args. ne00 *args. ne1 ;
50945087
50955088 float yl[32 ];
50965089 float sumf[N_DST]={0 .f }, all_sum;
@@ -5144,40 +5137,27 @@ void kernel_mul_mv_tq2_0_f32_impl(
51445137 y4 += 4 * QK_K;
51455138 }
51465139
5140+ device float * dst_f32 = (device float *) dst + (uint64_t )im*args.ne0 *args.ne1 + (uint64_t )r1*args.ne0 ;
5141+
51475142 for (int row = 0 ; row < N_DST; ++row) {
51485143 all_sum = simd_sum (sumf[row]);
51495144 if (tiisg == 0 ) {
5150- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
5145+ dst_f32[ first_row + row] = all_sum;
51515146 }
51525147 }
51535148}
51545149
51555150[[host_name(" kernel_mul_mv_tq2_0_f32" )]]
51565151kernel void kernel_mul_mv_tq2_0_f32 (
5157- device const void * src0,
5158- device const float * src1,
5159- device float * dst,
5160- constant int64_t & ne00,
5161- constant int64_t & ne01,
5162- constant int64_t & ne02,
5163- constant uint64_t & nb00,
5164- constant uint64_t & nb01,
5165- constant uint64_t & nb02,
5166- constant int64_t & ne10,
5167- constant int64_t & ne11,
5168- constant int64_t & ne12,
5169- constant uint64_t & nb10,
5170- constant uint64_t & nb11,
5171- constant uint64_t & nb12,
5172- constant int64_t & ne0,
5173- constant int64_t & ne1,
5174- constant uint & r2,
5175- constant uint & r3,
5176- uint3 tgpig[[threadgroup_position_in_grid]],
5177- uint tiisg[[thread_index_in_simdgroup]],
5178- uint sgitg[[simdgroup_index_in_threadgroup]]) {
5152+ constant ggml_metal_kargs_mul_mv & args,
5153+ device const char * src0,
5154+ device const char * src1,
5155+ device char * dst,
5156+ uint3 tgpig[[threadgroup_position_in_grid]],
5157+ ushort tiisg[[thread_index_in_simdgroup]],
5158+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
51795159
5180- kernel_mul_mv_tq2_0_f32_impl (src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3 , nullptr , tgpig, tiisg, sgitg);
5160+ kernel_mul_mv_tq2_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst , nullptr , tgpig, tiisg, sgitg);
51815161}
51825162
51835163// ======================= "True" 2-bit
0 commit comments