Skip to content

Commit 9d83231

Browse files
committed
Publish int2 quant kernel
1 parent e2bc170 commit 9d83231

File tree

2 files changed

+109
-2
lines changed

2 files changed

+109
-2
lines changed

src/kernels/dequantize.inl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,10 @@ static auto PIQUANT_HOT dequant_generic(
114114
dequant_uint4_to_bf16<ReduceOp>(static_cast<const uint4_t*>(in), static_cast<bfp16_t*>(out), numel, scale, static_cast<std::int32_t>(zp));
115115
return;
116116
}
117-
117+
if constexpr (std::is_same_v<In, uint2_t> && std::is_same_v<Out, bfp16_t>) {
118+
dequant_uint2_to_bf16<ReduceOp>(static_cast<const uint2_t*>(in), static_cast<bfp16_t*>(out), numel, scale, static_cast<std::int32_t>(zp));
119+
return;
120+
}
118121

119122
if constexpr (std::is_same_v<uint4_t, In>) { // Special case for int4
120123
dequant_uint4<In, Out, ReduceOp>(x, o, numel, scale, zp);

src/kernels/kernels_specialized.inl

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1199,7 +1199,6 @@ static auto PIQUANT_HOT dequant_uint4_to_bf16(
11991199
std::int64_t i {};
12001200

12011201
#if defined(__AVX512F__) && defined(__AVX512BW__)
1202-
__m512i vzp {_mm512_set1_epi32(zp)};
12031202
__m512 vscale {_mm512_set1_ps(scale)};
12041203
__m512i vmaskLo {_mm512_set1_epi8(0x0f)};
12051204
__m512 vbias {_mm512_set1_ps(-static_cast<fp32_t>(zp)*scale)};
@@ -1311,6 +1310,111 @@ static auto PIQUANT_HOT dequant_uint4_to_bf16(
13111310
}
13121311
}
13131312

1313+
template <const reduce_op ReduceOp>
1314+
static auto PIQUANT_HOT dequant_uint2_to_bf16(
1315+
const uint2_t* PIQUANT_RESTRICT x,
1316+
bfp16_t* PIQUANT_RESTRICT o,
1317+
std::int64_t numel,
1318+
fp32_t scale,
1319+
std::int32_t zp
1320+
) noexcept -> void {
1321+
std::int64_t i {};
1322+
#if defined(__AVX512F__) && defined(__AVX512BW__)
1323+
__m512 vscale {_mm512_set1_ps(scale)};
1324+
__m512 vbias {_mm512_set1_ps(-static_cast<fp32_t>(zp) * scale)};
1325+
__m128i LUT0 {_mm_setr_epi8(0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3)};
1326+
__m128i LUT1 {_mm_setr_epi8(0,0,0,0, 1,1,1,1, 2,2,2,2, 3,3,3,3)};
1327+
__m128i LO_NIB {_mm_set1_epi8(15)};
1328+
auto load_o16_as_ps {[&](const bfp16_t* ptr) noexcept -> __m512 {
1329+
#ifdef __AVX512BF16__
1330+
return _mm512_cvtpbh_ps(std::bit_cast<__m256bh>(_mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr))));
1331+
#else
1332+
__m256i raw {_mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr))};
1333+
__m512i u32 {_mm512_cvtepu16_epi32(raw)};
1334+
return _mm512_castsi512_ps(_mm512_slli_epi32(u32, 16));
1335+
#endif
1336+
}};
1337+
auto store_ps_to_o16 {[&](bfp16_t* ptr, __m512 f) noexcept {
1338+
#ifdef __AVX512BF16__
1339+
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), std::bit_cast<__m256i>(_mm512_cvtneps_pbh(f)));
1340+
#else
1341+
_mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), cvt_ps_to_bf16(f));
1342+
#endif
1343+
}};
1344+
auto do_16bytes {[&](const uint8_t* src, std::int64_t base_out) {
1345+
__m128i b {_mm_loadu_si128(reinterpret_cast<const __m128i*>(src))};
1346+
__m128i lo {_mm_and_si128(b, LO_NIB)};
1347+
__m128i hi {_mm_and_si128(_mm_srli_epi16(b, 4), LO_NIB)};
1348+
__m128i q0 {_mm_shuffle_epi8(LUT0, lo)};
1349+
__m128i q1 {_mm_shuffle_epi8(LUT1, lo)};
1350+
__m128i q2 {_mm_shuffle_epi8(LUT0, hi)};
1351+
__m128i q3 {_mm_shuffle_epi8(LUT1, hi)};
1352+
__m128i ab_lo {_mm_unpacklo_epi8(q0, q1)};
1353+
__m128i ab_hi {_mm_unpackhi_epi8(q0, q1)};
1354+
__m128i cd_lo {_mm_unpacklo_epi8(q2, q3)};
1355+
__m128i cd_hi {_mm_unpackhi_epi8(q2, q3)};
1356+
__m128i z0 {_mm_unpacklo_epi16(ab_lo, cd_lo)};
1357+
__m128i z1 {_mm_unpackhi_epi16(ab_lo, cd_lo)};
1358+
__m128i z2 {_mm_unpacklo_epi16(ab_hi, cd_hi)};
1359+
__m128i z3 {_mm_unpackhi_epi16(ab_hi, cd_hi)};
1360+
auto to_f {[&](__m128i z) noexcept -> __m512 {
1361+
return _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(z)), vscale, vbias);
1362+
}};
1363+
__m512 f0 {to_f(z0)};
1364+
__m512 f1 {to_f(z1)};
1365+
__m512 f2 {to_f(z2)};
1366+
__m512 f3 {to_f(z3)};
1367+
if constexpr (ReduceOp == reduce_op::add) {
1368+
f0 = _mm512_add_ps(f0, load_o16_as_ps(o+base_out + 0));
1369+
f1 = _mm512_add_ps(f1, load_o16_as_ps(o+base_out + 16));
1370+
f2 = _mm512_add_ps(f2, load_o16_as_ps(o+base_out + 32));
1371+
f3 = _mm512_add_ps(f3, load_o16_as_ps(o+base_out + 48));
1372+
}
1373+
store_ps_to_o16(o+base_out + 0, f0);
1374+
store_ps_to_o16(o+base_out + 16, f1);
1375+
store_ps_to_o16(o+base_out + 32, f2);
1376+
store_ps_to_o16(o+base_out + 48, f3);
1377+
}};
1378+
for (; i+255 < numel; i += 256) {
1379+
const uint8_t* src {reinterpret_cast<const uint8_t*>(x) + (i>>2)};
1380+
do_16bytes(src, i);
1381+
do_16bytes(src+16, i+64);
1382+
do_16bytes(src+32, i+128);
1383+
do_16bytes(src+48, i+192);
1384+
}
1385+
#endif
1386+
1387+
const auto dequant_step = [=](std::int32_t q) noexcept -> fp32_t {
1388+
return (static_cast<fp32_t>(q) - zp) * scale;
1389+
};
1390+
std::int64_t j {i>>2};
1391+
for (; i+3 < numel; i += 4, ++j) {
1392+
auto p = x[j].bits;
1393+
int qa = p & 3;
1394+
int qb = (p >> 2) & 3;
1395+
int qc = (p >> 4) & 3;
1396+
int qd = (p >> 6) & 3;
1397+
if constexpr (ReduceOp == reduce_op::set) {
1398+
o[i+0] = dequant_step(qa);
1399+
o[i+1] = dequant_step(qb);
1400+
o[i+2] = dequant_step(qc);
1401+
o[i+3] = dequant_step(qd);
1402+
} else {
1403+
o[i+0] += dequant_step(qa);
1404+
o[i+1] += dequant_step(qb);
1405+
o[i+2] += dequant_step(qc);
1406+
o[i+3] += dequant_step(qd);
1407+
}
1408+
}
1409+
if (i < numel) {
1410+
auto p = x[i >> 2].bits;
1411+
int rem = int(numel - i);
1412+
if (rem >= 1) { if constexpr (ReduceOp==reduce_op::set) o[i+0] = dequant_step( p & 3); else o[i+0] += dequant_step( p & 3); }
1413+
if (rem >= 2) { if constexpr (ReduceOp==reduce_op::set) o[i+1] = dequant_step((p >> 2) & 3); else o[i+1] += dequant_step((p >> 2) & 3); }
1414+
if (rem >= 3) { if constexpr (ReduceOp==reduce_op::set) o[i+2] = dequant_step((p >> 4) & 3); else o[i+2] += dequant_step((p >> 4) & 3); }
1415+
}
1416+
}
1417+
13141418
static auto PIQUANT_HOT find_min_max_f32(std::span<const fp32_t> in) noexcept -> std::array<fp32_t, 2> {
13151419
const fp32_t* PIQUANT_RESTRICT x {in.data()};
13161420
auto numel {static_cast<std::int64_t>(in.size())};

0 commit comments

Comments
 (0)