1515// #define IGPU_TRACE(...) std::cout << "#> ggml-igpu: " << __VA_ARGS__ << std::endl
1616#define IGPU_TRACE (...)
1717
18+ #define BLOC_V1
19+ #ifdef BLOC_V1
1820#include " mulmat-bf16bloc.h"
21+ #endif
1922
2023/*
2124#> version bloc-bf16 V0.
3639
3740namespace ggml ::backend::igpu {
3841
39- // taille de repacking:
40- static constexpr std::size_t BLOC_M0 = 16 ;
41- // static constexpr std::size_t BLOC_N0 = 16;
42- static constexpr std::size_t BLOC_K0 = 16 ;
43- static constexpr std::size_t BLOC_K1 = 512 ;
44-
4542 static bool IS_WEIGHT = true ;
4643 static bool IS_OTHER = true ;
4744 enum class BUFFER_TYPE {
@@ -126,39 +123,8 @@ namespace ggml::backend::igpu {
126123 const auto la = tensor.nb [1 ]/tensor.nb [0 ];
127124 bfloat16_t * ref = (bfloat16_t *)data;
128125 bfloat16_t * bloc = (bfloat16_t *)tensor.data ;
129- // TODO: @ optimiser...
130- /*
131- //# pragma omp parallel for num_threads(4)
132- # pragma omp parallel for
133- for (std::size_t i=0; i<M; i++) {
134- for (std::size_t k=0; k<K; k++) {
135- bloc[posBloc2D<BLOC_K1,BLOC_M0,1,TYPE_BLOC::PERFECT>(K, M, k, i)] = ref[pos2D(la, M, k, i)];
136- }
137- }
138- */
139- // Ca sera important quand on fera le codage en fp8...
140- # pragma omp parallel for num_threads(2) collapse(2) // private(tmp)
141- for (std::size_t k2=0 ; k2<K; k2+=BLOC_K1) {
142- for (std::size_t i1=0 ; i1<M; i1+=BLOC_M0) {
143- for (std::size_t k1=0 ; k1<BLOC_K1; k1+=16 ) {
144- bfloat16_t tmp[16 ][16 ];
145- for (std::size_t i0=0 ; i0<BLOC_M0; i0++) {
146- # pragma omp simd
147- for (std::size_t k0=0 ; k0<16 ; k0++) {
148- tmp[i0][k0] = ref[pos2D (la, M, k2+k1+k0, i1+i0)];
149- }
150- }
151- for (std::size_t k0=0 ; k0<16 ; k0++) {
152- # pragma omp simd
153- for (std::size_t i0=0 ; i0<BLOC_M0; i0++) {
154- bloc[posBloc2D<BLOC_K1,BLOC_M0,1 ,TYPE_BLOC::PERFECT>(K, M, k2+k1+k0, i1+i0)] = tmp[i0][k0];
155- }
156- }
157- // bloc[posBloc2D<BLOC_K1,BLOC_M0,1,TYPE_BLOC::PERFECT>(K, M, k, i)] = ref[pos2D(la, M, k, i)];
158- }
159- }
160- }
161126
127+ op_mul_mat::repack (ref, la, bloc, M, K);
162128 }
163129 void get_tensor (const ggml_tensor & tensor, void * data, std::size_t offset, std::size_t size) override {
164130 const auto K = tensor.ne [0 ];
@@ -252,10 +218,7 @@ namespace ggml::backend::igpu {
252218 ggml::cpp::backend::backend (dev), m_deviceId(deviceId)
253219 {
254220 IGPU_TRACE (" backend[" << get_name () << " ]: create <" << params << " >" );
255- // if (B_cache.ensure_size(K1 * block_size<N0*N1>(N) * M4)) {
256- if (B_cache.ensure_size (BLOC_K1 * 1024 )) { // @ optimiser la taille de N < 768 * 16
257- IGPU_TRACE (" B_cache[" << BLOC_K1 << " , " << 1024 <<" ]" );
258- }
221+ op_mul_mat::init_caches ();
259222 }
260223
261224 virtual ~backend () {
@@ -312,95 +275,9 @@ namespace ggml::backend::igpu {
312275 GGML_ASSERT (B->nb [0 ] == sizeof (float32_t ));
313276 GGML_ASSERT (C->nb [0 ] == sizeof (float32_t ));
314277 GGML_ASSERT (K % 16 == 0 );
315- GGML_ASSERT (K % BLOC_K1 == 0 );
316278 GGML_ASSERT (M % (4 *2 *16 ) == 0 ); // pas une contraite forte mais plus simple pour l'instant.
317- // TODO: voir comment decouper ca
318- // > ggml::backend::igpu::sgemm_wmma<M1,N1,M2,M4,K1>(a1,b1,c1, M,N,K, K,M);
319- // if (N==1) {} else
320- if (N == 0 ) {
321- std::cout << " BUG? " << N << " / " << A->name << " | "
322- << A->ne [0 ] << " ," << A->ne [1 ] << " "
323- << B->ne [0 ] << " ," << B->ne [1 ] << " "
324- << C->ne [0 ] << " ," << C->ne [1 ] << " "
325- << std::endl;
326- } else
327- if (N<=16 ) {
328- if (M%(4 *2 *16 *16 )==0 ) { // M=2048
329- sgemm_wmma<4 ,1 ,2 ,16 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
330- } else if (M%(4 *2 *16 *8 )==0 ) { // M=1024
331- sgemm_wmma<4 ,1 ,2 ,8 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
332- } else if (M%(4 *2 *16 *4 )==0 ) { // M=512
333- sgemm_wmma<4 ,1 ,2 ,4 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
334- } else if (M%(4 *2 *16 *2 )==0 ) { // M=256
335- sgemm_wmma<4 ,1 ,2 ,2 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
336- } else if (M%(4 *2 *16 *1 )==0 ) { // M=128
337- sgemm_wmma<4 ,1 ,2 ,1 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
338- } else { // est-ce que l'on fait les cas 3,5,6,7,... ?
339- // on va s'arreter la pour l'instant:
340- }
341- } else if (N<=32 ) {
342- if (M%(4 *2 *16 *16 )==0 ) { // M=2048
343- sgemm_wmma<4 ,2 ,2 ,16 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
344- } else if (M%(4 *2 *16 *8 )==0 ) { // M=1024
345- sgemm_wmma<4 ,2 ,2 ,8 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
346- } else if (M%(4 *2 *16 *4 )==0 ) { // M=512
347- sgemm_wmma<4 ,2 ,2 ,4 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
348- } else if (M%(4 *2 *16 *2 )==0 ) { // M=256
349- sgemm_wmma<4 ,2 ,2 ,2 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
350- } else if (M%(4 *2 *16 *1 )==0 ) { // M=128
351- sgemm_wmma<4 ,2 ,2 ,1 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
352- } else { // est-ce que l'on fait les cas 3,5,6,7,... ?
353- // on va s'arreter la pour l'instant:
354- }
355- } else if (N<=48 ) { // 3 blocs pour N => 4CU / M
356- if (M%(4 *2 *16 *8 )==0 ) { // M=1024
357- sgemm_wmma<4 ,1 ,2 ,8 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
358- } else if (M%(4 *2 *16 *4 )==0 ) { // M=512
359- sgemm_wmma<4 ,1 ,2 ,4 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
360- } else if (M%(4 *2 *16 *2 )==0 ) { // M=256
361- sgemm_wmma<4 ,1 ,2 ,2 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
362- } else if (M%(4 *2 *16 *1 )==0 ) { // M=128
363- sgemm_wmma<4 ,1 ,2 ,1 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
364- } else { // est-ce que l'on fait les cas 3,5,6,7,... ?
365- // on va s'arreter la pour l'instant:
366- }
367- } else if (N<=64 ) { // N1=2 => 2CU/N => 6 restant
368- if (M%(4 *2 *16 *8 )==0 ) { // M=1024
369- sgemm_wmma<4 ,2 ,2 ,8 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
370- } else if (M%(4 *2 *16 *4 )==0 ) { // M=512
371- sgemm_wmma<4 ,2 ,2 ,4 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
372- } else if (M%(4 *2 *16 *2 )==0 ) { // M=256
373- sgemm_wmma<4 ,2 ,2 ,2 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
374- } else if (M%(4 *2 *16 *1 )==0 ) { // M=128
375- sgemm_wmma<4 ,2 ,2 ,1 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
376- } else { // est-ce que l'on fait les cas 3,5,6,7,... ?
377- // on va s'arreter la pour l'instant:
378- }
379- } else if (N<=192 ) {
380- if (M%(4 *2 *16 *4 )==0 ) { // M=512
381- sgemm_wmma<4 ,2 ,2 ,4 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
382- } else if (M%(4 *2 *16 *2 )==0 ) { // M=256
383- sgemm_wmma<4 ,2 ,2 ,2 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
384- } else if (M%(4 *2 *16 *1 )==0 ) { // M=128
385- sgemm_wmma<4 ,2 ,2 ,1 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
386- } else { // est-ce que l'on fait les cas 3,5,6,7,... ?
387- // on va s'arreter la pour l'instant:
388- }
389- } else if (N<=384 ) {
390- if (M%(4 *2 *16 *2 )==0 ) { // M=256
391- sgemm_wmma<4 ,2 ,2 ,2 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
392- } else if (M%(4 *2 *16 *1 )==0 ) { // M=128
393- sgemm_wmma<4 ,2 ,2 ,1 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
394- } else { // est-ce que l'on fait les cas 3,5,6,7,... ?
395- // on va s'arreter la pour l'instant:
396- }
397- } else {
398- if (M%(4 *2 *16 *1 )==0 ) { // M=128
399- sgemm_wmma<4 ,2 ,2 ,1 ,BLOC_K1>((bfloat16_t *)A->data , (float32_t *)B->data , (float32_t *)C->data , M,N,K, lb,lc);
400- } else { // est-ce que l'on fait les cas 3,5,6,7,... ?
401- // on va s'arreter la pour l'instant:
402- }
403- }
279+
280+ op_mul_mat::compute ((const bfloat16_t *)A->data , (const float32_t *)B->data , (float32_t *)C->data , M,N,K, la,lb,lc);
404281 // matmul_ref((bfloat16_t*)A->data, (float32_t*)B->data, (float32_t*)C->data, M,N,K, la,lb,lc);
405282 }
406283 }
@@ -428,9 +305,6 @@ namespace ggml::backend::igpu {
428305 buffer_type* m_extra_buffer_type;
429306 buffer_type* m_device_buffer_type;
430307 buffer_type* m_host_buffer_type;
431- std::size_t K_MAX = 0 ;
432- std::size_t M_MAX = 0 ;
433- std::size_t N_MAX = 768 ; // taille optimale pour 780...
434308
435309 public:
436310 device (const std::string& name, int deviceId, const std::string& desc = " ..." ) : m_name(name), m_desc(desc), m_id(deviceId) {
@@ -516,7 +390,8 @@ namespace ggml::backend::igpu {
516390 case GGML_OP_MUL_MAT:
517391 {
518392 const struct ggml_tensor * A = op.src [0 ]; // les poids
519- const struct ggml_tensor * B = op.src [1 ]; // le
393+ const struct ggml_tensor * B = op.src [1 ]; // l'entrée
394+ const struct ggml_tensor * C = &op; // la sortie
520395
521396 if (!ggml_is_contiguous (A)) return false ;
522397 if (!ggml_is_contiguous (B)) return false ;
@@ -530,34 +405,7 @@ namespace ggml::backend::igpu {
530405 // return true;
531406 return false ;
532407 case GGML_TYPE_BF16:
533- // qq limites...
534- if (A->ne [0 ]*A->ne [1 ] >= 0x80000000 ) {
535- IGPU_TRACE ( op.name << " (" << A->name << " ) K*M trop grand: " << A->ne [0 ] << " , " << A->ne [1 ]);
536- return false ;
537- }
538- if ((A->ne [0 ] % BLOC_K1) != 0 ) {
539- IGPU_TRACE ( op.name << " : K non supporte : " << BLOC_K1 << " /" << A->ne [1 ]);
540- return false ;
541- }
542- if ((A->ne [1 ] % (4 *2 *16 )) != 0 ) {
543- IGPU_TRACE ( op.name << " : M non supporte : " << 4 *2 *16 << " /" << A->ne [2 ]);
544- return false ;
545- }
546- // TODO: memoriser les tailles max de M et K pour allocation des caches
547- if (K_MAX<A->ne [0 ]) {
548- K_MAX=A->ne [0 ];
549- IGPU_TRACE (" K_MAX: " << K_MAX);
550- }
551- if (M_MAX<A->ne [1 ]) {
552- M_MAX=A->ne [1 ];
553- IGPU_TRACE (" M_MAX: " << M_MAX);
554- }
555- // ordre:
556- // - supports_op
557- // - init_tensor
558- // - set_tensor
559- // GGML_LOG_INFO("ggml-igpu: MATMUL(%s): supported!\n", A->name);
560- return true ;
408+ return op_mul_mat::supported<bfloat16_t , float32_t , float32_t >(*A,*B,*C);
561409 default :
562410 return false ;
563411 }
0 commit comments