@@ -413,10 +413,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
413413 int i = 0;
414414#if defined(__AVX512BF16__)
415415 for (; i + 32 <= n; i += 32) {
416- _mm512_storeu_ps (
417- (__m512 *)(y + i),
418- (__m512) _mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
419- _mm512_loadu_ps(x + i)));
416+ _mm512_storeu_si512 (
417+ (__m512i *)(y + i),
418+ m512i( _mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
419+ _mm512_loadu_ps(x + i) )));
420420 }
421421#endif
422422 for (; i < n; i++) {
@@ -1618,10 +1618,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
16181618 __m512 c1 = _mm512_setzero_ps();
16191619 __m512 c2 = _mm512_setzero_ps();
16201620 for (; i + 64 <= n; i += 64) {
1621- c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)( x + i)),
1622- (__m512bh)_mm512_loadu_ps((const float *)( y + i)));
1623- c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)( x + i + 32)),
1624- (__m512bh)_mm512_loadu_ps((const float *)( y + i + 32)));
1621+ c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512(( x + i) )),
1622+ m512bh(_mm512_loadu_si512(( y + i) )));
1623+ c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512(( x + i + 32) )),
1624+ m512bh(_mm512_loadu_si512(( y + i + 32) )));
16251625 }
16261626 sumf += (ggml_float)_mm512_reduce_add_ps(c1);
16271627 sumf += (ggml_float)_mm512_reduce_add_ps(c2);
@@ -22873,6 +22873,14 @@ int ggml_cpu_has_avx512_vnni(void) {
2287322873#endif
2287422874}
2287522875
22876+ int ggml_cpu_has_avx512_bf16(void) {
22877+ #if defined(__AVX512BF16__)
22878+ return 1;
22879+ #else
22880+ return 0;
22881+ #endif
22882+ }
22883+
2287622884int ggml_cpu_has_fma(void) {
2287722885#if defined(__FMA__)
2287822886 return 1;
0 commit comments