Skip to content

Commit ec95c0e

Browse files
committed
ggml : add mxfp4
ggml : use e8m0 conversion instead of powf Co-authored-by: Diego Devesa <[email protected]> change kvalues_mxfp4 table to match e2m1 (#6) metal : remove quantization for now (not used) cuda : fix disabled CUDA graphs due to ffn moe bias vulkan : add support for mxfp4 cont : add cm2 dequant
1 parent 4cf69df commit ec95c0e

37 files changed

+1014
-60
lines changed

ggml/include/ggml.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,8 @@ extern "C" {
395395
// GGML_TYPE_IQ4_NL_4_4 = 36,
396396
// GGML_TYPE_IQ4_NL_4_8 = 37,
397397
// GGML_TYPE_IQ4_NL_8_8 = 38,
398-
GGML_TYPE_COUNT = 39,
398+
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
399+
GGML_TYPE_COUNT = 40,
399400
};
400401

401402
// precision
@@ -430,6 +431,7 @@ extern "C" {
430431
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
431432
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
432433
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
434+
GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
433435
};
434436

435437
// available tensor operations:

ggml/src/ggml-common.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ typedef sycl::half2 ggml_half2;
9999
#define QI4_1 (QK4_1 / (4 * QR4_1))
100100
#define QR4_1 2
101101

102+
#define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4))
103+
#define QR_MXFP4 2
104+
102105
#define QI5_0 (QK5_0 / (4 * QR5_0))
103106
#define QR5_0 2
104107

@@ -184,6 +187,13 @@ typedef struct {
184187
} block_q4_1;
185188
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
186189

190+
#define QK_MXFP4 32
191+
typedef struct {
192+
uint8_t e; // E8M0
193+
uint8_t qs[QK_MXFP4/2];
194+
} block_mxfp4;
195+
static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding");
196+
187197
#define QK5_0 32
188198
typedef struct {
189199
ggml_half d; // delta
@@ -1074,10 +1084,17 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
10741084
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
10751085
GGML_TABLE_END()
10761086

1087+
// TODO: fix name to kvalues_iq4_nl
10771088
GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
10781089
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
10791090
GGML_TABLE_END()
10801091

1092+
// e2m1 values (doubled)
1093+
// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
1094+
GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
1095+
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
1096+
GGML_TABLE_END()
1097+
10811098
#define NGRID_IQ1S 2048
10821099
#define IQ1S_DELTA 0.125f
10831100
#define IQ1M_DELTA 0.125f

ggml/src/ggml-cpu/arch-fallback.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
1414
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
1515
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
16+
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
1617
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
1718
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
1819
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@@ -168,6 +169,7 @@
168169
#elif defined(__wasm__)
169170
// quants.c
170171
#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
172+
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
171173
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
172174
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
173175
#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K

ggml/src/ggml-cpu/arch/arm/quants.c

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,67 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
589589
*s = sumf;
590590
}
591591

592+
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
593+
assert(nrc == 1);
594+
UNUSED(nrc);
595+
UNUSED(bx);
596+
UNUSED(by);
597+
UNUSED(bs);
598+
assert(n % QK_MXFP4 == 0);
599+
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
600+
601+
const block_mxfp4 * GGML_RESTRICT x = vx;
602+
const block_q8_0 * GGML_RESTRICT y = vy;
603+
604+
const int nb = n / QK_MXFP4;
605+
606+
int ib = 0;
607+
float sumf = 0;
608+
609+
#if defined __ARM_NEON
610+
const int8x16_t values = vld1q_s8(kvalues_mxfp4);
611+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
612+
uint8x16x2_t q4bits;
613+
int8x16x4_t q4b;
614+
int8x16x4_t q8b;
615+
int32x4_t prod_1;
616+
int32x4_t prod_2;
617+
618+
for (; ib + 1 < nb; ib += 2) {
619+
q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
620+
q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
621+
q8b.val[0] = vld1q_s8(y[ib + 0].qs);
622+
q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16);
623+
q8b.val[2] = vld1q_s8(y[ib + 1].qs);
624+
q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16);
625+
626+
q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
627+
q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
628+
q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
629+
q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
630+
631+
prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
632+
prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
633+
634+
sumf +=
635+
GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
636+
GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
637+
}
638+
639+
#endif
640+
for (; ib < nb; ++ib) {
641+
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
642+
int sumi1 = 0;
643+
int sumi2 = 0;
644+
for (int j = 0; j < QK_MXFP4/2; ++j) {
645+
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
646+
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
647+
}
648+
sumf += d * (sumi1 + sumi2);
649+
}
650+
*s = sumf;
651+
}
652+
592653
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
593654
const int qk = QK8_0;
594655
const int nb = n / qk;

