Skip to content

Commit 9f669e7

Browse files
authored
feat: enable attention dispatch for huanyuan video (#12591)
* feat: enable attention dispatch for huanyuan video * feat: enable attention dispatch for huanyuan video
1 parent 8ac17cd commit 9f669e7

File tree

1 file changed

+39
-20
lines changed

1 file changed

+39
-20
lines changed

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ...loaders import PeftAdapterMixin
2525
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
2626
from ..attention import FeedForward
27+
from ..attention_dispatch import dispatch_attention_fn
2728
from ..attention_processor import Attention, AttentionProcessor
2829
from ..cache_utils import CacheMixin
2930
from ..embeddings import (
@@ -42,6 +43,9 @@
4243

4344

4445
class HunyuanVideoAttnProcessor2_0:
46+
_attention_backend = None
47+
_parallel_config = None
48+
4549
def __init__(self):
4650
if not hasattr(F, "scaled_dot_product_attention"):
4751
raise ImportError(
@@ -64,9 +68,9 @@ def __call__(
6468
key = attn.to_k(hidden_states)
6569
value = attn.to_v(hidden_states)
6670

67-
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
68-
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
69-
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
71+
query = query.unflatten(2, (attn.heads, -1))
72+
key = key.unflatten(2, (attn.heads, -1))
73+
value = value.unflatten(2, (attn.heads, -1))
7074

7175
# 2. QK normalization
7276
if attn.norm_q is not None:
@@ -81,46 +85,61 @@ def __call__(
8185
if attn.add_q_proj is None and encoder_hidden_states is not None:
8286
query = torch.cat(
8387
[
84-
apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
85-
query[:, :, -encoder_hidden_states.shape[1] :],
88+
apply_rotary_emb(
89+
query[:, : -encoder_hidden_states.shape[1]],
90+
image_rotary_emb,
91+
sequence_dim=1,
92+
),
93+
query[:, -encoder_hidden_states.shape[1] :],
8694
],
87-
dim=2,
95+
dim=1,
8896
)
8997
key = torch.cat(
9098
[
91-
apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
92-
key[:, :, -encoder_hidden_states.shape[1] :],
99+
apply_rotary_emb(
100+
key[:, : -encoder_hidden_states.shape[1]],
101+
image_rotary_emb,
102+
sequence_dim=1,
103+
),
104+
key[:, -encoder_hidden_states.shape[1] :],
93105
],
94-
dim=2,
106+
dim=1,
95107
)
96108
else:
97-
query = apply_rotary_emb(query, image_rotary_emb)
98-
key = apply_rotary_emb(key, image_rotary_emb)
109+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
110+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
99111

100112
# 4. Encoder condition QKV projection and normalization
101113
if attn.add_q_proj is not None and encoder_hidden_states is not None:
102114
encoder_query = attn.add_q_proj(encoder_hidden_states)
103115
encoder_key = attn.add_k_proj(encoder_hidden_states)
104116
encoder_value = attn.add_v_proj(encoder_hidden_states)
105117

106-
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
107-
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
108-
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
118+
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
119+
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
120+
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
109121

110122
if attn.norm_added_q is not None:
111123
encoder_query = attn.norm_added_q(encoder_query)
112124
if attn.norm_added_k is not None:
113125
encoder_key = attn.norm_added_k(encoder_key)
114126

115-
query = torch.cat([query, encoder_query], dim=2)
116-
key = torch.cat([key, encoder_key], dim=2)
117-
value = torch.cat([value, encoder_value], dim=2)
127+
query = torch.cat([query, encoder_query], dim=1)
128+
key = torch.cat([key, encoder_key], dim=1)
129+
value = torch.cat([value, encoder_value], dim=1)
118130

119131
# 5. Attention
120-
hidden_states = F.scaled_dot_product_attention(
121-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
132+
hidden_states = dispatch_attention_fn(
133+
query,
134+
key,
135+
value,
136+
attn_mask=attention_mask,
137+
dropout_p=0.0,
138+
is_causal=False,
139+
backend=self._attention_backend,
140+
parallel_config=self._parallel_config,
122141
)
123-
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
142+
hidden_states = hidden_states.flatten(2, 3)
124143
hidden_states = hidden_states.to(query.dtype)
125144

126145
# 6. Output projection

0 commit comments

Comments
 (0)