2121
2222from ...configuration_utils import ConfigMixin , register_to_config
2323from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
24- from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
24+ from ...utils import USE_PEFT_BACKEND , deprecate , logging , scale_lora_layers , unscale_lora_layers
2525from ...utils .torch_utils import maybe_allow_in_graph
26- from ..attention import FeedForward
27- from ..attention_processor import Attention
26+ from ..attention import AttentionMixin , AttentionModuleMixin , FeedForward
27+ from ..attention_dispatch import dispatch_attention_fn
2828from ..cache_utils import CacheMixin
2929from ..embeddings import PixArtAlphaTextProjection , TimestepEmbedding , Timesteps , get_1d_rotary_pos_embed
3030from ..modeling_outputs import Transformer2DModelOutput
3535logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3636
3737
38- class WanAttnProcessor2_0 :
38+ def _get_qkv_projections (attn : "WanAttention" , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor ):
39+ # encoder_hidden_states is only passed for cross-attention
40+ if encoder_hidden_states is None :
41+ encoder_hidden_states = hidden_states
42+
43+ if attn .fused_projections :
44+ if attn .cross_attention_dim_head is None :
45+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
46+ query , key , value = attn .to_qkv (hidden_states ).chunk (3 , dim = - 1 )
47+ else :
48+ # In cross-attention layers, we can only fuse the KV projections into a single linear
49+ query = attn .to_q (hidden_states )
50+ key , value = attn .to_kv (encoder_hidden_states ).chunk (2 , dim = - 1 )
51+ else :
52+ query = attn .to_q (hidden_states )
53+ key = attn .to_k (encoder_hidden_states )
54+ value = attn .to_v (encoder_hidden_states )
55+ return query , key , value
56+
57+
58+ def _get_added_kv_projections (attn : "WanAttention" , encoder_hidden_states_img : torch .Tensor ):
59+ if attn .fused_projections :
60+ key_img , value_img = attn .to_added_kv (encoder_hidden_states_img ).chunk (2 , dim = - 1 )
61+ else :
62+ key_img = attn .add_k_proj (encoder_hidden_states_img )
63+ value_img = attn .add_v_proj (encoder_hidden_states_img )
64+ return key_img , value_img
65+
66+
67+ class WanAttnProcessor :
68+ _attention_backend = None
69+
3970 def __init__ (self ):
4071 if not hasattr (F , "scaled_dot_product_attention" ):
41- raise ImportError ("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." )
72+ raise ImportError (
73+ "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
74+ )
4275
4376 def __call__ (
4477 self ,
45- attn : Attention ,
78+ attn : "WanAttention" ,
4679 hidden_states : torch .Tensor ,
4780 encoder_hidden_states : Optional [torch .Tensor ] = None ,
4881 attention_mask : Optional [torch .Tensor ] = None ,
49- rotary_emb : Optional [torch .Tensor ] = None ,
82+ rotary_emb : Optional [Tuple [ torch .Tensor , torch . Tensor ] ] = None ,
5083 ) -> torch .Tensor :
5184 encoder_hidden_states_img = None
5285 if attn .add_k_proj is not None :
5386 # 512 is the context length of the text encoder, hardcoded for now
5487 image_context_length = encoder_hidden_states .shape [1 ] - 512
5588 encoder_hidden_states_img = encoder_hidden_states [:, :image_context_length ]
5689 encoder_hidden_states = encoder_hidden_states [:, image_context_length :]
57- if encoder_hidden_states is None :
58- encoder_hidden_states = hidden_states
5990
60- query = attn .to_q (hidden_states )
61- key = attn .to_k (encoder_hidden_states )
62- value = attn .to_v (encoder_hidden_states )
91+ query , key , value = _get_qkv_projections (attn , hidden_states , encoder_hidden_states )
6392
64- if attn .norm_q is not None :
65- query = attn .norm_q (query )
66- if attn .norm_k is not None :
67- key = attn .norm_k (key )
93+ query = attn .norm_q (query )
94+ key = attn .norm_k (key )
6895
69- query = query .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
70- key = key .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
71- value = value .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
96+ query = query .unflatten (2 , (attn .heads , - 1 ))
97+ key = key .unflatten (2 , (attn .heads , - 1 ))
98+ value = value .unflatten (2 , (attn .heads , - 1 ))
7299
73100 if rotary_emb is not None :
74101
@@ -77,8 +104,7 @@ def apply_rotary_emb(
77104 freqs_cos : torch .Tensor ,
78105 freqs_sin : torch .Tensor ,
79106 ):
80- x = hidden_states .view (* hidden_states .shape [:- 1 ], - 1 , 2 )
81- x1 , x2 = x [..., 0 ], x [..., 1 ]
107+ x1 , x2 = hidden_states .unflatten (- 1 , (- 1 , 2 )).unbind (- 1 )
82108 cos = freqs_cos [..., 0 ::2 ]
83109 sin = freqs_sin [..., 1 ::2 ]
84110 out = torch .empty_like (hidden_states )
@@ -92,23 +118,34 @@ def apply_rotary_emb(
92118 # I2V task
93119 hidden_states_img = None
94120 if encoder_hidden_states_img is not None :
95- key_img = attn . add_k_proj ( encoder_hidden_states_img )
121+ key_img , value_img = _get_added_kv_projections ( attn , encoder_hidden_states_img )
96122 key_img = attn .norm_added_k (key_img )
97- value_img = attn .add_v_proj (encoder_hidden_states_img )
98-
99- key_img = key_img .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
100- value_img = value_img .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
101123
102- hidden_states_img = F .scaled_dot_product_attention (
103- query , key_img , value_img , attn_mask = None , dropout_p = 0.0 , is_causal = False
124+ key_img = key_img .unflatten (2 , (attn .heads , - 1 ))
125+ value_img = value_img .unflatten (2 , (attn .heads , - 1 ))
126+
127+ hidden_states_img = dispatch_attention_fn (
128+ query ,
129+ key_img ,
130+ value_img ,
131+ attn_mask = None ,
132+ dropout_p = 0.0 ,
133+ is_causal = False ,
134+ backend = self ._attention_backend ,
104135 )
105- hidden_states_img = hidden_states_img .transpose ( 1 , 2 ). flatten (2 , 3 )
136+ hidden_states_img = hidden_states_img .flatten (2 , 3 )
106137 hidden_states_img = hidden_states_img .type_as (query )
107138
108- hidden_states = F .scaled_dot_product_attention (
109- query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
139+ hidden_states = dispatch_attention_fn (
140+ query ,
141+ key ,
142+ value ,
143+ attn_mask = attention_mask ,
144+ dropout_p = 0.0 ,
145+ is_causal = False ,
146+ backend = self ._attention_backend ,
110147 )
111- hidden_states = hidden_states .transpose ( 1 , 2 ). flatten (2 , 3 )
148+ hidden_states = hidden_states .flatten (2 , 3 )
112149 hidden_states = hidden_states .type_as (query )
113150
114151 if hidden_states_img is not None :
@@ -119,6 +156,119 @@ def apply_rotary_emb(
119156 return hidden_states
120157
121158
159+ class WanAttnProcessor2_0 :
160+ def __new__ (cls , * args , ** kwargs ):
161+ deprecation_message = (
162+ "The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
163+ "Please use WanAttnProcessor instead. "
164+ )
165+ deprecate ("WanAttnProcessor2_0" , "1.0.0" , deprecation_message , standard_warn = False )
166+ return WanAttnProcessor (* args , ** kwargs )
167+
168+
169+ class WanAttention (torch .nn .Module , AttentionModuleMixin ):
170+ _default_processor_cls = WanAttnProcessor
171+ _available_processors = [WanAttnProcessor ]
172+
173+ def __init__ (
174+ self ,
175+ dim : int ,
176+ heads : int = 8 ,
177+ dim_head : int = 64 ,
178+ eps : float = 1e-5 ,
179+ dropout : float = 0.0 ,
180+ added_kv_proj_dim : Optional [int ] = None ,
181+ cross_attention_dim_head : Optional [int ] = None ,
182+ processor = None ,
183+ ):
184+ super ().__init__ ()
185+
186+ self .inner_dim = dim_head * heads
187+ self .heads = heads
188+ self .added_kv_proj_dim = added_kv_proj_dim
189+ self .cross_attention_dim_head = cross_attention_dim_head
190+ self .kv_inner_dim = self .inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
191+
192+ self .to_q = torch .nn .Linear (dim , self .inner_dim , bias = True )
193+ self .to_k = torch .nn .Linear (dim , self .kv_inner_dim , bias = True )
194+ self .to_v = torch .nn .Linear (dim , self .kv_inner_dim , bias = True )
195+ self .to_out = torch .nn .ModuleList (
196+ [
197+ torch .nn .Linear (self .inner_dim , dim , bias = True ),
198+ torch .nn .Dropout (dropout ),
199+ ]
200+ )
201+ self .norm_q = torch .nn .RMSNorm (dim_head * heads , eps = eps , elementwise_affine = True )
202+ self .norm_k = torch .nn .RMSNorm (dim_head * heads , eps = eps , elementwise_affine = True )
203+
204+ self .add_k_proj = self .add_v_proj = None
205+ if added_kv_proj_dim is not None :
206+ self .add_k_proj = torch .nn .Linear (added_kv_proj_dim , self .inner_dim , bias = True )
207+ self .add_v_proj = torch .nn .Linear (added_kv_proj_dim , self .inner_dim , bias = True )
208+ self .norm_added_k = torch .nn .RMSNorm (dim_head * heads , eps = eps )
209+
210+ self .set_processor (processor )
211+
212+ def fuse_projections (self ):
213+ if getattr (self , "fused_projections" , False ):
214+ return
215+
216+ if self .cross_attention_dim_head is None :
217+ concatenated_weights = torch .cat ([self .to_q .weight .data , self .to_k .weight .data , self .to_v .weight .data ])
218+ concatenated_bias = torch .cat ([self .to_q .bias .data , self .to_k .bias .data , self .to_v .bias .data ])
219+ out_features , in_features = concatenated_weights .shape
220+ with torch .device ("meta" ):
221+ self .to_qkv = nn .Linear (in_features , out_features , bias = True )
222+ self .to_qkv .load_state_dict (
223+ {"weight" : concatenated_weights , "bias" : concatenated_bias }, strict = True , assign = True
224+ )
225+ else :
226+ concatenated_weights = torch .cat ([self .to_k .weight .data , self .to_v .weight .data ])
227+ concatenated_bias = torch .cat ([self .to_k .bias .data , self .to_v .bias .data ])
228+ out_features , in_features = concatenated_weights .shape
229+ with torch .device ("meta" ):
230+ self .to_kv = nn .Linear (in_features , out_features , bias = True )
231+ self .to_kv .load_state_dict (
232+ {"weight" : concatenated_weights , "bias" : concatenated_bias }, strict = True , assign = True
233+ )
234+
235+ if self .added_kv_proj_dim is not None :
236+ concatenated_weights = torch .cat ([self .add_k_proj .weight .data , self .add_v_proj .weight .data ])
237+ concatenated_bias = torch .cat ([self .add_k_proj .bias .data , self .add_v_proj .bias .data ])
238+ out_features , in_features = concatenated_weights .shape
239+ with torch .device ("meta" ):
240+ self .to_added_kv = nn .Linear (in_features , out_features , bias = True )
241+ self .to_added_kv .load_state_dict (
242+ {"weight" : concatenated_weights , "bias" : concatenated_bias }, strict = True , assign = True
243+ )
244+
245+ self .fused_projections = True
246+
247+ @torch .no_grad ()
248+ def unfuse_projections (self ):
249+ if not getattr (self , "fused_projections" , False ):
250+ return
251+
252+ if hasattr (self , "to_qkv" ):
253+ delattr (self , "to_qkv" )
254+ if hasattr (self , "to_kv" ):
255+ delattr (self , "to_kv" )
256+ if hasattr (self , "to_added_kv" ):
257+ delattr (self , "to_added_kv" )
258+
259+ self .fused_projections = False
260+
261+ def forward (
262+ self ,
263+ hidden_states : torch .Tensor ,
264+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
265+ attention_mask : Optional [torch .Tensor ] = None ,
266+ rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
267+ ** kwargs ,
268+ ) -> torch .Tensor :
269+ return self .processor (self , hidden_states , encoder_hidden_states , attention_mask , rotary_emb , ** kwargs )
270+
271+
122272class WanImageEmbedding (torch .nn .Module ):
123273 def __init__ (self , in_features : int , out_features : int , pos_embed_seq_len = None ):
124274 super ().__init__ ()
@@ -247,8 +397,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
247397 freqs_sin_h = freqs_sin [1 ][:pph ].view (1 , pph , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
248398 freqs_sin_w = freqs_sin [2 ][:ppw ].view (1 , 1 , ppw , - 1 ).expand (ppf , pph , ppw , - 1 )
249399
250- freqs_cos = torch .cat ([freqs_cos_f , freqs_cos_h , freqs_cos_w ], dim = - 1 ).reshape (1 , 1 , ppf * pph * ppw , - 1 )
251- freqs_sin = torch .cat ([freqs_sin_f , freqs_sin_h , freqs_sin_w ], dim = - 1 ).reshape (1 , 1 , ppf * pph * ppw , - 1 )
400+ freqs_cos = torch .cat ([freqs_cos_f , freqs_cos_h , freqs_cos_w ], dim = - 1 ).reshape (1 , ppf * pph * ppw , 1 , - 1 )
401+ freqs_sin = torch .cat ([freqs_sin_f , freqs_sin_h , freqs_sin_w ], dim = - 1 ).reshape (1 , ppf * pph * ppw , 1 , - 1 )
252402
253403 return freqs_cos , freqs_sin
254404
@@ -269,33 +419,24 @@ def __init__(
269419
270420 # 1. Self-attention
271421 self .norm1 = FP32LayerNorm (dim , eps , elementwise_affine = False )
272- self .attn1 = Attention (
273- query_dim = dim ,
422+ self .attn1 = WanAttention (
423+ dim = dim ,
274424 heads = num_heads ,
275- kv_heads = num_heads ,
276425 dim_head = dim // num_heads ,
277- qk_norm = qk_norm ,
278426 eps = eps ,
279- bias = True ,
280- cross_attention_dim = None ,
281- out_bias = True ,
282- processor = WanAttnProcessor2_0 (),
427+ cross_attention_dim_head = None ,
428+ processor = WanAttnProcessor (),
283429 )
284430
285431 # 2. Cross-attention
286- self .attn2 = Attention (
287- query_dim = dim ,
432+ self .attn2 = WanAttention (
433+ dim = dim ,
288434 heads = num_heads ,
289- kv_heads = num_heads ,
290435 dim_head = dim // num_heads ,
291- qk_norm = qk_norm ,
292436 eps = eps ,
293- bias = True ,
294- cross_attention_dim = None ,
295- out_bias = True ,
296437 added_kv_proj_dim = added_kv_proj_dim ,
297- added_proj_bias = True ,
298- processor = WanAttnProcessor2_0 (),
438+ cross_attention_dim_head = dim // num_heads ,
439+ processor = WanAttnProcessor (),
299440 )
300441 self .norm2 = FP32LayerNorm (dim , eps , elementwise_affine = True ) if cross_attn_norm else nn .Identity ()
301442
@@ -332,12 +473,12 @@ def forward(
332473
333474 # 1. Self-attention
334475 norm_hidden_states = (self .norm1 (hidden_states .float ()) * (1 + scale_msa ) + shift_msa ).type_as (hidden_states )
335- attn_output = self .attn1 (hidden_states = norm_hidden_states , rotary_emb = rotary_emb )
476+ attn_output = self .attn1 (norm_hidden_states , None , None , rotary_emb )
336477 hidden_states = (hidden_states .float () + attn_output * gate_msa ).type_as (hidden_states )
337478
338479 # 2. Cross-attention
339480 norm_hidden_states = self .norm2 (hidden_states .float ()).type_as (hidden_states )
340- attn_output = self .attn2 (hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states )
481+ attn_output = self .attn2 (norm_hidden_states , encoder_hidden_states , None , None )
341482 hidden_states = hidden_states + attn_output
342483
343484 # 3. Feed-forward
@@ -350,7 +491,9 @@ def forward(
350491 return hidden_states
351492
352493
353- class WanTransformer3DModel (ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , CacheMixin ):
494+ class WanTransformer3DModel (
495+ ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , CacheMixin , AttentionMixin
496+ ):
354497 r"""
355498 A Transformer model for video-like data used in the Wan model.
356499
0 commit comments