Skip to content

Commit 275d082

Browse files
committed
[Bugfix] Initialize attention bias on the same device as Query/Key/Value
Signed-off-by: Junlin Zhou <[email protected]>
1 parent cbae7af commit 275d082

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

vllm/attention/backends/xformers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,9 @@ def _run_memory_efficient_xformers_forward(
673673

674674
# Cross-attention mask is non-causal
675675
attn_bias = BlockDiagonalMask.from_seqlens(
676-
attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
676+
attn_metadata.seq_lens,
677+
attn_metadata.encoder_seq_lens,
678+
device=query.device)
677679

678680
# Encoder branch of encoder-decoder model uses
679681
# attn_metadata.encoder_seq_lens
@@ -683,7 +685,7 @@ def _run_memory_efficient_xformers_forward(
683685

684686
# Encoder self-attention mask is non-causal
685687
attn_bias = BlockDiagonalMask.from_seqlens(
686-
attn_metadata.encoder_seq_lens)
688+
attn_metadata.encoder_seq_lens, device=query.device)
687689

688690
# Self-attention block of encoder-only model just
689691
# uses the seq_lens directly.
@@ -692,7 +694,7 @@ def _run_memory_efficient_xformers_forward(
692694

693695
# Encoder self-attention mask is non-causal
694696
attn_bias = BlockDiagonalMask.from_seqlens(
695-
attn_metadata.seq_lens)
697+
attn_metadata.seq_lens, device=query.device)
696698

697699
# Self-attention block of decoder branch just
698700
# uses the seq_lens directly
@@ -701,7 +703,7 @@ def _run_memory_efficient_xformers_forward(
701703

702704
# Decoder self-attention mask is causal
703705
attn_bias = BlockDiagonalCausalMask.from_seqlens(
704-
attn_metadata.seq_lens)
706+
attn_metadata.seq_lens, device=query.device)
705707
else:
706708
raise ValueError("Unknown AttentionType: %s", attn_type)
707709

0 commit comments

Comments
 (0)