66layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
77
88layout (constant_id = 0) const uint BLOCK_SIZE = 32;
9+ layout (constant_id = 1) const uint NUM_ROWS = 1;
910
10- shared FLOAT_TYPE tmp[BLOCK_SIZE];
11-
12- void main() {
13- const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
14-
15- if (row >= p.stride_d) {
16- return;
17- }
11+ shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
1812
13+ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
1914 uint a_offset, b_offset, d_offset;
2015 get_offsets(a_offset, b_offset, d_offset);
2116
2217 const uint num_blocks_per_row = p.ncols / QUANT_K;
23- const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
2418
2519 // 16 threads are used to process each block
2620 const uint it_size = gl_WorkGroupSize.x/16;
@@ -38,15 +32,15 @@ void main() {
3832 const uint s_offset = 8*v_im;
3933 const uint y_offset = 128*v_im + l0;
4034
41- FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
35+ FLOAT_TYPE temp[NUM_ROWS];
36+
37+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
38+ temp[i] = FLOAT_TYPE(0);
39+ }
4240
4341 [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
4442 const uint y_idx = i * QUANT_K + y_offset;
4543
46- f16vec2 d = data_a[ib0 + i].d;
47- const FLOAT_TYPE dall = d.x;
48- const FLOAT_TYPE dmin = d.y;
49-
5044 B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
5145 B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
5246 B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
@@ -56,58 +50,84 @@ void main() {
5650 B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
5751 B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
5852
59- uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
60- uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
61-
62- uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F;
63- uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F;
64- uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F;
65- uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F;
66-
67- uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32));
68- uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32));
69- uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32));
70- uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32));
71-
72- uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0];
73- uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8];
74- uvec2 qs0 = uvec2(unpack8(qs0_u16));
75- uvec2 qs16 = uvec2(unpack8(qs16_u16));
76-
77- FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
78- FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
79- [[unroll]] for (int l = 0; l < 2; ++l) {
80- sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
81- fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
82- fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
83- fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
84- fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
85- fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
86- fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
87- fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
88- sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]),
89- fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]),
90- fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]),
91- fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]),
92- fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]),
93- fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]),
94- fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
95- fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
53+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
54+ const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
55+ f16vec2 d = data_a[ib0 + i].d;
56+ const FLOAT_TYPE dall = d.x;
57+ const FLOAT_TYPE dmin = d.y;
58+
59+ uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
60+ uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
61+
62+ uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F;
63+ uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F;
64+ uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F;
65+ uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F;
66+
67+ uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32));
68+ uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32));
69+ uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32));
70+ uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32));
71+
72+ uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0];
73+ uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8];
74+ uvec2 qs0 = uvec2(unpack8(qs0_u16));
75+ uvec2 qs16 = uvec2(unpack8(qs16_u16));
76+
77+ FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
78+ FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
79+ [[unroll]] for (int l = 0; l < 2; ++l) {
80+ sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
81+ fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
82+ fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
83+ fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
84+ fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
85+ fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
86+ fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
87+ fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
88+ sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]),
89+ fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]),
90+ fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]),
91+ fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]),
92+ fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]),
93+ fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]),
94+ fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
95+ fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
96+ }
97+ temp[n] = fma(dall, sum1, fma(-dmin, sum2, temp[n]));
9698 }
97- temp = fma(dall, sum1, fma(-dmin, sum2, temp));
9899 }
99100
100- tmp[gl_LocalInvocationID.x] = temp;
101-
102101 // sum up partial sums and write back result
102+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
103+ tmpsh[n][tid] = temp[n];
104+ }
103105 barrier();
104- [[unroll]] for (uint s = gl_WorkGroupSize.x /2; s > 0; s >>= 1) {
106+ [[unroll]] for (uint s = BLOCK_SIZE /2; s > 0; s >>= 1) {
105107 if (tid < s) {
106- tmp[tid] += tmp[tid + s];
108+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
109+ tmpsh[n][tid] += tmpsh[n][tid + s];
110+ }
107111 }
108112 barrier();
109113 }
110114 if (tid == 0) {
111- data_d[d_offset + row] = D_TYPE(tmp[0]);
115+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
116+ data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
117+ }
118+ }
119+ }
120+
121+ void main() {
122+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
123+
124+ // do NUM_ROWS at a time, unless there aren't enough remaining rows
125+ if (first_row + NUM_ROWS <= p.stride_d) {
126+ compute_outputs(first_row, NUM_ROWS);
127+ } else {
128+ if (first_row >= p.stride_d) {
129+ return;
130+ }
131+ compute_outputs(first_row, p.stride_d - first_row);
112132 }
113133}
0 commit comments