@@ -4489,6 +4489,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
44894489 *dsti = __float2half (*xi);
44904490}
44914491
4492+ static __device__ void cpy_1_f16_f16 (const char * cxi, char * cdsti) {
4493+ const half * xi = (const half *) cxi;
4494+ half * dsti = (half *) cdsti;
4495+
4496+ *dsti = *xi;
4497+ }
4498+
44924499template <cpy_kernel_t cpy_1>
44934500static __global__ void cpy_f32_f16 (const char * cx, char * cdst, const int ne,
44944501 const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@@ -4742,6 +4749,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
47424749 dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
47434750}
47444751
4752+ static __global__ void im2col_f32_f16 (
4753+ const float * x, half * dst,
4754+ int ofs0, int ofs1, int IW, int IH, int CHW,
4755+ int s0, int s1, int p0, int p1, int d0, int d1) {
4756+ const int iiw = blockIdx .z * s0 + threadIdx .z * d0 - p0;
4757+ const int iih = blockIdx .y * s1 + threadIdx .y * d1 - p1;
4758+
4759+ const int offset_dst =
4760+ (threadIdx .x * gridDim .y * gridDim .z + blockIdx .y * gridDim .z + blockIdx .z ) * CHW +
4761+ (blockIdx .x * (blockDim .y * blockDim .z ) + threadIdx .y * blockDim .z + threadIdx .z );
4762+
4763+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
4764+ dst[offset_dst] = __float2half (0 .0f );
4765+ } else {
4766+ const int offset_src = threadIdx .x * ofs0 + blockIdx .x * ofs1;
4767+ dst[offset_dst] = __float2half (x[offset_src + iih * IW + iiw]);
4768+ }
4769+ }
4770+
47454771template <int qk, int qr, dequantize_kernel_t dq>
47464772static void get_rows_cuda (const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
47474773 const dim3 block_dims (CUDA_GET_ROWS_BLOCK_SIZE, 1 , 1 );
@@ -5642,6 +5668,16 @@ static void ggml_cpy_f32_f16_cuda(
56425668 (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
56435669}
56445670
5671+ static void ggml_cpy_f16_f16_cuda (
5672+ const char * cx, char * cdst, const int ne,
5673+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
5674+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
5675+
5676+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
5677+ cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
5678+ (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
5679+ }
5680+
56455681static void scale_f32_cuda (const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
56465682 const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1 ) / CUDA_SCALE_BLOCK_SIZE;
56475683 scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0 , stream>>> (x, dst, scale, k);
@@ -5725,6 +5761,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
57255761 soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x);
57265762}
57275763
5764+ static void im2col_f32_f16_cuda (const float * x, half * dst,
5765+ int OH, int IW, int IH, int OW, int IC,
5766+ int KH, int KW, int N, int ofs0, int ofs1,
5767+ int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
5768+ dim3 block_nums (IC, OH, OW);
5769+ dim3 block_dims (N, KH, KW);
5770+ im2col_f32_f16<<<block_nums, block_dims, 0 , stream>>> (x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
5771+ }
5772+
57285773// buffer pool for cuda
57295774#define MAX_CUDA_BUFFERS 256
57305775
@@ -6522,8 +6567,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
65226567 src1_as_f16 = (half *) ggml_cuda_pool_malloc (ne * sizeof (half), &src1_as);
65236568 to_fp16_cuda (src1_ddf_i, src1_as_f16, ne, stream);
65246569 }
6525- const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
6526-
6570+ const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
65276571 size_t dst_as = 0 ;
65286572 half * dst_f16 = (half *) ggml_cuda_pool_malloc (row_diff*src1_ncols * sizeof (half), &dst_as);
65296573
@@ -6698,6 +6742,45 @@ inline void ggml_cuda_op_alibi(
66986742 (void ) src1_dd;
66996743}
67006744
6745+ inline void ggml_cuda_op_im2col (
6746+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6747+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6748+
6749+ GGML_ASSERT (src0->type == GGML_TYPE_F16);
6750+ GGML_ASSERT (src1->type == GGML_TYPE_F32);
6751+ GGML_ASSERT ( dst->type == GGML_TYPE_F16);
6752+
6753+ const int32_t s0 = ((const int32_t *)(dst->op_params ))[0 ];
6754+ const int32_t s1 = ((const int32_t *)(dst->op_params ))[1 ];
6755+ const int32_t p0 = ((const int32_t *)(dst->op_params ))[2 ];
6756+ const int32_t p1 = ((const int32_t *)(dst->op_params ))[3 ];
6757+ const int32_t d0 = ((const int32_t *)(dst->op_params ))[4 ];
6758+ const int32_t d1 = ((const int32_t *)(dst->op_params ))[5 ];
6759+
6760+ const bool is_2D = ((const int32_t *)(dst->op_params ))[6 ] == 1 ;
6761+
6762+ const int64_t N = src1->ne [is_2D ? 3 : 2 ];
6763+ const int64_t IC = src1->ne [is_2D ? 2 : 1 ];
6764+ const int64_t IH = is_2D ? src1->ne [1 ] : 1 ;
6765+ const int64_t IW = src1->ne [0 ];
6766+
6767+ const int64_t KH = is_2D ? src0->ne [1 ] : 1 ;
6768+ const int64_t KW = src0->ne [0 ];
6769+
6770+ const int64_t OH = is_2D ? dst->ne [2 ] : 1 ;
6771+ const int64_t OW = dst->ne [1 ];
6772+
6773+ const size_t ofs0 = src1->nb [is_2D ? 3 : 2 ] / 4 ; // nb is byte offset, src is type float32
6774+ const size_t ofs1 = src1->nb [is_2D ? 2 : 1 ] / 4 ; // nb is byte offset, src is type float32
6775+
6776+ im2col_f32_f16_cuda (src1_dd, (half*) dst_dd,
6777+ OH, IW, IH, OW, IC, KH, KW, N,
6778+ ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
6779+
6780+ (void ) src0;
6781+ (void ) src0_dd;
6782+ }
6783+
67016784inline void ggml_cuda_op_diag_mask_inf (
67026785 const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
67036786 const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7610,6 +7693,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
76107693 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
76117694 ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
76127695 ne10, ne11, nb10, nb11, nb12, main_stream);
7696+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
7697+ ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
7698+ ne10, ne11, nb10, nb11, nb12, main_stream);
76137699 } else {
76147700 fprintf (stderr, " %s: unsupported type combination (%s to %s)\n " , __func__,
76157701 ggml_type_name (src0->type ), ggml_type_name (src1->type ));
@@ -7641,6 +7727,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
76417727 ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_alibi);
76427728}
76437729
7730+ void ggml_cuda_im2col (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7731+ ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_im2col);
7732+ }
7733+
76447734static void ggml_cuda_nop (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
76457735 (void ) src0;
76467736 (void ) src1;
@@ -7934,6 +8024,15 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
79348024 return false ;
79358025 }
79368026
8027+ if (tensor->op == GGML_OP_MUL_MAT) {
8028+ if (tensor->src [0 ]->ne [3 ] != tensor->src [1 ]->ne [3 ]) {
8029+ #ifndef NDEBUG
8030+ fprintf (stderr, " %s: cannot compute %s: src0->ne[3] = %d, src1->ne[3] = %d - fallback to CPU\n " , __func__, tensor->name , tensor->src [0 ]->ne [3 ], tensor->src [1 ]->ne [3 ]);
8031+ #endif
8032+ return false ;
8033+ }
8034+ }
8035+
79378036 switch (tensor->op ) {
79388037 case GGML_OP_REPEAT:
79398038 func = ggml_cuda_repeat;
@@ -8012,6 +8111,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
80128111 case GGML_OP_ALIBI:
80138112 func = ggml_cuda_alibi;
80148113 break ;
8114+ case GGML_OP_IM2COL:
8115+ func = ggml_cuda_im2col;
8116+ break ;
80158117 default :
80168118 return false ;
80178119 }
0 commit comments