@@ -4470,6 +4470,65 @@ void ggml_compute_forward_get_rows(
44704470 // }
44714471}
44724472
4473+ static void ggml_compute_forward_set_rows_f32 (
4474+ const ggml_compute_params * params,
4475+ ggml_tensor * dst) {
4476+
4477+ const ggml_tensor * src0 = dst->src [0 ];
4478+ const ggml_tensor * src1 = dst->src [1 ];
4479+
4480+ GGML_TENSOR_BINARY_OP_LOCALS
4481+
4482+ const int64_t nc = ne00;
4483+ const int64_t nr = ggml_nelements (src1);
4484+
4485+ assert (ne0 == nc);
4486+ assert (ne02 == ne11);
4487+ assert (nb00 == sizeof (float ));
4488+ assert (ggml_nrows (src0) == nr);
4489+
4490+ const int ith = params->ith ;
4491+ const int nth = params->nth ;
4492+
4493+ // rows per thread
4494+ const int dr = (nr + nth - 1 )/nth;
4495+
4496+ // row range for this thread
4497+ const int ir0 = dr*ith;
4498+ const int ir1 = MIN (ir0 + dr, nr);
4499+
4500+ for (int64_t i = ir0; i < ir1; ++i) {
4501+ const int64_t i12 = i/(ne11*ne10);
4502+ const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4503+ const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4504+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4505+
4506+ GGML_ASSERT (i01 >= 0 && i01 < ne1);
4507+
4508+ ggml_cpu_fp32_to_fp16 (
4509+ (const float *) ((char *) src0->data + i10*nb01 + i11*nb02 + i12*nb03),
4510+ (ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i11*nb2 + i12*nb3), nc);
4511+ }
4512+ }
4513+
4514+ void ggml_compute_forward_set_rows (
4515+ const ggml_compute_params * params,
4516+ ggml_tensor * dst) {
4517+
4518+ const ggml_tensor * src0 = dst->src [0 ];
4519+
4520+ switch (src0->type ) {
4521+ case GGML_TYPE_F32:
4522+ {
4523+ ggml_compute_forward_set_rows_f32 (params, dst);
4524+ } break ;
4525+ default :
4526+ {
4527+ GGML_ABORT (" fatal error" );
4528+ }
4529+ }
4530+ }
4531+
44734532// ggml_compute_forward_get_rows_back
44744533
44754534static void ggml_compute_forward_get_rows_back_f32_f16 (
0 commit comments