@@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
166166
167167 // note: tracking the other way around is not necessary for now
168168 // seq_cpl[s0][s1] = true;
169+
170+ has_cpl = true ;
169171 }
170172 }
171173 }
@@ -403,6 +405,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
403405 return n_outputs;
404406}
405407
408+ uint32_t llama_batch_allocr::get_n_used () const {
409+ return n_used;
410+ }
411+
406412std::vector<int32_t > & llama_batch_allocr::get_out_ids () {
407413 return out_ids;
408414}
@@ -418,6 +424,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
418424void llama_batch_allocr::split_reset () {
419425 out_ids.clear ();
420426
427+ n_used = 0 ;
428+
421429 used.clear ();
422430 used.resize (get_n_tokens (), false );
423431
@@ -442,6 +450,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
442450 idxs.push_back (cur_idx);
443451
444452 used[cur_idx] = true ;
453+ ++n_used;
445454
446455 ++cur_idx;
447456
@@ -458,6 +467,12 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
458467}
459468
460469llama_ubatch llama_batch_allocr::split_equal (uint32_t n_ubatch, bool sequential) {
470+ if (sequential && has_cpl) {
471+ LLAMA_LOG_ERROR (" %s: sequential split is not supported when there are coupled sequences in the input batch\n " , __func__);
472+
473+ return {};
474+ }
475+
461476 std::vector<seq_set_t > cur_seq_set;
462477
463478 llama_seq_id last_seq_id = -1 ;
@@ -536,6 +551,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential)
536551 idxs_per_seq[s].push_back (idx);
537552
538553 used[idx] = true ;
554+ ++n_used;
539555
540556 ++cur_idx[s];
541557 }
@@ -577,6 +593,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
577593 idxs.push_back (cur_idx);
578594
579595 used[cur_idx] = true ;
596+ ++n_used;
580597
581598 if (idxs.size () >= n_ubatch) {
582599 break ;
0 commit comments