@@ -175,6 +175,30 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
175175 GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
176176 GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
177177 GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
178+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
179+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
180+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
181+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,
182+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,
183+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,
184+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,
185+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,
186+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,
187+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,
188+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,
189+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,
190+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,
191+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,
192+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,
193+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,
194+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,
195+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,
196+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,
197+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,
198+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,
199+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
200+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
201+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
178202 GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
179203 GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
180204 GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
@@ -699,6 +723,30 @@ @implementation GGMLMetalClass
699723 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
700724 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
701725 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
726+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
727+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
728+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
729+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
730+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
731+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
732+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
733+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
734+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
735+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
736+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
737+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
738+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
739+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
740+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
741+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
742+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
743+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
744+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
745+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
746+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
747+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
748+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
749+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
702750 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
703751 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
704752 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
@@ -1930,28 +1978,128 @@ static void ggml_metal_encode_node(
19301978 // to the matrix-vector kernel
19311979 int ne11_mm_min = 4 ;
19321980
1933- #if 0
1934- // the numbers below are measured on M2 Ultra for 7B and 13B models
1935- // these numbers do not translate to other devices or model sizes
1936- // TODO: need to find a better approach
1937- if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
1938- switch (src0t) {
1939- case GGML_TYPE_F16: ne11_mm_min = 2; break;
1940- case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1941- case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1942- case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1943- case GGML_TYPE_Q4_0:
1944- case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1945- case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1946- case GGML_TYPE_Q5_0: // not tested yet
1947- case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1948- case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1949- case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1950- default: ne11_mm_min = 1; break;
1951- }
1952- }
1953- #endif
1981+ if ((src0t == GGML_TYPE_F16 || // TODO: helper function
1982+ src0t == GGML_TYPE_Q4_0 ||
1983+ src0t == GGML_TYPE_Q4_1 ||
1984+ src0t == GGML_TYPE_Q5_0 ||
1985+ src0t == GGML_TYPE_Q5_1 ||
1986+ src0t == GGML_TYPE_Q8_0
1987+ ) &&
1988+ src1t == GGML_TYPE_F32 &&
1989+ (ne00%256 == 0 ) && // TODO: this can be relaxed to 128 for nxpsg == 8
1990+ (ne11 >= 2 && ne11 <= 8 )) {
1991+
1992+ // TODO: determine the optimal parameters based on grid utilization
1993+ const int nsg = 2 ; // TODO: or 4?
1994+ const int nxpsg = ne11 < 3 ? 16 : 8 ;
1995+ const int nypsg = 32 /nxpsg;
1996+ const int r0ptg = nypsg*nsg;
1997+ int r1ptg = 4 ;
1998+
1999+ switch (ne11) {
2000+ case 2 :
2001+ r1ptg = 2 ; break ;
2002+ case 3 :
2003+ case 6 :
2004+ r1ptg = 3 ; break ;
2005+ case 4 :
2006+ case 7 :
2007+ case 8 :
2008+ r1ptg = 4 ; break ;
2009+ case 5 :
2010+ r1ptg = 5 ; break ;
2011+ };
2012+
2013+ assert (nxpsg >= 8 );
2014+ assert (nxpsg%8 == 0 );
2015+
2016+ id <MTLComputePipelineState > pipeline = nil ;
2017+
2018+ switch (src0->type ) {
2019+ case GGML_TYPE_F16:
2020+ switch (r1ptg) {
2021+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline ; break ;
2022+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline ; break ;
2023+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline ; break ;
2024+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline ; break ;
2025+ default : GGML_ABORT (" not implemented" );
2026+ } break ;
2027+ case GGML_TYPE_Q4_0:
2028+ switch (r1ptg) {
2029+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline ; break ;
2030+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline ; break ;
2031+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline ; break ;
2032+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline ; break ;
2033+ default : GGML_ABORT (" not implemented" );
2034+ } break ;
2035+ case GGML_TYPE_Q4_1:
2036+ switch (r1ptg) {
2037+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline ; break ;
2038+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline ; break ;
2039+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline ; break ;
2040+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline ; break ;
2041+ default : GGML_ABORT (" not implemented" );
2042+ } break ;
2043+ case GGML_TYPE_Q5_0:
2044+ switch (r1ptg) {
2045+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline ; break ;
2046+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline ; break ;
2047+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline ; break ;
2048+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline ; break ;
2049+ default : GGML_ABORT (" not implemented" );
2050+ } break ;
2051+ case GGML_TYPE_Q5_1:
2052+ switch (r1ptg) {
2053+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline ; break ;
2054+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline ; break ;
2055+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline ; break ;
2056+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline ; break ;
2057+ default : GGML_ABORT (" not implemented" );
2058+ } break ;
2059+ case GGML_TYPE_Q8_0:
2060+ switch (r1ptg) {
2061+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline ; break ;
2062+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline ; break ;
2063+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline ; break ;
2064+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline ; break ;
2065+ default : GGML_ABORT (" not implemented" );
2066+ } break ;
2067+ default : GGML_ABORT (" not implemented" );
2068+ }
2069+
2070+ ggml_metal_kargs_mul_mv_ext args = {
2071+ /* .ne00 =*/ ne00,
2072+ /* .ne01 =*/ ne01,
2073+ /* .ne02 =*/ ne02,
2074+ /* .nb00 =*/ nb00,
2075+ /* .nb01 =*/ nb01,
2076+ /* .nb02 =*/ nb02,
2077+ /* .nb03 =*/ nb03,
2078+ /* .ne10 =*/ ne10,
2079+ /* .ne11 =*/ ne11,
2080+ /* .ne12 =*/ ne12,
2081+ /* .nb10 =*/ nb10,
2082+ /* .nb11 =*/ nb11,
2083+ /* .nb12 =*/ nb12,
2084+ /* .nb13 =*/ nb13,
2085+ /* .ne0 =*/ ne0,
2086+ /* .ne1 =*/ ne1,
2087+ /* .r2 =*/ r2,
2088+ /* .r3 =*/ r3,
2089+ /* .nsg =*/ nsg,
2090+ /* .nxpsg =*/ nxpsg,
2091+ /* .r1ptg =*/ r1ptg,
2092+ };
2093+
2094+ [encoder setComputePipelineState: pipeline];
2095+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
2096+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
2097+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
2098+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
19542099
2100+ // printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
2101+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + r0ptg - 1 )/r0ptg, (ne11 + r1ptg - 1 )/r1ptg, ne12*ne13) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
2102+ } else
19552103 // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
19562104 // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
19572105 if ([device supportsFamily: MTLGPUFamilyApple7] &&
0 commit comments