ggml/src/ggml-cpu/arch/x86/quants.c

Lines changed: 96 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ static inline int hsum_i32_4(const __m128i a) {
6666
}
6767

6868
#if defined(__AVX2__) || defined(__AVX512F__)
69+
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
70+
const __m256i ax = _mm256_sign_epi8(x, x);
71+
const __m256i sy = _mm256_sign_epi8(y, x);
72+
return _mm256_maddubs_epi16(ax, sy);
73+
}
74+
6975
// spread 32 bits to 32 bytes { 0x00, 0xFF }
7076
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
7177
uint32_t x32;
@@ -261,6 +267,11 @@ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const
261267
return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),
262268
_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
263269
}
270+
271+
static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) {
272+
return _mm256_set_m128(_mm_set1_ps(GGML_E8M0(x1) * GGML_CPU_FP16_TO_FP32(y1)),
273+
_mm_set1_ps(GGML_E8M0(x0) * GGML_CPU_FP16_TO_FP32(y0)));
274+
}
264275
#endif
265276
#elif defined(__SSSE3__)
266277
// horizontally add 4x4 floats
@@ -746,6 +757,91 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
746757
#endif
747758
}
748759

760+
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
761+
assert(nrc == 1);
762+
UNUSED(nrc);
763+
UNUSED(bx);
764+
UNUSED(by);
765+
UNUSED(bs);
766+
assert(n % QK_MXFP4 == 0);
767+
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
768+
769+
const block_mxfp4 * GGML_RESTRICT x = vx;
770+
const block_q8_0 * GGML_RESTRICT y = vy;
771+
772+
const int nb = n / QK_MXFP4;
773+
774+
int ib = 0;
775+
float sumf = 0;
776+
777+
#if defined __AVX2__
778+
779+
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
780+
const __m128i m4b = _mm_set1_epi8(0x0f);
781+
const __m256i mone = _mm256_set1_epi16(1);
782+
783+
__m256 accum1 = _mm256_setzero_ps();
784+
__m256 accum2 = _mm256_setzero_ps();
785+
for (; ib + 1 < nb; ib += 2) {
786+
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
787+
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
788+
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
789+
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
790+
const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
791+
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
792+
const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
793+
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
794+
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
795+
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
796+
const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
797+
const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
798+
accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
799+
_mm256_cvtepi32_ps(p_1), accum1);
800+
accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
801+
_mm256_cvtepi32_ps(p_2), accum2);
802+
}
803+
804+
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
805+
806+
#elif defined __AVX__
807+
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
808+
const __m128i m4b = _mm_set1_epi8(0x0f);
809+
810+
__m256 accum = _mm256_setzero_ps();
811+
for (; ib + 1 < nb; ib += 2) {
812+
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
813+
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
814+
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
815+
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
816+
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
817+
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
818+
819+
const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
820+
const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
821+
const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
822+
const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
823+
824+
const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
825+
const __m256 deltas = quad_mx_delta_float(x[ib].e, y[ib].d, x[ib + 1].e, y[ib + 1].d);
826+
accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
827+
}
828+
829+
sumf = hsum_float_8(accum);
830+
831+
#endif
832+
for (; ib < nb; ++ib) {
833+
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
834+
int sumi1 = 0;
835+
int sumi2 = 0;
836+
for (int j = 0; j < QK_MXFP4/2; ++j) {
837+
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
838+
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
839+
}
840+
sumf += d * (sumi1 + sumi2);
841+
}
842+
*s = sumf;
843+
}
844+
749845
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
750846
const int qk = QK8_0;
751847
const int nb = n / qk;
@@ -3206,14 +3302,6 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
32063302
#endif
32073303
}
32083304

