@@ -673,7 +673,9 @@ def _run_memory_efficient_xformers_forward(
673
673
674
674
# Cross-attention mask is non-causal
675
675
attn_bias = BlockDiagonalMask .from_seqlens (
676
- attn_metadata .seq_lens , attn_metadata .encoder_seq_lens )
676
+ attn_metadata .seq_lens ,
677
+ attn_metadata .encoder_seq_lens ,
678
+ device = query .device )
677
679
678
680
# Encoder branch of encoder-decoder model uses
679
681
# attn_metadata.encoder_seq_lens
@@ -683,7 +685,7 @@ def _run_memory_efficient_xformers_forward(
683
685
684
686
# Encoder self-attention mask is non-causal
685
687
attn_bias = BlockDiagonalMask .from_seqlens (
686
- attn_metadata .encoder_seq_lens )
688
+ attn_metadata .encoder_seq_lens , device = query . device )
687
689
688
690
# Self-attention block of encoder-only model just
689
691
# uses the seq_lens directly.
@@ -692,7 +694,7 @@ def _run_memory_efficient_xformers_forward(
692
694
693
695
# Encoder self-attention mask is non-causal
694
696
attn_bias = BlockDiagonalMask .from_seqlens (
695
- attn_metadata .seq_lens )
697
+ attn_metadata .seq_lens , device = query . device )
696
698
697
699
# Self-attention block of decoder branch just
698
700
# uses the seq_lens directly
@@ -701,7 +703,7 @@ def _run_memory_efficient_xformers_forward(
701
703
702
704
# Decoder self-attention mask is causal
703
705
attn_bias = BlockDiagonalCausalMask .from_seqlens (
704
- attn_metadata .seq_lens )
706
+ attn_metadata .seq_lens , device = query . device )
705
707
else :
706
708
raise ValueError ("Unknown AttentionType: %s" , attn_type )
707
709
0 commit comments