77#include < stdio.h>
88#include < atomic>
99#include < assert.h>
10+ #include < float.h>
1011
1112#if defined(GGML_USE_HIPBLAS)
1213#include < hip/hip_runtime.h>
@@ -4587,20 +4588,20 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
45874588 block_q4_0 * dsti = (block_q4_0 *) cdsti;
45884589
45894590 float amax = 0 .0f ;
4590- float max = 0 .0f ;
4591+ float vmax = 0 .0f ;
45914592
45924593 for (int j = 0 ; j < QK4_0; ++j) {
45934594 const float v = xi[j];
45944595 if (amax < fabsf (v)) {
45954596 amax = fabsf (v);
4596- max = v;
4597+ vmax = v;
45974598 }
45984599 }
45994600
4600- const float d = max / -8 ;
4601+ const float d = vmax / -8 ;
46014602 const float id = d ? 1 .0f /d : 0 .0f ;
46024603
4603- y[i]. d = d;
4604+ dsti-> d = d;
46044605
46054606 for (int j = 0 ; j < QK4_0/2 ; ++j) {
46064607 const float x0 = xi[0 + j]*id;
@@ -4614,6 +4615,38 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
46144615 }
46154616}
46164617
4618+ static __device__ void cpy_blck_f32_q4_1 (const char * cxi, char * cdsti) {
4619+ const float * xi = (const float *) cxi;
4620+ block_q4_1 * dsti = (block_q4_1 *) cdsti;
4621+
4622+ float vmin = FLT_MAX;
4623+ float vmax = -FLT_MAX;
4624+
4625+ for (int j = 0 ; j < QK4_1; ++j) {
4626+ const float v = xi[j];
4627+
4628+ if (v < vmin) vmin = v;
4629+ if (v > vmax) vmax = v;
4630+ }
4631+
4632+ const float d = (vmax - vmin) / ((1 << 4 ) - 1 );
4633+ const float id = d ? 1 .0f /d : 0 .0f ;
4634+
4635+ dsti->dm .x = d;
4636+ dsti->dm .y = vmin;
4637+
4638+ for (int j = 0 ; j < QK4_1/2 ; ++j) {
4639+ const float x0 = (xi[0 + j] - vmin)*id;
4640+ const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
4641+
4642+ const uint8_t xi0 = min (15 , (int8_t )(x0 + 0 .5f ));
4643+ const uint8_t xi1 = min (15 , (int8_t )(x1 + 0 .5f ));
4644+
4645+ dsti->qs [j] = xi0;
4646+ dsti->qs [j] |= xi1 << 4 ;
4647+ }
4648+ }
4649+
46174650template <cpy_kernel_t cpy_blck, int qk>
46184651static __global__ void cpy_f32_q (const char * cx, char * cdst, const int ne,
46194652 const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
0 commit comments