@@ -109,13 +109,13 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
109109#define NUM_WARPS (BLOCK_SIZE / WARP)
110110
111111#ifdef MUL_MAT_ID
112- shared u16vec2 row_ids[4096 ];
112+ shared u16vec2 row_ids[BN ];
113113uint _ne1;
114114
115115#ifdef MUL_MAT_ID_USE_SUBGROUPS
116116shared uvec4 ballots_sh[NUM_WARPS];
117117
118- void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
118+ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic ) {
119119 _ne1 = 0;
120120 uint num_elements = p.nei1 * p.nei0;
121121 uint nei0shift = findLSB(p.nei0);
@@ -165,11 +165,14 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
165165 barrier();
166166
167167 uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
168- if (in_range && id == expert_idx) {
169- row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
168+ if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN ) {
169+ row_ids[_ne1 + idx - ic * BN ] = u16vec2(ii0, ii1);
170170 }
171171 _ne1 += total;
172172 iter &= 15;
173+ if (_ne1 >= (ic + 1) * BN) {
174+ break;
175+ }
173176 }
174177 barrier();
175178}
@@ -242,16 +245,18 @@ void main() {
242245#ifdef MUL_MAT_ID
243246#ifdef MUL_MAT_ID_USE_SUBGROUPS
244247 if (bitCount(p.nei0) == 1) {
245- load_row_ids(expert_idx, true);
248+ load_row_ids(expert_idx, true, ic );
246249 } else {
247- load_row_ids(expert_idx, false);
250+ load_row_ids(expert_idx, false, ic );
248251 }
249252#else
250253 _ne1 = 0;
251- for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
252- for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
254+ for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN ; ii1++) {
255+ for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN ; ii0++) {
253256 if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
254- row_ids[_ne1] = u16vec2(ii0, ii1);
257+ if (_ne1 >= ic * BN) {
258+ row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
259+ }
255260 _ne1++;
256261 }
257262 }
@@ -797,7 +802,7 @@ void main() {
797802 [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
798803#if LOAD_VEC_B == 8
799804#ifdef MUL_MAT_ID
800- const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
805+ const u16vec2 row_idx = row_ids[loadc_b + l];
801806 const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
802807#else
803808 const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
@@ -813,7 +818,7 @@ void main() {
813818 buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
814819#elif LOAD_VEC_B == 4
815820#ifdef MUL_MAT_ID
816- const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
821+ const u16vec2 row_idx = row_ids[loadc_b + l];
817822 const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
818823#else
819824 const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
@@ -832,7 +837,7 @@ void main() {
832837#else
833838 const uint row_i = ic * BN + loadc_b + l;
834839 if (row_i < _ne1 && block + loadr_b < end_k) {
835- const u16vec2 row_idx = row_ids[row_i ];
840+ const u16vec2 row_idx = row_ids[loadc_b + l ];
836841 buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
837842 } else {
838843 buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
@@ -903,7 +908,7 @@ void main() {
903908 const uint row_i = dc + cm_col * TN + col + store_c;
904909 if (row_i >= _ne1) break;
905910
906- const u16vec2 row_idx = row_ids[row_i];
911+ const u16vec2 row_idx = row_ids[row_i - ic * BN ];
907912
908913 if (dr + cm_row * TM + store_r < p.M) {
909914 data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
@@ -953,7 +958,7 @@ void main() {
953958 const uint row_i = dc_warp + cc;
954959 if (row_i >= _ne1) break;
955960
956- const u16vec2 row_idx = row_ids[row_i];
961+ const u16vec2 row_idx = row_ids[row_i - ic * BN ];
957962#endif // MUL_MAT_ID
958963 [[unroll]] for (uint cr = 0; cr < TM; cr++) {
959964#ifdef MUL_MAT_ID
0 commit comments