File tree Expand file tree Collapse file tree 3 files changed +10
-5
lines changed Expand file tree Collapse file tree 3 files changed +10
-5
lines changed Original file line number Diff line number Diff line change @@ -1738,7 +1738,8 @@ struct server_context {
17381738 }
17391739
17401740 // process in chunks of params.n_batch
1741- int32_t n_batch = params.n_batch ;
1741+ int32_t n_batch = llama_n_batch (ctx);
1742+ int32_t n_ubatch = llama_n_ubatch (ctx);
17421743
17431744 // next, batch any pending prompts without exceeding n_batch
17441745 if (params.cont_batching || batch.n_tokens == 0 ) {
@@ -1811,7 +1812,7 @@ struct server_context {
18111812
18121813 if (slot.embedding ) {
18131814 // this prompt is too large to process - discard it
1814- if (slot.n_prompt_tokens > n_batch ) {
1815+ if (slot.n_prompt_tokens > n_ubatch ) {
18151816 slot.state = SLOT_STATE_PROCESSING;
18161817 slot.command = SLOT_COMMAND_NONE;
18171818 slot.release ();
Original file line number Diff line number Diff line change @@ -8774,6 +8774,8 @@ static int llama_decode_internal(
87748774
87758775 GGML_ASSERT(n_tokens_all <= cparams.n_batch);
87768776
8777+ GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
8778+
87778779 if (lctx.t_compute_start_us == 0) {
87788780 lctx.t_compute_start_us = ggml_time_us();
87798781 }
@@ -9011,9 +9013,6 @@ static int llama_decode_internal(
90119013 case LLAMA_POOLING_TYPE_CLS:
90129014 case LLAMA_POOLING_TYPE_MEAN:
90139015 {
9014- // FIXME: this may not work if the sequences are split into different batches
9015- GGML_ASSERT(n_tokens_all == n_tokens);
9016-
90179016 GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
90189017
90199018 // extract sequence embeddings
@@ -13076,6 +13075,10 @@ uint32_t llama_n_batch(const struct llama_context * ctx) {
1307613075 return ctx->cparams.n_batch;
1307713076}
1307813077
13078+ uint32_t llama_n_ubatch(const struct llama_context *ctx) {
13079+ return ctx->cparams.n_ubatch;
13080+ }
13081+
1307913082uint32_t llama_n_seq_max(const struct llama_context * ctx) {
1308013083 return ctx->kv_self.size;
1308113084}
Original file line number Diff line number Diff line change @@ -378,6 +378,7 @@ extern "C" {
378378
379379 LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
380380 LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
381+ LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
381382 LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
382383
383384 LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
You can’t perform that action at this time.
0 commit comments