@@ -1865,13 +1865,24 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18651865 // use cublasGemmBatchedEx
18661866 const int64_t ne23 = ne12*ne13;
18671867
1868+ #ifdef GGML_USE_MUSA
1869+ const void ** ptrs_src;
1870+ void ** ptrs_dst;
1871+ CUDA_CHECK (cudaMalloc ((void **)&ptrs_src, sizeof (void *)*2 *ne23));
1872+ CUDA_CHECK (cudaMalloc ((void **)&ptrs_dst, sizeof (void *)*1 *ne23));
1873+ #else // GGML_USE_MUSA
18681874 ggml_cuda_pool_alloc<const void *> ptrs_src (ctx.pool (), 2 *ne23);
18691875 ggml_cuda_pool_alloc< void *> ptrs_dst (ctx.pool (), 1 *ne23);
1876+ #endif // GGML_USE_MUSA
18701877
18711878 dim3 block_dims (ne13, ne12);
18721879 k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
18731880 src0_f16, src1_f16, dst_t ,
1881+ #ifdef GGML_USE_MUSA
1882+ ptrs_src, ptrs_dst,
1883+ #else // GGML_USE_MUSA
18741884 ptrs_src.get (), ptrs_dst.get (),
1885+ #endif // GGML_USE_MUSA
18751886 ne12, ne13,
18761887 ne23,
18771888 nb02, nb03,
@@ -1881,15 +1892,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18811892 r2, r3);
18821893 CUDA_CHECK (cudaGetLastError ());
18831894
1884- CUBLAS_CHECK (
1895+ #ifdef GGML_USE_MUSA
1896+ cudaDeviceSynchronize ();
1897+ const void **Aarray = (const void **) (ptrs_src + 0 *ne23);
1898+ const void **Barray = (const void **) (ptrs_src + 1 *ne23);
1899+ void **Carray = ( void **) (ptrs_dst + 0 *ne23);
1900+ #else // GGML_USE_MUSA
1901+ const void **Aarray = (const void **) (ptrs_src.get () + 0 *ne23);
1902+ const void **Barray = (const void **) (ptrs_src.get () + 1 *ne23);
1903+ void **Carray = ( void **) (ptrs_dst.get () + 0 *ne23);
1904+ #endif // GGML_USE_MUSA
1905+
1906+ CUBLAS_CHECK (
18851907 cublasGemmBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
18861908 ne01, ne11, ne10,
1887- alpha, ( const void **) (ptrs_src. get () + 0 *ne23) , CUDA_R_16F, nb01/nb00,
1888- ( const void **) (ptrs_src. get () + 1 *ne23) , CUDA_R_16F, s11,
1889- beta, ( void **) (ptrs_dst. get () + 0 *ne23) , cu_data_type, ne0,
1909+ alpha, Aarray , CUDA_R_16F, nb01/nb00,
1910+ Barray , CUDA_R_16F, s11,
1911+ beta, Carray , cu_data_type, ne0,
18901912 ne23,
18911913 cu_compute_type,
18921914 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1915+
1916+ #ifdef GGML_USE_MUSA
1917+ CUDA_CHECK (cudaFree (ptrs_src));
1918+ CUDA_CHECK (cudaFree (ptrs_dst));
1919+ #endif // GGML_USE_MUSA
18931920 }
18941921#endif
18951922
@@ -2989,12 +3016,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
29893016 if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
29903017 return false ;
29913018 }
2992- #ifdef GGML_USE_MUSA
2993- if (b->type == GGML_TYPE_F16 && b->ne [2 ]*b->ne [3 ] > 1 &&
2994- !ggml_is_transposed (a) && !ggml_is_transposed (b)) {
2995- return false ;
2996- }
2997- #endif // GGML_USE_MUSA
3019+ // #ifdef GGML_USE_MUSA
3020+ // if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
3021+ // !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3022+ // return false;
3023+ // }
3024+ // #endif // GGML_USE_MUSA
29983025 switch (a->type ) {
29993026 case GGML_TYPE_F32:
30003027 case GGML_TYPE_F16:
@@ -3019,11 +3046,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30193046 case GGML_TYPE_IQ4_NL:
30203047 case GGML_TYPE_IQ4_XS:
30213048 case GGML_TYPE_BF16:
3022- #ifdef GGML_USE_MUSA
3023- if (a->type == GGML_TYPE_Q3_K) {
3024- return false ;
3025- }
3026- #endif // GGML_USE_MUSA
3049+ // #ifdef GGML_USE_MUSA
3050+ // if (a->type == GGML_TYPE_Q3_K) {
3051+ // return false;
3052+ // }
3053+ // #endif // GGML_USE_MUSA
30273054 return true ;
30283055 default :
30293056 return false ;
0 commit comments