@@ -564,7 +564,6 @@ def benchmark(func, *args, **kwargs):
564564#
565565# * Cross Attention
566566# * Fully masked rows no longer cause NaNs
567- # * Modifying attention score: ALiBi with FlexAttention and NJT
568567# * Packed Projection
569568
570569###############################################################################
@@ -668,66 +667,6 @@ def benchmark(func, *args, **kwargs):
668667# appropriately makes it possible to properly express empty sequences.
669668
670669
671- ################################################################################
672- # FlexAttention + NJT
673- # ---------------------------------------------------------------------
674- # NJT also composes with the ``FlexAttention`` module. This is a generalization
675- # of the ``MultiheadAttention`` layer that allows for arbitrary modifications
676- # to the attention score. The example below takes the ``alibi_mod``
677- # that implements `ALiBi <https://arxiv.org/abs/2108.12409>`_ from
678- # `attention gym <https://github.com/meta-pytorch/attention-gym>`_ and uses it
679- # with nested input tensors.
680-
681- from torch .nn .attention .flex_attention import flex_attention
682-
683-
684- def generate_alibi_bias (H : int ):
685- """Returns an alibi bias score_mod given the number of heads H
686- Args:
687- H: number of heads
688- Returns:
689- alibi_bias: alibi bias score_mod
690- """
691-
692- def alibi_mod (score , b , h , q_idx , kv_idx ):
693- scale = torch .exp2 (- ((h + 1 ) * 8.0 / H ))
694- bias = (q_idx - kv_idx ) * scale
695- return score + bias
696-
697- return alibi_mod
698-
699-
700- query , key , value , _ = gen_batch (N , E_q , E_k , E_v , device )
701- n_heads , D = 8 , E_q // 8
702- alibi_score_mod = generate_alibi_bias (n_heads )
703- query = query .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
704- key = key .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
705- value = value .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
706- out_flex2 = flex_attention (query , key , value , score_mod = alibi_score_mod )
707-
708- ###############################################################################
709- # In addition, one can also use the ``block_mask`` utility of ``FlexAttention``
710- # with NJTs via the ``create_nested_block_mask`` function. This is useful for
711- # taking advantage of the sparsity of the mask to speed up the attention computation.
712- # In particular, the function creates a sparse block mask for a "stacked sequence" of all
713- # the variable length sequences in the NJT combined into one, while properly masking out
714- # inter-sequence attention. In the following example, we show how to create a
715- # causal block mask using this utility.
716-
717- from torch .nn .attention .flex_attention import create_nested_block_mask
718-
719-
720- def causal_mask (b , h , q_idx , kv_idx ):
721- return q_idx >= kv_idx
722-
723-
724- query , key , value , _ = gen_batch (N , E_q , E_k , E_v , device )
725- block_mask = create_nested_block_mask (causal_mask , 1 , 1 , query , _compile = True )
726- query = query .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
727- key = key .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
728- value = value .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
729- out_flex = flex_attention (query , key , value , block_mask = block_mask )
730-
731670###############################################################################
732671# Packed Projection
733672# -----------------
0 commit comments