2424from ...loaders import PeftAdapterMixin
2525from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
2626from ..attention import FeedForward
27+ from ..attention_dispatch import dispatch_attention_fn
2728from ..attention_processor import Attention , AttentionProcessor
2829from ..cache_utils import CacheMixin
2930from ..embeddings import (
4243
4344
4445class HunyuanVideoAttnProcessor2_0 :
46+ _attention_backend = None
47+ _parallel_config = None
48+
4549 def __init__ (self ):
4650 if not hasattr (F , "scaled_dot_product_attention" ):
4751 raise ImportError (
@@ -64,9 +68,9 @@ def __call__(
6468 key = attn .to_k (hidden_states )
6569 value = attn .to_v (hidden_states )
6670
67- query = query .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
68- key = key .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
69- value = value .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
71+ query = query .unflatten (2 , (attn .heads , - 1 ))
72+ key = key .unflatten (2 , (attn .heads , - 1 ))
73+ value = value .unflatten (2 , (attn .heads , - 1 ))
7074
7175 # 2. QK normalization
7276 if attn .norm_q is not None :
@@ -81,46 +85,61 @@ def __call__(
8185 if attn .add_q_proj is None and encoder_hidden_states is not None :
8286 query = torch .cat (
8387 [
84- apply_rotary_emb (query [:, :, : - encoder_hidden_states .shape [1 ]], image_rotary_emb ),
85- query [:, :, - encoder_hidden_states .shape [1 ] :],
88+ apply_rotary_emb (
89+ query [:, : - encoder_hidden_states .shape [1 ]],
90+ image_rotary_emb ,
91+ sequence_dim = 1 ,
92+ ),
93+ query [:, - encoder_hidden_states .shape [1 ] :],
8694 ],
87- dim = 2 ,
95+ dim = 1 ,
8896 )
8997 key = torch .cat (
9098 [
91- apply_rotary_emb (key [:, :, : - encoder_hidden_states .shape [1 ]], image_rotary_emb ),
92- key [:, :, - encoder_hidden_states .shape [1 ] :],
99+ apply_rotary_emb (
100+ key [:, : - encoder_hidden_states .shape [1 ]],
101+ image_rotary_emb ,
102+ sequence_dim = 1 ,
103+ ),
104+ key [:, - encoder_hidden_states .shape [1 ] :],
93105 ],
94- dim = 2 ,
106+ dim = 1 ,
95107 )
96108 else :
97- query = apply_rotary_emb (query , image_rotary_emb )
98- key = apply_rotary_emb (key , image_rotary_emb )
109+ query = apply_rotary_emb (query , image_rotary_emb , sequence_dim = 1 )
110+ key = apply_rotary_emb (key , image_rotary_emb , sequence_dim = 1 )
99111
100112 # 4. Encoder condition QKV projection and normalization
101113 if attn .add_q_proj is not None and encoder_hidden_states is not None :
102114 encoder_query = attn .add_q_proj (encoder_hidden_states )
103115 encoder_key = attn .add_k_proj (encoder_hidden_states )
104116 encoder_value = attn .add_v_proj (encoder_hidden_states )
105117
106- encoder_query = encoder_query .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
107- encoder_key = encoder_key .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
108- encoder_value = encoder_value .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
118+ encoder_query = encoder_query .unflatten (2 , (attn .heads , - 1 ))
119+ encoder_key = encoder_key .unflatten (2 , (attn .heads , - 1 ))
120+ encoder_value = encoder_value .unflatten (2 , (attn .heads , - 1 ))
109121
110122 if attn .norm_added_q is not None :
111123 encoder_query = attn .norm_added_q (encoder_query )
112124 if attn .norm_added_k is not None :
113125 encoder_key = attn .norm_added_k (encoder_key )
114126
115- query = torch .cat ([query , encoder_query ], dim = 2 )
116- key = torch .cat ([key , encoder_key ], dim = 2 )
117- value = torch .cat ([value , encoder_value ], dim = 2 )
127+ query = torch .cat ([query , encoder_query ], dim = 1 )
128+ key = torch .cat ([key , encoder_key ], dim = 1 )
129+ value = torch .cat ([value , encoder_value ], dim = 1 )
118130
119131 # 5. Attention
120- hidden_states = F .scaled_dot_product_attention (
121- query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
132+ hidden_states = dispatch_attention_fn (
133+ query ,
134+ key ,
135+ value ,
136+ attn_mask = attention_mask ,
137+ dropout_p = 0.0 ,
138+ is_causal = False ,
139+ backend = self ._attention_backend ,
140+ parallel_config = self ._parallel_config ,
122141 )
123- hidden_states = hidden_states .transpose ( 1 , 2 ). flatten (2 , 3 )
142+ hidden_states = hidden_states .flatten (2 , 3 )
124143 hidden_states = hidden_states .to (query .dtype )
125144
126145 # 6. Output projection
0 commit comments