@@ -1199,7 +1199,6 @@ static auto PIQUANT_HOT dequant_uint4_to_bf16(
1199
1199
std::int64_t i {};
1200
1200
1201
1201
#if defined(__AVX512F__) && defined(__AVX512BW__)
1202
- __m512i vzp {_mm512_set1_epi32 (zp)};
1203
1202
__m512 vscale {_mm512_set1_ps (scale)};
1204
1203
__m512i vmaskLo {_mm512_set1_epi8 (0x0f )};
1205
1204
__m512 vbias {_mm512_set1_ps (-static_cast <fp32_t >(zp)*scale)};
@@ -1311,6 +1310,111 @@ static auto PIQUANT_HOT dequant_uint4_to_bf16(
1311
1310
}
1312
1311
}
1313
1312
1313
+ template <const reduce_op ReduceOp>
1314
+ static auto PIQUANT_HOT dequant_uint2_to_bf16 (
1315
+ const uint2_t * PIQUANT_RESTRICT x,
1316
+ bfp16_t * PIQUANT_RESTRICT o,
1317
+ std::int64_t numel,
1318
+ fp32_t scale,
1319
+ std::int32_t zp
1320
+ ) noexcept -> void {
1321
+ std::int64_t i {};
1322
+ #if defined(__AVX512F__) && defined(__AVX512BW__)
1323
+ __m512 vscale {_mm512_set1_ps (scale)};
1324
+ __m512 vbias {_mm512_set1_ps (-static_cast <fp32_t >(zp) * scale)};
1325
+ __m128i LUT0 {_mm_setr_epi8 (0 ,1 ,2 ,3 , 0 ,1 ,2 ,3 , 0 ,1 ,2 ,3 , 0 ,1 ,2 ,3 )};
1326
+ __m128i LUT1 {_mm_setr_epi8 (0 ,0 ,0 ,0 , 1 ,1 ,1 ,1 , 2 ,2 ,2 ,2 , 3 ,3 ,3 ,3 )};
1327
+ __m128i LO_NIB {_mm_set1_epi8 (15 )};
1328
+ auto load_o16_as_ps {[&](const bfp16_t * ptr) noexcept -> __m512 {
1329
+ #ifdef __AVX512BF16__
1330
+ return _mm512_cvtpbh_ps (std::bit_cast<__m256bh>(_mm256_loadu_si256 (reinterpret_cast <const __m256i*>(ptr))));
1331
+ #else
1332
+ __m256i raw {_mm256_loadu_si256 (reinterpret_cast <const __m256i*>(ptr))};
1333
+ __m512i u32 {_mm512_cvtepu16_epi32 (raw)};
1334
+ return _mm512_castsi512_ps (_mm512_slli_epi32 (u32 , 16 ));
1335
+ #endif
1336
+ }};
1337
+ auto store_ps_to_o16 {[&](bfp16_t * ptr, __m512 f) noexcept {
1338
+ #ifdef __AVX512BF16__
1339
+ _mm256_storeu_si256 (reinterpret_cast <__m256i*>(ptr), std::bit_cast<__m256i>(_mm512_cvtneps_pbh (f)));
1340
+ #else
1341
+ _mm256_storeu_si256 (reinterpret_cast <__m256i*>(ptr), cvt_ps_to_bf16 (f));
1342
+ #endif
1343
+ }};
1344
+ auto do_16bytes {[&](const uint8_t * src, std::int64_t base_out) {
1345
+ __m128i b {_mm_loadu_si128 (reinterpret_cast <const __m128i*>(src))};
1346
+ __m128i lo {_mm_and_si128 (b, LO_NIB)};
1347
+ __m128i hi {_mm_and_si128 (_mm_srli_epi16 (b, 4 ), LO_NIB)};
1348
+ __m128i q0 {_mm_shuffle_epi8 (LUT0, lo)};
1349
+ __m128i q1 {_mm_shuffle_epi8 (LUT1, lo)};
1350
+ __m128i q2 {_mm_shuffle_epi8 (LUT0, hi)};
1351
+ __m128i q3 {_mm_shuffle_epi8 (LUT1, hi)};
1352
+ __m128i ab_lo {_mm_unpacklo_epi8 (q0, q1)};
1353
+ __m128i ab_hi {_mm_unpackhi_epi8 (q0, q1)};
1354
+ __m128i cd_lo {_mm_unpacklo_epi8 (q2, q3)};
1355
+ __m128i cd_hi {_mm_unpackhi_epi8 (q2, q3)};
1356
+ __m128i z0 {_mm_unpacklo_epi16 (ab_lo, cd_lo)};
1357
+ __m128i z1 {_mm_unpackhi_epi16 (ab_lo, cd_lo)};
1358
+ __m128i z2 {_mm_unpacklo_epi16 (ab_hi, cd_hi)};
1359
+ __m128i z3 {_mm_unpackhi_epi16 (ab_hi, cd_hi)};
1360
+ auto to_f {[&](__m128i z) noexcept -> __m512 {
1361
+ return _mm512_fmadd_ps (_mm512_cvtepi32_ps (_mm512_cvtepu8_epi32 (z)), vscale, vbias);
1362
+ }};
1363
+ __m512 f0 {to_f (z0)};
1364
+ __m512 f1 {to_f (z1)};
1365
+ __m512 f2 {to_f (z2)};
1366
+ __m512 f3 {to_f (z3)};
1367
+ if constexpr (ReduceOp == reduce_op::add) {
1368
+ f0 = _mm512_add_ps (f0, load_o16_as_ps (o+base_out + 0 ));
1369
+ f1 = _mm512_add_ps (f1, load_o16_as_ps (o+base_out + 16 ));
1370
+ f2 = _mm512_add_ps (f2, load_o16_as_ps (o+base_out + 32 ));
1371
+ f3 = _mm512_add_ps (f3, load_o16_as_ps (o+base_out + 48 ));
1372
+ }
1373
+ store_ps_to_o16 (o+base_out + 0 , f0);
1374
+ store_ps_to_o16 (o+base_out + 16 , f1);
1375
+ store_ps_to_o16 (o+base_out + 32 , f2);
1376
+ store_ps_to_o16 (o+base_out + 48 , f3);
1377
+ }};
1378
+ for (; i+255 < numel; i += 256 ) {
1379
+ const uint8_t * src {reinterpret_cast <const uint8_t *>(x) + (i>>2 )};
1380
+ do_16bytes (src, i);
1381
+ do_16bytes (src+16 , i+64 );
1382
+ do_16bytes (src+32 , i+128 );
1383
+ do_16bytes (src+48 , i+192 );
1384
+ }
1385
+ #endif
1386
+
1387
+ const auto dequant_step = [=](std::int32_t q) noexcept -> fp32_t {
1388
+ return (static_cast <fp32_t >(q) - zp) * scale;
1389
+ };
1390
+ std::int64_t j {i>>2 };
1391
+ for (; i+3 < numel; i += 4 , ++j) {
1392
+ auto p = x[j].bits ;
1393
+ int qa = p & 3 ;
1394
+ int qb = (p >> 2 ) & 3 ;
1395
+ int qc = (p >> 4 ) & 3 ;
1396
+ int qd = (p >> 6 ) & 3 ;
1397
+ if constexpr (ReduceOp == reduce_op::set) {
1398
+ o[i+0 ] = dequant_step (qa);
1399
+ o[i+1 ] = dequant_step (qb);
1400
+ o[i+2 ] = dequant_step (qc);
1401
+ o[i+3 ] = dequant_step (qd);
1402
+ } else {
1403
+ o[i+0 ] += dequant_step (qa);
1404
+ o[i+1 ] += dequant_step (qb);
1405
+ o[i+2 ] += dequant_step (qc);
1406
+ o[i+3 ] += dequant_step (qd);
1407
+ }
1408
+ }
1409
+ if (i < numel) {
1410
+ auto p = x[i >> 2 ].bits ;
1411
+ int rem = int (numel - i);
1412
+ if (rem >= 1 ) { if constexpr (ReduceOp==reduce_op::set) o[i+0 ] = dequant_step ( p & 3 ); else o[i+0 ] += dequant_step ( p & 3 ); }
1413
+ if (rem >= 2 ) { if constexpr (ReduceOp==reduce_op::set) o[i+1 ] = dequant_step ((p >> 2 ) & 3 ); else o[i+1 ] += dequant_step ((p >> 2 ) & 3 ); }
1414
+ if (rem >= 3 ) { if constexpr (ReduceOp==reduce_op::set) o[i+2 ] = dequant_step ((p >> 4 ) & 3 ); else o[i+2 ] += dequant_step ((p >> 4 ) & 3 ); }
1415
+ }
1416
+ }
1417
+
1314
1418
static auto PIQUANT_HOT find_min_max_f32 (std::span<const fp32_t > in) noexcept -> std::array<fp32_t, 2> {
1315
1419
const fp32_t * PIQUANT_RESTRICT x {in.data ()};
1316
1420
auto numel {static_cast <std::int64_t >(in.size ())};
0 commit comments