Skip to content

Commit 7720366

Browse files
committed
CUDA: use v_perm_b32 to replace byte_perm on AMD GPUs
1 parent bd4b6cd commit 7720366

File tree

1 file changed

+31
-4
lines changed

1 file changed

+31
-4
lines changed

ggml/src/ggml-cuda/vecdotq.cuh

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,45 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
2929
}
3030

3131
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
32-
#if __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
32+
#if defined(GGML_USE_HIP)
33+
// Load the 16-byte table into four 32-bit unsigned integers.
34+
const uint32_t *values = (const uint32_t *)table;
35+
36+
const uint32_t q_even = q4;
37+
const uint32_t q_odd = (q4 >> 4);
38+
39+
// Perform lookups in the lower half of the table (indices 0-7).
40+
uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707);
41+
uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707);
42+
43+
// Perform lookups in the upper half of the table (indices 8-15).
44+
uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707);
45+
uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707);
46+
47+
// Select between the low and high results based on the MSB of each index nibble.
48+
uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1);
49+
uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even);
50+
uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1);
51+
uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd);
52+
53+
return make_int2(res_x, res_y);
54+
#elif defined(__CUDA_ARCH__)
3355
uint32_t v1, v2, v3, v4, mask;
34-
const uint32_t * values = (const uint32_t *)table;
56+
const uint32_t *values = (const uint32_t *)table;
3557

3658
mask = (0x32103210 | ((q4 & 0x88888888) >> 1));
59+
// Perform lookups in the lower half of the table (indices 0-7).
3760
v1 = __byte_perm(values[0], values[1], q4);
61+
// Perform lookups in the upper half of the table (indices 8-15).
3862
v2 = __byte_perm(values[2], values[3], q4);
63+
// Select between the low and high results based on the MSB of each index nibble.
3964
v3 = __byte_perm(v1, v2, mask);
65+
// Same for the upper part of q4.
4066
v1 = __byte_perm(values[0], values[1], q4 >> 16);
4167
v2 = __byte_perm(values[2], values[3], q4 >> 16);
4268
v4 = __byte_perm(v1, v2, mask >> 16);
43-
69+
70+
// Mix the results to get the final int2.
4471
return make_int2(__byte_perm(v3, v4, 0x6420), __byte_perm(v3, v4, 0x7531));
4572
#else
4673
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
@@ -54,7 +81,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
5481
table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
5582

5683
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
57-
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
84+
#endif
5885
}
5986

6087
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called

0 commit comments

Comments
 (0)