Skip to content

Commit 748c6a5

Browse files
use fastfdiv for mul_mat_id modulo
1 parent dbde65f commit 748c6a5

File tree

2 files changed

+32
-29
lines changed

2 files changed

+32
-29
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,8 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
570570
//
571571
// n/d = (mulhi(n, mp) + n) >> L;
572572
static const uint3 init_fastdiv_values(uint32_t d) {
573+
GGML_ASSERT(d != 0);
574+
573575
// compute L = ceil(log2(d));
574576
uint32_t L = 0;
575577
while (L < 32 && (uint32_t{ 1 } << L) < d) {

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ template <ggml_type type, int ncols_dst>
141141
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
142142
static __global__ void mul_mat_vec_q(
143143
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
144-
const uint32_t ncols_x, const uint32_t nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
144+
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
145145
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
146146
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
147147
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
@@ -163,8 +163,8 @@ static __global__ void mul_mat_vec_q(
163163

164164
// The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
165165
const uint32_t channel_dst = blockIdx.y;
166-
const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
167-
const uint32_t channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
166+
const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
167+
const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
168168
const uint32_t sample_dst = blockIdx.z;
169169
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
170170
const uint32_t sample_y = sample_dst;
@@ -248,8 +248,9 @@ static void mul_mat_vec_q_switch_ncols_dst(
248248
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
249249
GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
250250

251-
const uint3 channel_ratio = init_fastdiv_values(nchannels_dst / nchannels_x);
252-
const uint3 sample_ratio = init_fastdiv_values(nsamples_dst / nsamples_x);
251+
const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
252+
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
253+
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
253254

254255
const int device = ggml_cuda_get_device();
255256
const int warp_size = ggml_cuda_info().devices[device].warp_size;
@@ -261,65 +262,65 @@ static void mul_mat_vec_q_switch_ncols_dst(
261262
constexpr int c_ncols_dst = 1;
262263
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
263264
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
264-
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
265-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
266-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
265+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
266+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
267+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
267268
} break;
268269
case 2: {
269270
constexpr int c_ncols_dst = 2;
270271
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
271272
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
272-
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
273-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
274-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
273+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
274+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
275+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
275276
} break;
276277
case 3: {
277278
constexpr int c_ncols_dst = 3;
278279
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
279280
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
280-
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
281-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
282-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
281+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
282+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
283+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
283284
} break;
284285
case 4: {
285286
constexpr int c_ncols_dst = 4;
286287
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
287288
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
288-
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
289-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
290-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
289+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
290+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
291+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
291292
} break;
292293
case 5: {
293294
constexpr int c_ncols_dst = 5;
294295
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
295296
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
296-
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
297-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
298-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
297+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
298+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
299+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
299300
} break;
300301
case 6: {
301302
constexpr int c_ncols_dst = 6;
302303
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
303304
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
304-
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
305-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
306-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
305+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
306+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
307+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
307308
} break;
308309
case 7: {
309310
constexpr int c_ncols_dst = 7;
310311
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
311312
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
312-
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
313-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
314-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
313+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
314+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
315+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
315316
} break;
316317
case 8: {
317318
constexpr int c_ncols_dst = 8;
318319
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
319320
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
320-
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
321-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
322-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
321+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
322+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
323+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
323324
} break;
324325
default:
325326
GGML_ABORT("fatal error");

0 commit comments

Comments
 (0)