@@ -33,9 +33,6 @@ llama_context::llama_context(
3333 throw std::runtime_error (" n_seq_max must be <= " + std::to_string (LLAMA_MAX_SEQ));
3434 }
3535
36- const char * LLAMA_HT = getenv (" LLAMA_HT" );
37- cparams.kv_unified = (LLAMA_HT && atoi (LLAMA_HT) > 0 ) ? false : true ;
38-
3936 cparams.n_threads = params.n_threads ;
4037 cparams.n_threads_batch = params.n_threads_batch ;
4138 cparams.yarn_ext_factor = params.yarn_ext_factor ;
@@ -104,7 +101,8 @@ llama_context::llama_context(
104101
105102 cparams.n_ubatch = std::min (cparams.n_batch , params.n_ubatch == 0 ? params.n_batch : params.n_ubatch );
106103
107- cparams.op_offload = params.op_offload ;
104+ cparams.op_offload = params.op_offload ;
105+ cparams.attn_streams = params.attn_streams ;
108106
109107 const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max ;
110108
@@ -115,6 +113,7 @@ llama_context::llama_context(
115113 LLAMA_LOG_INFO (" %s: n_ubatch = %u\n " , __func__, cparams.n_ubatch );
116114 LLAMA_LOG_INFO (" %s: causal_attn = %d\n " , __func__, cparams.causal_attn );
117115 LLAMA_LOG_INFO (" %s: flash_attn = %d\n " , __func__, cparams.flash_attn );
116+ LLAMA_LOG_INFO (" %s: attn_streams = %s\n " , __func__, cparams.attn_streams ? " true" : " false" );
118117 LLAMA_LOG_INFO (" %s: freq_base = %.1f\n " , __func__, cparams.rope_freq_base );
119118 LLAMA_LOG_INFO (" %s: freq_scale = %g\n " , __func__, cparams.rope_freq_scale );
120119
@@ -270,7 +269,7 @@ llama_context::llama_context(
270269
271270 // reserve worst-case graph
272271 if (!hparams.vocab_only && memory) {
273- const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max ;
272+ const uint32_t n_seqs = cparams.attn_streams ? cparams.n_seq_max : 1 ;
274273 const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
275274
276275 LLAMA_LOG_DEBUG (" %s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n " , __func__, n_tokens, n_seqs, n_outputs);
@@ -314,6 +313,10 @@ llama_context::llama_context(
314313
315314 // reserve again with pp graph to avoid ggml-alloc reallocations during inference
316315 {
316+ // TODO: not sure if the following graph would be worster case for multi-stream KV caches:
317+ //
318+ // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
319+ //
317320 auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, mctx.get ());
318321 if (!gf) {
319322 throw std::runtime_error (" failed to allocate compute pp buffers" );
@@ -478,7 +481,7 @@ bool llama_context::kv_self_update(bool optimize) {
478481 throw std::runtime_error (" failed to initialize memory context" );
479482 }
480483
481- const uint32_t n_seqs = cparams.n_seq_max ;
484+ const uint32_t n_seqs = cparams. attn_streams ? cparams.n_seq_max : 1 ;
482485 const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
483486
484487 auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, mctx.get ());
@@ -2192,6 +2195,7 @@ llama_context_params llama_context_default_params() {
21922195 /* .no_perf =*/ true ,
21932196 /* .op_offload =*/ true ,
21942197 /* .swa_full =*/ true ,
2198+ /* .attn_streams =*/ false ,
21952199 };
21962200
21972201 return result;
0 commit comments