Skip to content

Commit 0671a16

Browse files
committed
un peu de refacto
1 parent 1827499 commit 0671a16

File tree

3 files changed

+208
-163
lines changed

3 files changed

+208
-163
lines changed

ggml/src/ggml-igpu/ggml-igpu.cpp

Lines changed: 10 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
//#define IGPU_TRACE(...) std::cout << "#> ggml-igpu: " << __VA_ARGS__ << std::endl
1616
#define IGPU_TRACE(...)
1717

18+
#define BLOC_V1
19+
#ifdef BLOC_V1
1820
#include "mulmat-bf16bloc.h"
21+
#endif
1922

2023
/*
2124
#> version bloc-bf16 V0.
@@ -36,12 +39,6 @@
3639

3740
namespace ggml::backend::igpu {
3841

39-
// taille de repacking:
40-
static constexpr std::size_t BLOC_M0 = 16;
41-
// static constexpr std::size_t BLOC_N0 = 16;
42-
static constexpr std::size_t BLOC_K0 = 16;
43-
static constexpr std::size_t BLOC_K1 = 512;
44-
4542
static bool IS_WEIGHT = true;
4643
static bool IS_OTHER = true;
4744
enum class BUFFER_TYPE {
@@ -126,39 +123,8 @@ namespace ggml::backend::igpu {
126123
const auto la = tensor.nb[1]/tensor.nb[0];
127124
bfloat16_t* ref = (bfloat16_t*)data;
128125
bfloat16_t* bloc = (bfloat16_t*)tensor.data;
129-
// TODO: @ optimiser...
130-
/*
131-
//# pragma omp parallel for num_threads(4)
132-
# pragma omp parallel for
133-
for (std::size_t i=0; i<M; i++) {
134-
for (std::size_t k=0; k<K; k++) {
135-
bloc[posBloc2D<BLOC_K1,BLOC_M0,1,TYPE_BLOC::PERFECT>(K, M, k, i)] = ref[pos2D(la, M, k, i)];
136-
}
137-
}
138-
*/
139-
// Ca sera important quand on fera le codage en fp8...
140-
# pragma omp parallel for num_threads(2) collapse(2) // private(tmp)
141-
for (std::size_t k2=0; k2<K; k2+=BLOC_K1) {
142-
for (std::size_t i1=0; i1<M; i1+=BLOC_M0) {
143-
for (std::size_t k1=0; k1<BLOC_K1; k1+=16) {
144-
bfloat16_t tmp[16][16];
145-
for (std::size_t i0=0; i0<BLOC_M0; i0++) {
146-
# pragma omp simd
147-
for (std::size_t k0=0; k0<16; k0++) {
148-
tmp[i0][k0] = ref[pos2D(la, M, k2+k1+k0, i1+i0)];
149-
}
150-
}
151-
for (std::size_t k0=0; k0<16; k0++) {
152-
# pragma omp simd
153-
for (std::size_t i0=0; i0<BLOC_M0; i0++) {
154-
bloc[posBloc2D<BLOC_K1,BLOC_M0,1,TYPE_BLOC::PERFECT>(K, M, k2+k1+k0, i1+i0)] = tmp[i0][k0];
155-
}
156-
}
157-
// bloc[posBloc2D<BLOC_K1,BLOC_M0,1,TYPE_BLOC::PERFECT>(K, M, k, i)] = ref[pos2D(la, M, k, i)];
158-
}
159-
}
160-
}
161126

