@@ -141,7 +141,7 @@ template <ggml_type type, int ncols_dst>
141141__launch_bounds__ (calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
142142static __global__ void mul_mat_vec_q(
143143 const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
144- const uint32_t ncols_x, const uint32_t nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
144+ const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
145145 const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
146146 const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
147147 const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
@@ -163,8 +163,8 @@ static __global__ void mul_mat_vec_q(
163163
164164 // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
165165 const uint32_t channel_dst = blockIdx .y ;
166- const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv (channel_dst, channel_ratio);
167- const uint32_t channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
166+ const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv (channel_dst, channel_ratio);
167+ const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo ( channel_dst, nchannels_y) : channel_dst;
168168 const uint32_t sample_dst = blockIdx .z ;
169169 const uint32_t sample_x = fastdiv (sample_dst, sample_ratio);
170170 const uint32_t sample_y = sample_dst;
@@ -248,8 +248,9 @@ static void mul_mat_vec_q_switch_ncols_dst(
248248 GGML_ASSERT (ncols_x % ggml_blck_size (type) == 0 );
249249 GGML_ASSERT (ncols_dst <= MMVQ_MAX_BATCH_SIZE);
250250
251- const uint3 channel_ratio = init_fastdiv_values (nchannels_dst / nchannels_x);
252- const uint3 sample_ratio = init_fastdiv_values (nsamples_dst / nsamples_x);
251+ const uint3 nchannels_y_fd = ids ? init_fastdiv_values (nchannels_y) : make_uint3 (0 , 0 , 0 );
252+ const uint3 channel_ratio_fd = ids ? make_uint3 (0 , 0 , 0 ) : init_fastdiv_values (nchannels_dst / nchannels_x);
253+ const uint3 sample_ratio_fd = init_fastdiv_values (nsamples_dst / nsamples_x);
253254
254255 const int device = ggml_cuda_get_device ();
255256 const int warp_size = ggml_cuda_info ().devices [device].warp_size ;
@@ -261,65 +262,65 @@ static void mul_mat_vec_q_switch_ncols_dst(
261262 constexpr int c_ncols_dst = 1 ;
262263 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
263264 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
264- (vx, vy, ids, dst, ncols_x, nchannels_y , stride_row_x, stride_col_y, stride_col_dst,
265- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
266- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
265+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd , stride_row_x, stride_col_y, stride_col_dst,
266+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
267+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
267268 } break ;
268269 case 2 : {
269270 constexpr int c_ncols_dst = 2 ;
270271 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
271272 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
272- (vx, vy, ids, dst, ncols_x, nchannels_y , stride_row_x, stride_col_y, stride_col_dst,
273- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
274- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
273+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd , stride_row_x, stride_col_y, stride_col_dst,
274+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
275+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
275276 } break ;
276277 case 3 : {
277278 constexpr int c_ncols_dst = 3 ;
278279 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
279280 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
280- (vx, vy, ids, dst, ncols_x, nchannels_y , stride_row_x, stride_col_y, stride_col_dst,
281- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
282- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
281+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd , stride_row_x, stride_col_y, stride_col_dst,
282+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
283+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
283284 } break ;
284285 case 4 : {
285286 constexpr int c_ncols_dst = 4 ;
286287 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
287288 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
288- (vx, vy, ids, dst, ncols_x, nchannels_y , stride_row_x, stride_col_y, stride_col_dst,
289- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
290- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
289+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd , stride_row_x, stride_col_y, stride_col_dst,
290+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
291+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
291292 } break ;
292293 case 5 : {
293294 constexpr int c_ncols_dst = 5 ;
294295 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
295296 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
296- (vx, vy, ids, dst, ncols_x, nchannels_y , stride_row_x, stride_col_y, stride_col_dst,
297- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
298- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
297+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd , stride_row_x, stride_col_y, stride_col_dst,
298+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
299+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
299300 } break ;
300301 case 6 : {
301302 constexpr int c_ncols_dst = 6 ;
302303 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
303304 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
304- (vx, vy, ids, dst, ncols_x, nchannels_y , stride_row_x, stride_col_y, stride_col_dst,
305- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
306- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
305+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd , stride_row_x, stride_col_y, stride_col_dst,
306+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
307+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
307308 } break ;
308309 case 7 : {
309310 constexpr int c_ncols_dst = 7 ;
310311 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
311312 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
312- (vx, vy, ids, dst, ncols_x, nchannels_y , stride_row_x, stride_col_y, stride_col_dst,
313- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
314- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
313+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd , stride_row_x, stride_col_y, stride_col_dst,
314+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
315+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
315316 } break ;
316317 case 8 : {
317318 constexpr int c_ncols_dst = 8 ;
318319 std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
319320 mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
320- (vx, vy, ids, dst, ncols_x, nchannels_y , stride_row_x, stride_col_y, stride_col_dst,
321- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
322- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
321+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd , stride_row_x, stride_col_y, stride_col_dst,
322+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
323+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
323324 } break ;
324325 default :
325326 GGML_ABORT (" fatal error" );
0 commit comments