@@ -281,43 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281281}
282282
283283void llm_graph_input_attn_kv_unified::set_input (const llama_ubatch * ubatch) {
284- if (self_k_idxs) {
285- mctx->set_input_k_idxs (self_k_idxs, ubatch);
286- }
287-
288- if (self_v_idxs) {
289- mctx->set_input_v_idxs (self_v_idxs, ubatch);
290- }
284+ mctx->set_input_k_idxs (self_k_idxs, ubatch);
285+ mctx->set_input_v_idxs (self_v_idxs, ubatch);
291286
292- if (self_kq_mask) {
293- mctx->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
294- }
287+ mctx->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
295288}
296289
297290void llm_graph_input_attn_kv_unified_iswa::set_input (const llama_ubatch * ubatch) {
298- if (self_k_idxs) {
299- mctx->get_base ()->set_input_k_idxs (self_k_idxs, ubatch);
300- }
301-
302- if (self_v_idxs) {
303- mctx->get_base ()->set_input_v_idxs (self_v_idxs, ubatch);
304- }
305-
306- if (self_k_idxs_swa) {
307- mctx->get_swa ()->set_input_k_idxs (self_k_idxs_swa, ubatch);
308- }
291+ mctx->get_base ()->set_input_k_idxs (self_k_idxs, ubatch);
292+ mctx->get_base ()->set_input_v_idxs (self_v_idxs, ubatch);
309293
310- if (self_v_idxs_swa) {
311- mctx->get_swa ()->set_input_v_idxs (self_v_idxs_swa, ubatch);
312- }
294+ mctx->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
313295
314- if (self_kq_mask) {
315- mctx->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
316- }
296+ mctx->get_swa ()->set_input_k_idxs (self_k_idxs_swa, ubatch);
297+ mctx->get_swa ()->set_input_v_idxs (self_v_idxs_swa, ubatch);
317298
318- if (self_kq_mask_swa) {
319- mctx->get_swa ()->set_input_kq_mask (self_kq_mask_swa, ubatch, cparams.causal_attn );
320- }
299+ mctx->get_swa ()->set_input_kq_mask (self_kq_mask_swa, ubatch, cparams.causal_attn );
321300}
322301
323302void llm_graph_input_attn_cross::set_input (const llama_ubatch * ubatch) {
@@ -357,17 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
357336}
358337
359338void llm_graph_input_mem_hybrid::set_input (const llama_ubatch * ubatch) {
360- if (self_k_idxs) {
361- mctx->get_attn ()->set_input_k_idxs (self_k_idxs, ubatch);
362- }
363-
364- if (self_v_idxs) {
365- mctx->get_attn ()->set_input_v_idxs (self_v_idxs, ubatch);
366- }
339+ mctx->get_attn ()->set_input_k_idxs (self_k_idxs, ubatch);
340+ mctx->get_attn ()->set_input_v_idxs (self_v_idxs, ubatch);
367341
368- if (self_kq_mask) {
369- mctx->get_attn ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
370- }
342+ mctx->get_attn ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
371343
372344 const int64_t n_rs = mctx->get_recr ()->get_n_rs ();
373345
0 commit comments