127+
op_mul_mat::repack(ref, la, bloc, M, K);
162128
}
163129
void get_tensor(const ggml_tensor & tensor, void * data, std::size_t offset, std::size_t size) override {
164130
const auto K = tensor.ne[0];
@@ -252,10 +218,7 @@ namespace ggml::backend::igpu {
252218
ggml::cpp::backend::backend(dev), m_deviceId(deviceId)
253219
{
254220
IGPU_TRACE("backend[" << get_name() << "]: create <" << params << ">");
255-
// if (B_cache.ensure_size(K1 * block_size<N0*N1>(N) * M4)) {
256-
if (B_cache.ensure_size(BLOC_K1 * 1024)) { // @ optimiser la taille de N < 768 * 16
257-
IGPU_TRACE("B_cache[" << BLOC_K1 << ", " << 1024 <<"]");
258-
}
221+
op_mul_mat::init_caches();
259222
}
260223

261224
virtual ~backend() {
@@ -312,95 +275,9 @@ namespace ggml::backend::igpu {
312275
GGML_ASSERT(B->nb[0] == sizeof(float32_t));
313276
GGML_ASSERT(C->nb[0] == sizeof(float32_t));
314277
GGML_ASSERT(K % 16 == 0);
315-
GGML_ASSERT(K % BLOC_K1 == 0);
316278
GGML_ASSERT(M % (4*2*16) == 0); // pas une contraite forte mais plus simple pour l'instant.
317-
// TODO: voir comment decouper ca
318-
//> ggml::backend::igpu::sgemm_wmma<M1,N1,M2,M4,K1>(a1,b1,c1, M,N,K, K,M);
319-
// if (N==1) {} else
320-
if (N == 0) {
321-
std::cout << "BUG? " << N << " / " << A->name << " | "
322-
<< A->ne[0] << "," << A->ne[1] << " "
323-
<< B->ne[0] << "," << B->ne[1] << " "
324-
<< C->ne[0] << "," << C->ne[1] << " "
325-
<< std::endl;
326-
} else
327-
if (N<=16) {
328-
if (M%(4*2*16*16)==0) { // M=2048
329-
sgemm_wmma<4,1,2,16,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
330-
} else if (M%(4*2*16*8)==0) { // M=1024
331-
sgemm_wmma<4,1,2,8,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
332-
} else if (M%(4*2*16*4)==0) { // M=512
333-
sgemm_wmma<4,1,2,4,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
334-
} else if (M%(4*2*16*2)==0) { // M=256
335-
sgemm_wmma<4,1,2,2,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
336-
} else if (M%(4*2*16*1)==0) { // M=128
337-
sgemm_wmma<4,1,2,1,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
338-
} else { // est-ce que l'on fait les cas 3,5,6,7,... ?
339-
// on va s'arreter la pour l'instant:
340-
}
341-
} else if (N<=32) {
342-
if (M%(4*2*16*16)==0) { // M=2048
343-
sgemm_wmma<4,2,2,16,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
344-
} else if (M%(4*2*16*8)==0) { // M=1024
345-
sgemm_wmma<4,2,2,8,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
346-
} else if (M%(4*2*16*4)==0) { // M=512
347-
sgemm_wmma<4,2,2,4,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
348-
} else if (M%(4*2*16*2)==0) { // M=256
349-
sgemm_wmma<4,2,2,2,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
350-
} else if (M%(4*2*16*1)==0) { // M=128
351-
sgemm_wmma<4,2,2,1,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
352-
} else { // est-ce que l'on fait les cas 3,5,6,7,... ?
353-
// on va s'arreter la pour l'instant:
354-
}
355-
} else if (N<=48) { // 3 blocs pour N => 4CU / M
356-
if (M%(4*2*16*8)==0) { // M=1024
357-
sgemm_wmma<4,1,2,8,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
358-
} else if (M%(4*2*16*4)==0) { // M=512
359-
sgemm_wmma<4,1,2,4,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
360-
} else if (M%(4*2*16*2)==0) { // M=256
361-
sgemm_wmma<4,1,2,2,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
362-
} else if (M%(4*2*16*1)==0) { // M=128
363-
sgemm_wmma<4,1,2,1,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
364-
} else { // est-ce que l'on fait les cas 3,5,6,7,... ?
365-
// on va s'arreter la pour l'instant:
366-
}
367-
} else if (N<=64) { // N1=2 => 2CU/N => 6 restant
368-
if (M%(4*2*16*8)==0) { // M=1024
369-
sgemm_wmma<4,2,2,8,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
370-
} else if (M%(4*2*16*4)==0) { // M=512
371-
sgemm_wmma<4,2,2,4,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
372-
} else if (M%(4*2*16*2)==0) { // M=256
373-
sgemm_wmma<4,2,2,2,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
374-
} else if (M%(4*2*16*1)==0) { // M=128
375-
sgemm_wmma<4,2,2,1,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
376-
} else { // est-ce que l'on fait les cas 3,5,6,7,... ?
377-
// on va s'arreter la pour l'instant:
378-
}
379-
} else if (N<=192) {
380-
if (M%(4*2*16*4)==0) { // M=512
381-
sgemm_wmma<4,2,2,4,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
382-
} else if (M%(4*2*16*2)==0) { // M=256
383-
sgemm_wmma<4,2,2,2,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
384-
} else if (M%(4*2*16*1)==0) { // M=128
385-
sgemm_wmma<4,2,2,1,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
386-
} else { // est-ce que l'on fait les cas 3,5,6,7,... ?
387-
// on va s'arreter la pour l'instant:
388-
}
389-
} else if (N<=384) {
390-
if (M%(4*2*16*2)==0) { // M=256
391-
sgemm_wmma<4,2,2,2,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
392-
} else if (M%(4*2*16*1)==0) { // M=128
393-
sgemm_wmma<4,2,2,1,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
394-
} else { // est-ce que l'on fait les cas 3,5,6,7,... ?
395-
// on va s'arreter la pour l'instant:
396-
}
397-
} else {
398-
if (M%(4*2*16*1)==0) { // M=128
399-
sgemm_wmma<4,2,2,1,BLOC_K1>((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, lb,lc);
400-
} else { // est-ce que l'on fait les cas 3,5,6,7,... ?
401-
// on va s'arreter la pour l'instant:
402-
}
403-
}
279+
280+
op_mul_mat::compute((const bfloat16_t*)A->data, (const float32_t*)B->data, (float32_t*)C->data, M,N,K, la,lb,lc);
404281
//matmul_ref((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, la,lb,lc);
405282
}
406283
}
@@ -428,9 +305,6 @@ namespace ggml::backend::igpu {
428305
buffer_type* m_extra_buffer_type;
429306
buffer_type* m_device_buffer_type;
430307
buffer_type* m_host_buffer_type;
431-
std::size_t K_MAX = 0;
432-
std::size_t M_MAX = 0;
433-
std::size_t N_MAX = 768; // taille optimale pour 780...
434308

435309
public:
436310
device(const std::string& name, int deviceId, const std::string& desc = "...") : m_name(name), m_desc(desc), m_id(deviceId) {
@@ -516,7 +390,8 @@ namespace ggml::backend::igpu {
516390
case GGML_OP_MUL_MAT:
517391
{
518392
const struct ggml_tensor * A = op.src[0]; // les poids
519-
const struct ggml_tensor * B = op.src[1]; // le
393+
const struct ggml_tensor * B = op.src[1]; // l'entrée
394+
const struct ggml_tensor * C = &op; // la sortie
520395

521396
if (!ggml_is_contiguous(A)) return false;
522397
if (!ggml_is_contiguous(B)) return false;
@@ -530,34 +405,7 @@ namespace ggml::backend::igpu {
530405
//return true;
531406
return false;
532407
case GGML_TYPE_BF16:
533-
// qq limites...
534-
if (A->ne[0]*A->ne[1] >= 0x80000000) {
535-
IGPU_TRACE( op.name << "(" << A->name << ") K*M trop grand: " << A->ne[0] << ", " << A->ne[1]);
536-
return false;
537-
}
538-
if ((A->ne[0] % BLOC_K1) != 0) {
539-
IGPU_TRACE( op.name << ": K non supporte : " << BLOC_K1 << "/" << A->ne[1]);
540-
return false;
541-
}
542-
if ((A->ne[1] % (4*2*16)) != 0) {
543-
IGPU_TRACE( op.name << ": M non supporte : " << 4*2*16 << "/" << A->ne[2]);
544-
return false;
545-
}
546-
// TODO: memoriser les tailles max de M et K pour allocation des caches
547-
if (K_MAX<A->ne[0]) {
548-
K_MAX=A->ne[0];
549-
IGPU_TRACE("K_MAX: " << K_MAX);
550-
}
551-
if (M_MAX<A->ne[1]) {
552-
M_MAX=A->ne[1];
553-
IGPU_TRACE("M_MAX: " << M_MAX);
554-
}
555-
// ordre:
556-
// - supports_op
557-
// - init_tensor
558-
// - set_tensor
559-
// GGML_LOG_INFO("ggml-igpu: MATMUL(%s): supported!\n", A->name);
560-
return true;
408+
return op_mul_mat::supported<bfloat16_t, float32_t, float32_t>(*A,*B,*C);
561409
default:
562410
return false;
563411
}

0 commit comments

Comments
 (0)