@@ -66,6 +66,12 @@ static inline int hsum_i32_4(const __m128i a) {
6666}
6767
6868#if defined(__AVX2__ ) || defined(__AVX512F__ )
69+ static inline __m256i mul_add_epi8 (const __m256i x , const __m256i y ) {
70+ const __m256i ax = _mm256_sign_epi8 (x , x );
71+ const __m256i sy = _mm256_sign_epi8 (y , x );
72+ return _mm256_maddubs_epi16 (ax , sy );
73+ }
74+
6975// spread 32 bits to 32 bytes { 0x00, 0xFF }
7076static inline __m256i bytes_from_bits_32 (const uint8_t * x ) {
7177 uint32_t x32 ;
@@ -261,6 +267,11 @@ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const
261267 return _mm256_set_m128 (_mm_set1_ps (GGML_CPU_FP16_TO_FP32 (x1 ) * GGML_CPU_FP16_TO_FP32 (y1 )),
262268 _mm_set1_ps (GGML_CPU_FP16_TO_FP32 (x0 ) * GGML_CPU_FP16_TO_FP32 (y0 )));
263269}
270+
271+ static inline __m256 quad_mx_delta_float (const int8_t x0 , const float y0 , const int8_t x1 , const float y1 ) {
272+ return _mm256_set_m128 (_mm_set1_ps (GGML_E8M0 (x1 ) * GGML_CPU_FP16_TO_FP32 (y1 )),
273+ _mm_set1_ps (GGML_E8M0 (x0 ) * GGML_CPU_FP16_TO_FP32 (y0 )));
274+ }
264275#endif
265276#elif defined(__SSSE3__ )
266277// horizontally add 4x4 floats
@@ -746,6 +757,91 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
746757#endif
747758}
748759
760+ void ggml_vec_dot_mxfp4_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
761+ assert (nrc == 1 );
762+ UNUSED (nrc );
763+ UNUSED (bx );
764+ UNUSED (by );
765+ UNUSED (bs );
766+ assert (n % QK_MXFP4 == 0 );
767+ static_assert (QK_MXFP4 == QK8_0 , "QK_MXFP4 and QK8_0 must be the same" );
768+
769+ const block_mxfp4 * GGML_RESTRICT x = vx ;
770+ const block_q8_0 * GGML_RESTRICT y = vy ;
771+
772+ const int nb = n / QK_MXFP4 ;
773+
774+ int ib = 0 ;
775+ float sumf = 0 ;
776+
777+ #if defined __AVX2__
778+
779+ const __m128i values128 = _mm_loadu_si128 ((const __m128i * )kvalues_mxfp4 );
780+ const __m128i m4b = _mm_set1_epi8 (0x0f );
781+ const __m256i mone = _mm256_set1_epi16 (1 );
782+
783+ __m256 accum1 = _mm256_setzero_ps ();
784+ __m256 accum2 = _mm256_setzero_ps ();
785+ for (; ib + 1 < nb ; ib += 2 ) {
786+ const __m128i q4bits_1 = _mm_loadu_si128 ((const __m128i * )x [ib + 0 ].qs );
787+ const __m128i q4bits_2 = _mm_loadu_si128 ((const __m128i * )x [ib + 1 ].qs );
788+ const __m256i q8b_1 = _mm256_loadu_si256 ((const __m256i * )y [ib + 0 ].qs );
789+ const __m256i q8b_2 = _mm256_loadu_si256 ((const __m256i * )y [ib + 1 ].qs );
790+ const __m256i q4b_1 = MM256_SET_M128I (_mm_shuffle_epi8 (values128 , _mm_and_si128 (_mm_srli_epi16 (q4bits_1 , 4 ), m4b )),
791+ _mm_shuffle_epi8 (values128 , _mm_and_si128 (q4bits_1 , m4b )));
792+ const __m256i q4b_2 = MM256_SET_M128I (_mm_shuffle_epi8 (values128 , _mm_and_si128 (_mm_srli_epi16 (q4bits_2 , 4 ), m4b )),
793+ _mm_shuffle_epi8 (values128 , _mm_and_si128 (q4bits_2 , m4b )));
794+ const __m256i p16_1 = mul_add_epi8 (q4b_1 , q8b_1 );
795+ const __m256i p16_2 = mul_add_epi8 (q4b_2 , q8b_2 );
796+ const __m256i p_1 = _mm256_madd_epi16 (p16_1 , mone );
797+ const __m256i p_2 = _mm256_madd_epi16 (p16_2 , mone );
798+ accum1 = _mm256_fmadd_ps (_mm256_set1_ps (GGML_CPU_FP16_TO_FP32 (y [ib + 0 ].d )* GGML_E8M0_TO_FP32_HALF (x [ib + 0 ].e )),
799+ _mm256_cvtepi32_ps (p_1 ), accum1 );
800+ accum2 = _mm256_fmadd_ps (_mm256_set1_ps (GGML_CPU_FP16_TO_FP32 (y [ib + 1 ].d )* GGML_E8M0_TO_FP32_HALF (x [ib + 1 ].e )),
801+ _mm256_cvtepi32_ps (p_2 ), accum2 );
802+ }
803+
804+ sumf = hsum_float_8 (_mm256_add_ps (accum1 , accum2 ));
805+
806+ #elif defined __AVX__
807+ const __m128i values128 = _mm_loadu_si128 ((const __m128i * )kvalues_mxfp4 );
808+ const __m128i m4b = _mm_set1_epi8 (0x0f );
809+
810+ __m256 accum = _mm256_setzero_ps ();
811+ for (; ib + 1 < nb ; ib += 2 ) {
812+ const __m128i q4bits_1 = _mm_loadu_si128 ((const __m128i * )x [ib + 0 ].qs );
813+ const __m128i q4bits_2 = _mm_loadu_si128 ((const __m128i * )x [ib + 1 ].qs );
814+ const __m128i q8b_1_0 = _mm_loadu_si128 ((const __m128i * )y [ib + 0 ].qs );
815+ const __m128i q8b_1_1 = _mm_loadu_si128 ((const __m128i * )y [ib + 0 ].qs + 1 );
816+ const __m128i q8b_2_0 = _mm_loadu_si128 ((const __m128i * )y [ib + 1 ].qs );
817+ const __m128i q8b_2_1 = _mm_loadu_si128 ((const __m128i * )y [ib + 1 ].qs + 1 );
818+
819+ const __m128i q4b_1_0 = _mm_shuffle_epi8 (values128 , _mm_and_si128 (q4bits_1 , m4b ));
820+ const __m128i q4b_1_1 = _mm_shuffle_epi8 (values128 , _mm_and_si128 (_mm_srli_epi16 (q4bits_1 , 4 ), m4b ));
821+ const __m128i q4b_2_0 = _mm_shuffle_epi8 (values128 , _mm_and_si128 (q4bits_2 , m4b ));
822+ const __m128i q4b_2_1 = _mm_shuffle_epi8 (values128 , _mm_and_si128 (_mm_srli_epi16 (q4bits_2 , 4 ), m4b ));
823+
824+ const __m256 p = mul_sum_i8_quad_float (q4b_1_0 , q4b_1_1 , q4b_2_0 , q4b_2_1 , q8b_1_0 , q8b_1_1 , q8b_2_0 , q8b_2_1 );
825+ const __m256 deltas = quad_mx_delta_float (x [ib ].e , y [ib ].d , x [ib + 1 ].e , y [ib + 1 ].d );
826+ accum = _mm256_add_ps (_mm256_mul_ps (deltas , p ), accum );
827+ }
828+
829+ sumf = hsum_float_8 (accum );
830+
831+ #endif
832+ for (; ib < nb ; ++ ib ) {
833+ const float d = GGML_CPU_FP16_TO_FP32 (y [ib ].d )* GGML_E8M0_TO_FP32_HALF (x [ib ].e );
834+ int sumi1 = 0 ;
835+ int sumi2 = 0 ;
836+ for (int j = 0 ; j < QK_MXFP4 /2 ; ++ j ) {
837+ sumi1 += y [ib ].qs [j + 0 ] * kvalues_mxfp4 [x [ib ].qs [j ] & 0xf ];
838+ sumi2 += y [ib ].qs [j + QK_MXFP4 /2 ] * kvalues_mxfp4 [x [ib ].qs [j ] >> 4 ];
839+ }
840+ sumf += d * (sumi1 + sumi2 );
841+ }
842+ * s = sumf ;
843+ }
844+
749845void ggml_vec_dot_q5_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
750846 const int qk = QK8_0 ;
751847 const int nb = n / qk ;
@@ -3206,14 +3302,6 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
32063302#endif
32073303}
32083304
3209- #if defined(__AVX2__ )
3210- static inline __m256i mul_add_epi8 (const __m256i x , const __m256i y ) {
3211- const __m256i ax = _mm256_sign_epi8 (x , x );
3212- const __m256i sy = _mm256_sign_epi8 (y , x );
3213- return _mm256_maddubs_epi16 (ax , sy );
3214- }
3215- #endif
3216-
32173305void ggml_vec_dot_iq1_s_q8_K (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
32183306 assert (n % QK_K == 0 );
32193307 assert (nrc == 1 );
0 commit comments