@@ -29,18 +29,45 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
2929}
3030
3131static __device__ __forceinline__ int2 get_int_from_table_16 (const int & q4, const int8_t * table) {
32- #if __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
32+ #if defined(GGML_USE_HIP)
33+ // Load the 16-byte table into four 32-bit unsigned integers.
34+ const uint32_t *values = (const uint32_t *)table;
35+
36+ const uint32_t q_even = q4;
37+ const uint32_t q_odd = (q4 >> 4 );
38+
39+ // Perform lookups in the lower half of the table (indices 0-7).
40+ uint32_t v_even_low = __builtin_amdgcn_perm (values[1 ], values[0 ], q_even & 0x07070707 );
41+ uint32_t v_odd_low = __builtin_amdgcn_perm (values[1 ], values[0 ], q_odd & 0x07070707 );
42+
43+ // Perform lookups in the upper half of the table (indices 8-15).
44+ uint32_t v_even_high = __builtin_amdgcn_perm (values[3 ], values[2 ], q_even & 0x07070707 );
45+ uint32_t v_odd_high = __builtin_amdgcn_perm (values[3 ], values[2 ], q_odd & 0x07070707 );
46+
47+ // Select between the low and high results based on the MSB of each index nibble.
48+ uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808 ) >> 1 );
49+ uint32_t res_x = __builtin_amdgcn_perm (v_even_high, v_even_low, mask_even);
50+ uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808 ) >> 1 );
51+ uint32_t res_y = __builtin_amdgcn_perm (v_odd_high, v_odd_low, mask_odd);
52+
53+ return make_int2 (res_x, res_y);
54+ #elif defined(__CUDA_ARCH__)
3355 uint32_t v1, v2, v3, v4, mask;
34- const uint32_t * values = (const uint32_t *)table;
56+ const uint32_t *values = (const uint32_t *)table;
3557
3658 mask = (0x32103210 | ((q4 & 0x88888888 ) >> 1 ));
59+ // Perform lookups in the lower half of the table (indices 0-7).
3760 v1 = __byte_perm (values[0 ], values[1 ], q4);
61+ // Perform lookups in the upper half of the table (indices 8-15).
3862 v2 = __byte_perm (values[2 ], values[3 ], q4);
63+ // Select between the low and high results based on the MSB of each index nibble.
3964 v3 = __byte_perm (v1, v2, mask);
65+ // Same for the upper part of q4.
4066 v1 = __byte_perm (values[0 ], values[1 ], q4 >> 16 );
4167 v2 = __byte_perm (values[2 ], values[3 ], q4 >> 16 );
4268 v4 = __byte_perm (v1, v2, mask >> 16 );
43-
69+
70+ // Mix the results to get the final int2.
4471 return make_int2 (__byte_perm (v3, v4, 0x6420 ), __byte_perm (v3, v4, 0x7531 ));
4572#else
4673 const int q0_32 = (q4 >> 0 ) & 0x0F0F0F0F ;
@@ -54,7 +81,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
5481 table[q1_8[0 ]], table[q1_8[1 ]], table[q1_8[2 ]], table[q1_8[3 ]]);
5582
5683 return make_int2 (*((const int *) &val0_8), *((const int *) &val1_8));
57- #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
84+ #endif
5885}
5986
6087// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
0 commit comments