88#include " ggml-cpu-fp8.h"
99
1010namespace fp8 {
11+ union fp32_int32 {
12+ float f;
13+ uint32_t bits;
14+ };
15+
1116#ifdef GGML_USE_OPENMP_SIMD
1217#pragma omp declare simd
1318#endif
1419template <int E>
1520inline uint8_t from_float (float value) {
1621 FP8<E> out;
17- union {
18- float f;
19- uint32_t bits;
20- } in = {value};
22+ fp32_int32 in = {value};
2123 out.bits = (in.bits >> 24 ) & 0x80 ;
2224 in.bits &= 0x7fffffff ;
23- if (in.f >= FP8<E>::MAX () ) {
25+ if (in.f >= FP8<E>::MAX) {
2426 out.bits |= 0x7E ;
25- } else if (in.f < FP8<E>::MIN () ) { // => 0.
27+ } else if (in.f < FP8<E>::MIN) { // => 0.
2628 } else {
27- in.f *= exp_m2 <FP8<E>::E_BIAS () -127 >();
28- uint32_t eps = (0x3fffff >>FP8<E>::M ()) + ((in.bits >> (23 -FP8<E>::M () )) & 0x1 );
29+ in.f *= exp_f2 <FP8<E>::E_BIAS-127 >();
30+ uint32_t eps = (0x3fffff >>FP8<E>::M) + ((in.bits >> (23 -FP8<E>::M)) & 0x1 );
2931 in.bits += eps;
30- out.bits |= (in.bits >> (23 -FP8<E>::M () )) & 0x7F ;
32+ out.bits |= (in.bits >> (23 -FP8<E>::M)) & 0x7F ;
3133 }
3234 return out.bits ;
3335}
@@ -37,16 +39,13 @@ inline uint8_t from_float(float value) {
3739#endif
3840template <int E>
3941inline float to_float (const FP8<E>& in) {
40- union {
41- float f;
42- uint32_t bits;
43- } out = {0 };
42+ fp32_int32 out = {0 };
4443 out.bits = in.bits & 0x80 ;
4544 out.bits <<= 24 ;
4645 uint32_t _bits = in.bits & 0x7F ;
47- _bits <<= (23 -FP8<E>::M () );
46+ _bits <<= (23 -FP8<E>::M);
4847 out.bits |= _bits;
49- out.f *= exp_p2 <127 -FP8<E>::E_BIAS () >();
48+ out.f *= exp_f2 <127 -FP8<E>::E_BIAS>();
5049 return out.f ;
5150}
5251} // namespace fp8
@@ -91,8 +90,8 @@ static inline void conv(const float* x, bloc_fp8<E, QK>* y, int64_t size) {
9190 for (int64_t i=0 ; i<QK; i++) {
9291 m = std::max (std::abs (x[q*QK+i]),m);
9392 }
94- const float D = FP8<E>::MAX () /m;
95- y[q].d = m/FP8<E>::MAX () ;
93+ const float D = FP8<E>::MAX/m;
94+ y[q].d = m/FP8<E>::MAX;
9695#ifdef GGML_USE_OPENMP_SIMD
9796 #pragma omp simd
9897#endif
@@ -154,22 +153,22 @@ float dot_reg(const bloc_fp8<E, QK>* x, const _Y* y, int64_t size) {
154153 for (int64_t v=0 ; v<VECT_SIZE; ++v) { mantice_16bits[v] = mantice_8bits[v]; }
155154
156155 for (int64_t v=0 ; v<VECT_SIZE; ++v) { sign_16bits[v] <<= 8 ; }
157- for (int64_t v=0 ; v<VECT_SIZE; ++v) { mantice_16bits[v] <<= (7 -fp8_t::M () ); }
156+ for (int64_t v=0 ; v<VECT_SIZE; ++v) { mantice_16bits[v] <<= (7 -fp8_t ::M); }
158157
159158 for (int64_t v=0 ; v<VECT_SIZE; ++v) { x_bf16[v] = sign_16bits[v] | mantice_16bits[v]; }
160159
161160 for (int64_t v=0 ; v<VECT_SIZE; ++v) { ux[v].bits = x_bf16[v]; }
162161 for (int64_t v=0 ; v<VECT_SIZE; ++v) { ux[v].bits <<= 16 ; }
163162
164- for (int64_t v=0 ; v<VECT_SIZE; ++v) { X[v] = ux[v].f ; } // * exp_p2 <127-fp8_t::E_BIAS() >(); }
163+ for (int64_t v=0 ; v<VECT_SIZE; ++v) { X[v] = ux[v].f ; } // * exp_f2 <127-fp8_t::E_BIAS>(); }
165164 for (int64_t v=0 ; v<VECT_SIZE; ++v) { Y[v] = (float )y[q*QK+i+r*VECT_SIZE+v]; }
166165 for (int64_t v=0 ; v<VECT_SIZE; ++v) { Z0[r][v] += X[v]*Y[v]; }
167166 }
168167 }
169168 // apply scale
170169 for (int64_t r=0 ; r<NB_REG; ++r) {
171170 for (int64_t v=0 ; v<VECT_SIZE; ++v) {
172- Z[r][v] += Z0[r][v]*(x[q]).d * exp_p2 <127 -fp8_t::E_BIAS () >();
171+ Z[r][v] += Z0[r][v]*(x[q]).d * exp_f2 <127 -fp8_t ::E_BIAS>();
173172 }
174173 }
175174 }
0 commit comments