3209-
#if defined(__AVX2__)
3210-
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
3211-
const __m256i ax = _mm256_sign_epi8(x, x);
3212-
const __m256i sy = _mm256_sign_epi8(y, x);
3213-
return _mm256_maddubs_epi16(ax, sy);
3214-
}
3215-
#endif
3216-
32173305
void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
32183306
assert(n % QK_K == 0);
32193307
assert(nrc == 1);

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
253253
.vec_dot_type = GGML_TYPE_Q8_1,
254254
.nrows = 1,
255255
},
256+
[GGML_TYPE_MXFP4] = {
257+
.from_float = quantize_row_mxfp4,
258+
.vec_dot = ggml_vec_dot_mxfp4_q8_0,
259+
.vec_dot_type = GGML_TYPE_Q8_0,
260+
.nrows = 1,
261+
},
256262
[GGML_TYPE_Q2_K] = {
257263
.from_float = quantize_row_q2_K,
258264
.vec_dot = ggml_vec_dot_q2_K_q8_K,

ggml/src/ggml-cpu/ops.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,6 +1284,7 @@ void ggml_compute_forward_add(
12841284
case GGML_TYPE_Q5_0:
12851285
case GGML_TYPE_Q5_1:
12861286
case GGML_TYPE_Q8_0:
1287+
case GGML_TYPE_MXFP4:
12871288
case GGML_TYPE_Q2_K:
12881289
case GGML_TYPE_Q3_K:
12891290
case GGML_TYPE_Q4_K:
@@ -1661,6 +1662,7 @@ void ggml_compute_forward_add1(
16611662
case GGML_TYPE_Q5_1:
16621663
case GGML_TYPE_Q8_0:
16631664
case GGML_TYPE_Q8_1:
1665+
case GGML_TYPE_MXFP4:
16641666
case GGML_TYPE_Q2_K:
16651667
case GGML_TYPE_Q3_K:
16661668
case GGML_TYPE_Q4_K:
@@ -1788,6 +1790,7 @@ void ggml_compute_forward_acc(
17881790
case GGML_TYPE_Q5_1:
17891791
case GGML_TYPE_Q8_0:
17901792
case GGML_TYPE_Q8_1:
1793+
case GGML_TYPE_MXFP4:
17911794
case GGML_TYPE_Q2_K:
17921795
case GGML_TYPE_Q3_K:
17931796
case GGML_TYPE_Q4_K:
@@ -4687,6 +4690,7 @@ void ggml_compute_forward_out_prod(
46874690
case GGML_TYPE_Q5_0:
46884691
case GGML_TYPE_Q5_1:
46894692
case GGML_TYPE_Q8_0:
4693+
case GGML_TYPE_MXFP4:
46904694
case GGML_TYPE_Q2_K:
46914695
case GGML_TYPE_Q3_K:
46924696
case GGML_TYPE_Q4_K:
@@ -4961,6 +4965,7 @@ void ggml_compute_forward_set(
49614965
case GGML_TYPE_Q5_1:
49624966
case GGML_TYPE_Q8_0:
49634967
case GGML_TYPE_Q8_1:
4968+
case GGML_TYPE_MXFP4:
49644969
case GGML_TYPE_Q2_K:
49654970
case GGML_TYPE_Q3_K:
49664971
case GGML_TYPE_Q4_K:
@@ -5222,6 +5227,7 @@ void ggml_compute_forward_get_rows(
52225227
case GGML_TYPE_Q5_1:
52235228
case GGML_TYPE_Q8_0:
52245229
case GGML_TYPE_Q8_1:
5230+
case GGML_TYPE_MXFP4:
52255231
case GGML_TYPE_Q2_K:
52265232
case GGML_TYPE_Q3_K:
52275233
case GGML_TYPE_Q4_K:
@@ -5937,6 +5943,7 @@ void ggml_compute_forward_clamp(
59375943
case GGML_TYPE_Q5_1:
59385944
case GGML_TYPE_Q8_0:
59395945
case GGML_TYPE_Q8_1:
5946+
case GGML_TYPE_MXFP4:
59405947
case GGML_TYPE_Q2_K:
59415948
case GGML_TYPE_Q3_K:
59425949
case GGML_TYPE_Q4_K:

ggml/src/ggml-cpu/quants.c

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRI
4646
quantize_row_q8_1_ref(x, y, k);
4747
}
4848

49+
void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
50+
quantize_row_mxfp4_ref(x, y, k);
51+
}
52+
4953
//
5054
// 2-6 bit quantization in super-blocks
5155
//
@@ -181,6 +185,37 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
181185
*s = sumf;
182186
}
183187

188+
void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
189+
assert(nrc == 1);
190+
UNUSED(nrc);
191+
UNUSED(bx);
192+
UNUSED(by);
193+
UNUSED(bs);
194+
assert(n % QK_MXFP4 == 0);
195+
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
196+
197+
const block_mxfp4 * GGML_RESTRICT x = vx;
198+
const block_q8_0 * GGML_RESTRICT y = vy;
199+
200+
const int nb = n / QK_MXFP4;
201+
202+
int ib = 0;
203+
float sumf = 0;
204+
205+
for (; ib < nb; ++ib) {
206+
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
207+
208+
int sumi1 = 0;
209+
int sumi2 = 0;
210+
for (int j = 0; j < QK_MXFP4/2; ++j) {
211+
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
212+
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
213+
}
214+
sumf += d * (sumi1 + sumi2);
215+
}
216+
*s = sumf;
217+
}
218+
184219
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
185220
const int qk = QK8_0;
186221
const int nb = n / qk;

0 commit comments

Comments
 (0)