@@ -15523,39 +15523,6 @@ static void llama_graph_compute(
1552315523 // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
1552415524}
1552515525
15526- // Optionally swaps the batch and single-tok threadpools.
15527- // Returns the number of threads, and if a valid threadpool exists, returns it too.
15528- static std::pair<int32_t, ggml_compute_threadpool_t> llama_swap_threadpools(
15529- llama_context & lctx,
15530- int32_t n_tokens) {
15531-
15532- const auto & cparams = lctx.cparams;
15533- int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
15534-
15535- ggml_compute_threadpool_t threadpool = nullptr; // nullptr -> disposable threadpool
15536-
15537- // A batch threadpool without a non-batch threadpool isn't supported.
15538- GGML_ASSERT(!lctx.threadpool_batch || lctx.threadpool);
15539-
15540- if (lctx.threadpool_batch && lctx.threadpool) {
15541- // Switch between the 2 threadpools as needed
15542- if (n_tokens > 1) {
15543- ggml_pause_threadpool(lctx.threadpool);
15544- threadpool = lctx.threadpool_batch;
15545- n_threads = cparams.n_threads_batch;
15546- } else {
15547- ggml_pause_threadpool(lctx.threadpool_batch);
15548- threadpool = lctx.threadpool;
15549- n_threads = cparams.n_threads;
15550- }
15551- } else if (lctx.threadpool) {
15552- threadpool = lctx.threadpool;
15553- n_threads = cparams.n_threads;
15554- }
15555- return std::make_pair(n_threads, threadpool);
15556- }
15557-
15558-
1555915526// decode a batch of tokens by evaluating the transformer
1556015527//
1556115528// - lctx: llama context
@@ -15662,11 +15629,8 @@ static int llama_decode_internal(
1566215629 lctx.n_outputs = n_outputs_new;
1566315630 }
1566415631
15665- std::pair<int32_t, ggml_compute_threadpool_t> threads =
15666- llama_swap_threadpools(lctx, n_tokens);
15667-
15668- int n_threads = threads.first;
15669- ggml_compute_threadpool_t threadpool = threads.second;
15632+ int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
15633+ ggml_compute_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
1567015634
1567115635 GGML_ASSERT(n_threads > 0);
1567215636
@@ -15906,11 +15870,9 @@ static int llama_encode_internal(
1590615870 lctx.inp_embd_enc = NULL;
1590715871 lctx.n_outputs = n_tokens;
1590815872
15909- std::pair<int32_t, ggml_compute_threadpool_t> threads =
15910- llama_swap_threadpools( lctx, n_tokens) ;
15873+ int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
15874+ ggml_compute_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch ;
1591115875
15912- int n_threads = threads.first;
15913- ggml_compute_threadpool_t threadpool = threads.second;
1591415876 GGML_ASSERT(n_threads > 0);
1591515877
1591615878 ggml_backend_sched_reset(lctx.sched);
@@ -17500,36 +17462,15 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
1750017462
1750117463void llama_attach_threadpool(
1750217464 struct llama_context * ctx,
17503- ggml_compute_threadpool_t threadpool) {
17504- ctx->threadpool = threadpool;
17505- }
17506-
17507- void llama_attach_batch_threadpool(
17508- struct llama_context * ctx,
17465+ ggml_compute_threadpool_t threadpool,
1750917466 ggml_compute_threadpool_t threadpool_batch) {
17510- ctx->threadpool_batch = threadpool_batch;
17467+ ctx->threadpool = threadpool;
17468+ ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
1751117469}
1751217470
1751317471void llama_detach_threadpool(struct llama_context * ctx) {
17514- ctx->threadpool = nullptr;
17515- }
17516-
17517- void llama_detach_batch_threadpool(struct llama_context * ctx) {
17518- ctx->threadpool = nullptr;
17519- }
17520-
17521- void llama_detach_threadpools(struct llama_context * ctx) {
17522- llama_detach_threadpool(ctx);
17523- llama_detach_batch_threadpool(ctx);
17524- }
17525-
17526- void llama_pause_threadpools(struct llama_context * ctx) {
17527- if (ctx->threadpool) {
17528- ggml_pause_threadpool(ctx->threadpool);
17529- }
17530- if (ctx->threadpool_batch) {
17531- ggml_pause_threadpool(ctx->threadpool_batch);
17532- }
17472+ ctx->threadpool = nullptr;
17473+ ctx->threadpool_batch = nullptr;
1753317474}
1753417475
1753517476void llama_backend_free(void) {
0 commit comments