2929#define __shfl_xor_sync (mask, var, laneMask, width ) __shfl_xor (var, laneMask, width)
3030#define cublasCreate hipblasCreate
3131#define cublasGemmEx hipblasGemmEx
32+ #define cublasGemmBatchedEx hipblasGemmBatchedEx
33+ #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
3234#define cublasHandle_t hipblasHandle_t
3335#define cublasSetMathMode (handle, mode ) CUBLAS_STATUS_SUCCESS
3436#define cublasSetStream hipblasSetStream
@@ -4326,13 +4328,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
43264328
43274329 const half * x = (const half *) vx;
43284330
4329- const int row_x = blockDim .y *blockIdx .y + threadIdx .y ;
4330- const int channel = blockDim .z *blockIdx .z + threadIdx .z ;
4331+ const int row_x = blockDim .y *blockIdx .y + threadIdx .y ;
4332+ const int channel = blockDim .z *blockIdx .z + threadIdx .z ;
43314333 const int channel_x = channel / channel_x_divisor;
43324334
4333- const int nrows_y = ncols_x;
4335+ const int nrows_y = ncols_x;
43344336 const int nrows_dst = nrows_x;
4335- const int row_dst = row_x;
4337+ const int row_dst = row_x;
43364338
43374339 const int idst = channel*nrows_dst + row_dst;
43384340
@@ -4345,13 +4347,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
43454347 break ;
43464348 }
43474349
4348- const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
4349- const float xi = __half2float (x[ix]);
4350-
43514350 const int row_y = col_x;
43524351
4352+ const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
43534353 const int iy = channel*nrows_y + row_y;
43544354
4355+ const float xi = __half2float (x[ix]);
4356+
43554357 tmp += xi * y[iy];
43564358 }
43574359
@@ -7013,7 +7015,8 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens
70137015}
70147016
70157017static void ggml_cuda_mul_mat_vec_nc (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
7016- GGML_ASSERT (!ggml_is_contiguous (src0) && ggml_is_contiguous (src1));
7018+ GGML_ASSERT (!ggml_is_transposed (src0));
7019+ GGML_ASSERT (!ggml_is_transposed (src1));
70177020 GGML_ASSERT (!ggml_is_permuted (src0));
70187021 GGML_ASSERT (src0->backend != GGML_BACKEND_GPU_SPLIT);
70197022 GGML_ASSERT (src0->type == GGML_TYPE_F16);
@@ -7023,11 +7026,11 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
70237026 const int64_t ne01 = src0->ne [1 ];
70247027 const int64_t ne02 = src0->ne [2 ];
70257028
7026- const int64_t ne12 = src1->ne [2 ];
7027-
70287029 const int64_t nb01 = src0->nb [1 ];
70297030 const int64_t nb02 = src0->nb [2 ];
70307031
7032+ const int64_t ne12 = src1->ne [2 ];
7033+
70317034 CUDA_CHECK (ggml_cuda_set_device (g_main_device));
70327035 cudaStream_t main_stream = g_cudaStreams[g_main_device][0 ];
70337036
@@ -7046,6 +7049,159 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
70467049 ggml_mul_mat_vec_nc_f16_f32_cuda (src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
70477050}
70487051
7052+ static void ggml_cuda_mul_mat_mat_batched_cublas (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
7053+ GGML_ASSERT (!ggml_is_transposed (src0));
7054+ GGML_ASSERT (!ggml_is_transposed (src1));
7055+ GGML_ASSERT (src0->backend != GGML_BACKEND_GPU_SPLIT);
7056+ GGML_ASSERT (src0->type == GGML_TYPE_F16);
7057+ GGML_ASSERT (src1->type == GGML_TYPE_F32);
7058+
7059+ const int64_t ne00 = src0->ne [0 ]; GGML_UNUSED (ne00);
7060+ const int64_t ne01 = src0->ne [1 ];
7061+ const int64_t ne02 = src0->ne [2 ];
7062+ const int64_t ne03 = src0->ne [3 ];
7063+
7064+ const int64_t nb01 = src0->nb [1 ];
7065+ const int64_t nb02 = src0->nb [2 ]; GGML_UNUSED (nb02);
7066+ const int64_t nb03 = src0->nb [3 ]; GGML_UNUSED (nb03);
7067+
7068+ const int64_t ne10 = src1->ne [0 ];
7069+ const int64_t ne11 = src1->ne [1 ];
7070+ const int64_t ne12 = src1->ne [2 ];
7071+ const int64_t ne13 = src1->ne [3 ];
7072+
7073+ const int64_t nb11 = src1->nb [1 ];
7074+ const int64_t nb12 = src1->nb [2 ]; GGML_UNUSED (nb12);
7075+ const int64_t nb13 = src1->nb [3 ]; GGML_UNUSED (nb13);
7076+
7077+ const int64_t ne1 = ggml_nelements (src1);
7078+ const int64_t ne = ggml_nelements (dst);
7079+
7080+ CUDA_CHECK (ggml_cuda_set_device (g_main_device));
7081+ cudaStream_t main_stream = g_cudaStreams[g_main_device][0 ];
7082+
7083+ int id;
7084+ CUDA_CHECK (cudaGetDevice (&id));
7085+ CUBLAS_CHECK (cublasSetStream (g_cublas_handles[id], main_stream));
7086+
7087+ ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra ;
7088+ void * src0_ddq = src0_extra->data_device [g_main_device];
7089+ half * src0_as_f16 = (half *) src0_ddq;
7090+
7091+ ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra ;
7092+ float * src1_ddf = (float *) src1_extra->data_device [g_main_device];
7093+
7094+ ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra ;
7095+ float * dst_ddf = (float *) dst_extra->data_device [g_main_device];
7096+
7097+ // convert src1 to fp16
7098+ const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda (src1->type );
7099+ GGML_ASSERT (to_fp16_cuda != nullptr );
7100+
7101+ size_t src1_as = 0 ;
7102+ half * src1_as_f16 = (half *) ggml_cuda_pool_malloc (ne1 * sizeof (half), &src1_as);
7103+ to_fp16_cuda (src1_ddf, src1_as_f16, ne1, main_stream);
7104+
7105+ size_t dst_as = 0 ;
7106+ half * dst_f16 = (half *) ggml_cuda_pool_malloc (ne * sizeof (half), &dst_as);
7107+
7108+ GGML_ASSERT (ne12 % ne02 == 0 );
7109+ GGML_ASSERT (ne13 % ne03 == 0 );
7110+
7111+ // broadcast factors
7112+ const int64_t r2 = ne12/ne02;
7113+ const int64_t r3 = ne13/ne03;
7114+
7115+ const half alpha_f16 = 1 .0f ;
7116+ const half beta_f16 = 0 .0f ;
7117+
7118+ #if 0
7119+ // use cublasGemmEx
7120+ {
7121+ for (int i13 = 0; i13 < ne13; ++i13) {
7122+ for (int i12 = 0; i12 < ne12; ++i12) {
7123+ int i03 = i13 / r3;
7124+ int i02 = i12 / r2;
7125+
7126+ CUBLAS_CHECK(
7127+ cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7128+ ne01, ne11, ne10,
7129+ &alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
7130+ (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
7131+ &beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
7132+ CUBLAS_COMPUTE_16F,
7133+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7134+ }
7135+ }
7136+ }
7137+ #else
7138+ if (r2 == 1 && r3 == 1 && src0->nb [2 ]*src0->ne [2 ] == src0->nb [3 ] && src1->nb [2 ]*src1->ne [2 ] == src1->nb [3 ]) {
7139+ // there is no broadcast and src0, src1 are contiguous across dims 2, 3
7140+ // use cublasGemmStridedBatchedEx
7141+ CUBLAS_CHECK (
7142+ cublasGemmStridedBatchedEx (g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7143+ ne01, ne11, ne10,
7144+ &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof (half), src0->nb [2 ]/sizeof (half), // strideA
7145+ (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof (float ), src1->nb [2 ]/sizeof (float ), // strideB
7146+ &beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb [2 ]/sizeof (float ), // strideC
7147+ ne12*ne13,
7148+ CUBLAS_COMPUTE_16F,
7149+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7150+ } else {
7151+ // use cublasGemmBatchedEx
7152+ // TODO: https://github.com/ggerganov/llama.cpp/pull/3749#discussion_r1369997000
7153+ const int ne23 = ne12*ne13;
7154+
7155+ // TODO: avoid this alloc
7156+ void ** ptrs = (void **) malloc (3 *ne23*sizeof (void *));
7157+
7158+ for (int i13 = 0 ; i13 < ne13; ++i13) {
7159+ for (int i12 = 0 ; i12 < ne12; ++i12) {
7160+ int i03 = i13 / r3;
7161+ int i02 = i12 / r2;
7162+
7163+ ptrs[0 *ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb [2 ] + i03*src0->nb [3 ];
7164+ ptrs[1 *ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb [2 ]/2 + i13*src1->nb [3 ]/2 ;
7165+ ptrs[2 *ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* dst->nb [2 ]/2 + i13* dst->nb [3 ]/2 ;
7166+ }
7167+ }
7168+
7169+ // allocate device memory for pointers
7170+ void ** ptrs_as = nullptr ;
7171+ CUDA_CHECK (cudaMalloc (&ptrs_as, 3 *ne23*sizeof (void *)));
7172+
7173+ // TODO: this does not work for some reason -- not sure why?
7174+ // size_t ptrs_s = 0;
7175+ // ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
7176+
7177+ // copy pointers to device
7178+ CUDA_CHECK (cudaMemcpy (ptrs_as, ptrs, 3 *ne23*sizeof (void *), cudaMemcpyHostToDevice));
7179+
7180+ free (ptrs);
7181+
7182+ CUBLAS_CHECK (
7183+ cublasGemmBatchedEx (g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7184+ ne01, ne11, ne10,
7185+ &alpha_f16, (const void **) (ptrs_as + 0 *ne23), CUDA_R_16F, nb01/sizeof (half),
7186+ (const void **) (ptrs_as + 1 *ne23), CUDA_R_16F, nb11/sizeof (float ),
7187+ &beta_f16, ( void **) (ptrs_as + 2 *ne23), CUDA_R_16F, ne01,
7188+ ne23,
7189+ CUBLAS_COMPUTE_16F,
7190+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7191+
7192+ // free device memory for pointers
7193+ CUDA_CHECK (cudaFree (ptrs_as));
7194+ // ggml_cuda_pool_free(ptrs_as, ptrs_s);
7195+ }
7196+ #endif
7197+
7198+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
7199+ to_fp32_cuda (dst_f16, dst_ddf, ne, main_stream);
7200+
7201+ ggml_cuda_pool_free (src1_as_f16, src1_as);
7202+ ggml_cuda_pool_free (dst_f16, dst_as);
7203+ }
7204+
70497205static void ggml_cuda_mul_mat (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
70507206 bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
70517207 src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
@@ -7058,10 +7214,22 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
70587214 }
70597215 }
70607216
7217+ // debug helpers
7218+ // printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
7219+ // printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
7220+ // printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
7221+ // printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
7222+ // printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
7223+ // printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
7224+
70617225 if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted (src0) && ggml_is_permuted (src1) && src1->ne [1 ] == 1 ) {
7226+ // KQ
70627227 ggml_cuda_mul_mat_vec_p021 (src0, src1, dst);
7063- } else if (all_on_device && !ggml_is_contiguous (src0) && ggml_is_contiguous (src1) && src1->ne [1 ] == 1 ) {
7228+ } else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous (src0) && !ggml_is_transposed (src1) && src1->ne [1 ] == 1 ) {
7229+ // KQV
70647230 ggml_cuda_mul_mat_vec_nc (src0, src1, dst);
7231+ } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]*src1->ne [3 ] > 1 ) {
7232+ ggml_cuda_mul_mat_mat_batched_cublas (src0, src1, dst);
70657233 } else if (src0->type == GGML_TYPE_F32) {
70667234 ggml_cuda_op_mul_mat (src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false );
70677235 } else if (ggml_is_quantized (src0->type ) || src0->type == GGML_TYPE_F16) {
0 commit comments