1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import inspect
1516from dataclasses import dataclass
16- from typing import Any , Callable , Type
17-
18- from ..models .attention import BasicTransformerBlock
19- from ..models .attention_processor import AttnProcessor2_0
20- from ..models .transformers .cogvideox_transformer_3d import CogVideoXBlock
21- from ..models .transformers .transformer_cogview4 import CogView4AttnProcessor , CogView4TransformerBlock
22- from ..models .transformers .transformer_flux import FluxSingleTransformerBlock , FluxTransformerBlock
23- from ..models .transformers .transformer_hunyuan_video import (
24- HunyuanVideoSingleTransformerBlock ,
25- HunyuanVideoTokenReplaceSingleTransformerBlock ,
26- HunyuanVideoTokenReplaceTransformerBlock ,
27- HunyuanVideoTransformerBlock ,
28- )
29- from ..models .transformers .transformer_ltx import LTXVideoTransformerBlock
30- from ..models .transformers .transformer_mochi import MochiTransformerBlock
31- from ..models .transformers .transformer_wan import WanTransformerBlock
17+ from typing import Any , Callable , Dict , Type
3218
3319
3420@dataclass
@@ -38,40 +24,90 @@ class AttentionProcessorMetadata:
3824
3925@dataclass
4026class TransformerBlockMetadata :
41- skip_block_output_fn : Callable [[Any ], Any ]
4227 return_hidden_states_index : int = None
4328 return_encoder_hidden_states_index : int = None
4429
30+ _cls : Type = None
31+ _cached_parameter_indices : Dict [str , int ] = None
32+
33+ def _get_parameter_from_args_kwargs (self , identifier : str , args = (), kwargs = None ):
34+ kwargs = kwargs or {}
35+ if identifier in kwargs :
36+ return kwargs [identifier ]
37+ if self ._cached_parameter_indices is not None :
38+ return args [self ._cached_parameter_indices [identifier ]]
39+ if self ._cls is None :
40+ raise ValueError ("Model class is not set for metadata." )
41+ parameters = list (inspect .signature (self ._cls .forward ).parameters .keys ())
42+ parameters = parameters [1 :] # skip `self`
43+ self ._cached_parameter_indices = {param : i for i , param in enumerate (parameters )}
44+ if identifier not in self ._cached_parameter_indices :
45+ raise ValueError (f"Parameter '{ identifier } ' not found in function signature but was requested." )
46+ index = self ._cached_parameter_indices [identifier ]
47+ if index >= len (args ):
48+ raise ValueError (f"Expected { index } arguments but got { len (args )} ." )
49+ return args [index ]
50+
4551
4652class AttentionProcessorRegistry :
4753 _registry = {}
54+ # TODO(aryan): this is only required for the time being because we need to do the registrations
55+ # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
56+ # import errors because of the models imported in this file.
57+ _is_registered = False
4858
4959 @classmethod
5060 def register (cls , model_class : Type , metadata : AttentionProcessorMetadata ):
61+ cls ._register ()
5162 cls ._registry [model_class ] = metadata
5263
5364 @classmethod
5465 def get (cls , model_class : Type ) -> AttentionProcessorMetadata :
66+ cls ._register ()
5567 if model_class not in cls ._registry :
5668 raise ValueError (f"Model class { model_class } not registered." )
5769 return cls ._registry [model_class ]
5870
71+ @classmethod
72+ def _register (cls ):
73+ if cls ._is_registered :
74+ return
75+ cls ._is_registered = True
76+ _register_attention_processors_metadata ()
77+
5978
6079class TransformerBlockRegistry :
6180 _registry = {}
81+ # TODO(aryan): this is only required for the time being because we need to do the registrations
82+ # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
83+ # import errors because of the models imported in this file.
84+ _is_registered = False
6285
6386 @classmethod
6487 def register (cls , model_class : Type , metadata : TransformerBlockMetadata ):
88+ cls ._register ()
89+ metadata ._cls = model_class
6590 cls ._registry [model_class ] = metadata
6691
6792 @classmethod
6893 def get (cls , model_class : Type ) -> TransformerBlockMetadata :
94+ cls ._register ()
6995 if model_class not in cls ._registry :
7096 raise ValueError (f"Model class { model_class } not registered." )
7197 return cls ._registry [model_class ]
7298
99+ @classmethod
100+ def _register (cls ):
101+ if cls ._is_registered :
102+ return
103+ cls ._is_registered = True
104+ _register_transformer_blocks_metadata ()
105+
73106
74107def _register_attention_processors_metadata ():
108+ from ..models .attention_processor import AttnProcessor2_0
109+ from ..models .transformers .transformer_cogview4 import CogView4AttnProcessor
110+
75111 # AttnProcessor2_0
76112 AttentionProcessorRegistry .register (
77113 model_class = AttnProcessor2_0 ,
@@ -90,11 +126,24 @@ def _register_attention_processors_metadata():
90126
91127
92128def _register_transformer_blocks_metadata ():
129+ from ..models .attention import BasicTransformerBlock
130+ from ..models .transformers .cogvideox_transformer_3d import CogVideoXBlock
131+ from ..models .transformers .transformer_cogview4 import CogView4TransformerBlock
132+ from ..models .transformers .transformer_flux import FluxSingleTransformerBlock , FluxTransformerBlock
133+ from ..models .transformers .transformer_hunyuan_video import (
134+ HunyuanVideoSingleTransformerBlock ,
135+ HunyuanVideoTokenReplaceSingleTransformerBlock ,
136+ HunyuanVideoTokenReplaceTransformerBlock ,
137+ HunyuanVideoTransformerBlock ,
138+ )
139+ from ..models .transformers .transformer_ltx import LTXVideoTransformerBlock
140+ from ..models .transformers .transformer_mochi import MochiTransformerBlock
141+ from ..models .transformers .transformer_wan import WanTransformerBlock
142+
93143 # BasicTransformerBlock
94144 TransformerBlockRegistry .register (
95145 model_class = BasicTransformerBlock ,
96146 metadata = TransformerBlockMetadata (
97- skip_block_output_fn = _skip_block_output_fn_BasicTransformerBlock ,
98147 return_hidden_states_index = 0 ,
99148 return_encoder_hidden_states_index = None ,
100149 ),
@@ -104,7 +153,6 @@ def _register_transformer_blocks_metadata():
104153 TransformerBlockRegistry .register (
105154 model_class = CogVideoXBlock ,
106155 metadata = TransformerBlockMetadata (
107- skip_block_output_fn = _skip_block_output_fn_CogVideoXBlock ,
108156 return_hidden_states_index = 0 ,
109157 return_encoder_hidden_states_index = 1 ,
110158 ),
@@ -114,7 +162,6 @@ def _register_transformer_blocks_metadata():
114162 TransformerBlockRegistry .register (
115163 model_class = CogView4TransformerBlock ,
116164 metadata = TransformerBlockMetadata (
117- skip_block_output_fn = _skip_block_output_fn_CogView4TransformerBlock ,
118165 return_hidden_states_index = 0 ,
119166 return_encoder_hidden_states_index = 1 ,
120167 ),
@@ -124,15 +171,13 @@ def _register_transformer_blocks_metadata():
124171 TransformerBlockRegistry .register (
125172 model_class = FluxTransformerBlock ,
126173 metadata = TransformerBlockMetadata (
127- skip_block_output_fn = _skip_block_output_fn_FluxTransformerBlock ,
128174 return_hidden_states_index = 1 ,
129175 return_encoder_hidden_states_index = 0 ,
130176 ),
131177 )
132178 TransformerBlockRegistry .register (
133179 model_class = FluxSingleTransformerBlock ,
134180 metadata = TransformerBlockMetadata (
135- skip_block_output_fn = _skip_block_output_fn_FluxSingleTransformerBlock ,
136181 return_hidden_states_index = 1 ,
137182 return_encoder_hidden_states_index = 0 ,
138183 ),
@@ -142,31 +187,27 @@ def _register_transformer_blocks_metadata():
142187 TransformerBlockRegistry .register (
143188 model_class = HunyuanVideoTransformerBlock ,
144189 metadata = TransformerBlockMetadata (
145- skip_block_output_fn = _skip_block_output_fn_HunyuanVideoTransformerBlock ,
146190 return_hidden_states_index = 0 ,
147191 return_encoder_hidden_states_index = 1 ,
148192 ),
149193 )
150194 TransformerBlockRegistry .register (
151195 model_class = HunyuanVideoSingleTransformerBlock ,
152196 metadata = TransformerBlockMetadata (
153- skip_block_output_fn = _skip_block_output_fn_HunyuanVideoSingleTransformerBlock ,
154197 return_hidden_states_index = 0 ,
155198 return_encoder_hidden_states_index = 1 ,
156199 ),
157200 )
158201 TransformerBlockRegistry .register (
159202 model_class = HunyuanVideoTokenReplaceTransformerBlock ,
160203 metadata = TransformerBlockMetadata (
161- skip_block_output_fn = _skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock ,
162204 return_hidden_states_index = 0 ,
163205 return_encoder_hidden_states_index = 1 ,
164206 ),
165207 )
166208 TransformerBlockRegistry .register (
167209 model_class = HunyuanVideoTokenReplaceSingleTransformerBlock ,
168210 metadata = TransformerBlockMetadata (
169- skip_block_output_fn = _skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock ,
170211 return_hidden_states_index = 0 ,
171212 return_encoder_hidden_states_index = 1 ,
172213 ),
@@ -176,7 +217,6 @@ def _register_transformer_blocks_metadata():
176217 TransformerBlockRegistry .register (
177218 model_class = LTXVideoTransformerBlock ,
178219 metadata = TransformerBlockMetadata (
179- skip_block_output_fn = _skip_block_output_fn_LTXVideoTransformerBlock ,
180220 return_hidden_states_index = 0 ,
181221 return_encoder_hidden_states_index = None ,
182222 ),
@@ -186,7 +226,6 @@ def _register_transformer_blocks_metadata():
186226 TransformerBlockRegistry .register (
187227 model_class = MochiTransformerBlock ,
188228 metadata = TransformerBlockMetadata (
189- skip_block_output_fn = _skip_block_output_fn_MochiTransformerBlock ,
190229 return_hidden_states_index = 0 ,
191230 return_encoder_hidden_states_index = 1 ,
192231 ),
@@ -196,7 +235,6 @@ def _register_transformer_blocks_metadata():
196235 TransformerBlockRegistry .register (
197236 model_class = WanTransformerBlock ,
198237 metadata = TransformerBlockMetadata (
199- skip_block_output_fn = _skip_block_output_fn_WanTransformerBlock ,
200238 return_hidden_states_index = 0 ,
201239 return_encoder_hidden_states_index = None ,
202240 ),
@@ -223,49 +261,4 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
223261
224262_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
225263_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
226-
227-
228- def _skip_block_output_fn___hidden_states_0___ret___hidden_states (self , * args , ** kwargs ):
229- hidden_states = kwargs .get ("hidden_states" , None )
230- if hidden_states is None and len (args ) > 0 :
231- hidden_states = args [0 ]
232- return hidden_states
233-
234-
235- def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states (self , * args , ** kwargs ):
236- hidden_states = kwargs .get ("hidden_states" , None )
237- encoder_hidden_states = kwargs .get ("encoder_hidden_states" , None )
238- if hidden_states is None and len (args ) > 0 :
239- hidden_states = args [0 ]
240- if encoder_hidden_states is None and len (args ) > 1 :
241- encoder_hidden_states = args [1 ]
242- return hidden_states , encoder_hidden_states
243-
244-
245- def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states (self , * args , ** kwargs ):
246- hidden_states = kwargs .get ("hidden_states" , None )
247- encoder_hidden_states = kwargs .get ("encoder_hidden_states" , None )
248- if hidden_states is None and len (args ) > 0 :
249- hidden_states = args [0 ]
250- if encoder_hidden_states is None and len (args ) > 1 :
251- encoder_hidden_states = args [1 ]
252- return encoder_hidden_states , hidden_states
253-
254-
255- _skip_block_output_fn_BasicTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
256- _skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
257- _skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
258- _skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
259- _skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
260- _skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
261- _skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
262- _skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
263- _skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
264- _skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
265- _skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
266- _skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
267264# fmt: on
268-
269-
270- _register_attention_processors_metadata ()
271- _register_transformer_blocks_metadata ()
0 commit comments