Skip to content

Commit 0454fbb

Browse files
a-r-r-o-wDN6
andauthored
First Block Cache (#11180)
* update * modify flux single blocks to make compatible with cache techniques (without too much model-specific intrusion code) * remove debug logs * update * cache context for different batches of data * fix hs residual bug for single return outputs; support ltx * fix controlnet flux * support flux, ltx i2v, ltx condition * update * update * Update docs/source/en/api/cache.md * Update src/diffusers/hooks/hooks.py Co-authored-by: Dhruv Nair <[email protected]> * address review comments pt. 1 * address review comments pt. 2 * cache context refacotr; address review pt. 3 * address review comments * metadata registration with decorators instead of centralized * support cogvideox * support mochi * fix * remove unused function * remove central registry based on review * update --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent cbc8ced commit 0454fbb

32 files changed

+902
-170
lines changed

docs/source/en/api/cache.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
2828
[[autodoc]] FasterCacheConfig
2929

3030
[[autodoc]] apply_faster_cache
31+
32+
### FirstBlockCacheConfig
33+
34+
[[autodoc]] FirstBlockCacheConfig
35+
36+
[[autodoc]] apply_first_block_cache

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,11 @@
133133
_import_structure["hooks"].extend(
134134
[
135135
"FasterCacheConfig",
136+
"FirstBlockCacheConfig",
136137
"HookRegistry",
137138
"PyramidAttentionBroadcastConfig",
138139
"apply_faster_cache",
140+
"apply_first_block_cache",
139141
"apply_pyramid_attention_broadcast",
140142
]
141143
)
@@ -751,9 +753,11 @@
751753
else:
752754
from .hooks import (
753755
FasterCacheConfig,
756+
FirstBlockCacheConfig,
754757
HookRegistry,
755758
PyramidAttentionBroadcastConfig,
756759
apply_faster_cache,
760+
apply_first_block_cache,
757761
apply_pyramid_attention_broadcast,
758762
)
759763
from .models import (

src/diffusers/hooks/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,23 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
from ..utils import is_torch_available
216

317

418
if is_torch_available():
519
from .faster_cache import FasterCacheConfig, apply_faster_cache
20+
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
621
from .group_offloading import apply_group_offloading
722
from .hooks import HookRegistry, ModelHook
823
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook

src/diffusers/hooks/_common.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from ..models.attention_processor import Attention, MochiAttention
16+
17+
18+
_ATTENTION_CLASSES = (Attention, MochiAttention)
19+
20+
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
21+
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
22+
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
23+
24+
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
25+
{
26+
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
27+
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
28+
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
29+
}
30+
)

src/diffusers/hooks/_helpers.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
from dataclasses import dataclass
17+
from typing import Any, Callable, Dict, Type
18+
19+
20+
@dataclass
21+
class AttentionProcessorMetadata:
22+
skip_processor_output_fn: Callable[[Any], Any]
23+
24+
25+
@dataclass
26+
class TransformerBlockMetadata:
27+
return_hidden_states_index: int = None
28+
return_encoder_hidden_states_index: int = None
29+
30+
_cls: Type = None
31+
_cached_parameter_indices: Dict[str, int] = None
32+
33+
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
34+
kwargs = kwargs or {}
35+
if identifier in kwargs:
36+
return kwargs[identifier]
37+
if self._cached_parameter_indices is not None:
38+
return args[self._cached_parameter_indices[identifier]]
39+
if self._cls is None:
40+
raise ValueError("Model class is not set for metadata.")
41+
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
42+
parameters = parameters[1:] # skip `self`
43+
self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
44+
if identifier not in self._cached_parameter_indices:
45+
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
46+
index = self._cached_parameter_indices[identifier]
47+
if index >= len(args):
48+
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
49+
return args[index]
50+
51+
52+
class AttentionProcessorRegistry:
53+
_registry = {}
54+
# TODO(aryan): this is only required for the time being because we need to do the registrations
55+
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
56+
# import errors because of the models imported in this file.
57+
_is_registered = False
58+
59+
@classmethod
60+
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
61+
cls._register()
62+
cls._registry[model_class] = metadata
63+
64+
@classmethod
65+
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
66+
cls._register()
67+
if model_class not in cls._registry:
68+
raise ValueError(f"Model class {model_class} not registered.")
69+
return cls._registry[model_class]
70+
71+
@classmethod
72+
def _register(cls):
73+
if cls._is_registered:
74+
return
75+
cls._is_registered = True
76+
_register_attention_processors_metadata()
77+
78+
79+
class TransformerBlockRegistry:
80+
_registry = {}
81+
# TODO(aryan): this is only required for the time being because we need to do the registrations
82+
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
83+
# import errors because of the models imported in this file.
84+
_is_registered = False
85+
86+
@classmethod
87+
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
88+
cls._register()
89+
metadata._cls = model_class
90+
cls._registry[model_class] = metadata
91+
92+
@classmethod
93+
def get(cls, model_class: Type) -> TransformerBlockMetadata:
94+
cls._register()
95+
if model_class not in cls._registry:
96+
raise ValueError(f"Model class {model_class} not registered.")
97+
return cls._registry[model_class]
98+
99+
@classmethod
100+
def _register(cls):
101+
if cls._is_registered:
102+
return
103+
cls._is_registered = True
104+
_register_transformer_blocks_metadata()
105+
106+
107+
def _register_attention_processors_metadata():
108+
from ..models.attention_processor import AttnProcessor2_0
109+
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
110+
111+
# AttnProcessor2_0
112+
AttentionProcessorRegistry.register(
113+
model_class=AttnProcessor2_0,
114+
metadata=AttentionProcessorMetadata(
115+
skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
116+
),
117+
)
118+
119+
# CogView4AttnProcessor
120+
AttentionProcessorRegistry.register(
121+
model_class=CogView4AttnProcessor,
122+
metadata=AttentionProcessorMetadata(
123+
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
124+
),
125+
)
126+
127+
128+
def _register_transformer_blocks_metadata():
129+
from ..models.attention import BasicTransformerBlock
130+
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
131+
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
132+
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
133+
from ..models.transformers.transformer_hunyuan_video import (
134+
HunyuanVideoSingleTransformerBlock,
135+
HunyuanVideoTokenReplaceSingleTransformerBlock,
136+
HunyuanVideoTokenReplaceTransformerBlock,
137+
HunyuanVideoTransformerBlock,
138+
)
139+
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
140+
from ..models.transformers.transformer_mochi import MochiTransformerBlock
141+
from ..models.transformers.transformer_wan import WanTransformerBlock
142+
143+
# BasicTransformerBlock
144+
TransformerBlockRegistry.register(
145+
model_class=BasicTransformerBlock,
146+
metadata=TransformerBlockMetadata(
147+
return_hidden_states_index=0,
148+
return_encoder_hidden_states_index=None,
149+
),
150+
)
151+
152+
# CogVideoX
153+
TransformerBlockRegistry.register(
154+
model_class=CogVideoXBlock,
155+
metadata=TransformerBlockMetadata(
156+
return_hidden_states_index=0,
157+
return_encoder_hidden_states_index=1,
158+
),
159+
)
160+
161+
# CogView4
162+
TransformerBlockRegistry.register(
163+
model_class=CogView4TransformerBlock,
164+
metadata=TransformerBlockMetadata(
165+
return_hidden_states_index=0,
166+
return_encoder_hidden_states_index=1,
167+
),
168+
)
169+
170+
# Flux
171+
TransformerBlockRegistry.register(
172+
model_class=FluxTransformerBlock,
173+
metadata=TransformerBlockMetadata(
174+
return_hidden_states_index=1,
175+
return_encoder_hidden_states_index=0,
176+
),
177+
)
178+
TransformerBlockRegistry.register(
179+
model_class=FluxSingleTransformerBlock,
180+
metadata=TransformerBlockMetadata(
181+
return_hidden_states_index=1,
182+
return_encoder_hidden_states_index=0,
183+
),
184+
)
185+
186+
# HunyuanVideo
187+
TransformerBlockRegistry.register(
188+
model_class=HunyuanVideoTransformerBlock,
189+
metadata=TransformerBlockMetadata(
190+
return_hidden_states_index=0,
191+
return_encoder_hidden_states_index=1,
192+
),
193+
)
194+
TransformerBlockRegistry.register(
195+
model_class=HunyuanVideoSingleTransformerBlock,
196+
metadata=TransformerBlockMetadata(
197+
return_hidden_states_index=0,
198+
return_encoder_hidden_states_index=1,
199+
),
200+
)
201+
TransformerBlockRegistry.register(
202+
model_class=HunyuanVideoTokenReplaceTransformerBlock,
203+
metadata=TransformerBlockMetadata(
204+
return_hidden_states_index=0,
205+
return_encoder_hidden_states_index=1,
206+
),
207+
)
208+
TransformerBlockRegistry.register(
209+
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
210+
metadata=TransformerBlockMetadata(
211+
return_hidden_states_index=0,
212+
return_encoder_hidden_states_index=1,
213+
),
214+
)
215+
216+
# LTXVideo
217+
TransformerBlockRegistry.register(
218+
model_class=LTXVideoTransformerBlock,
219+
metadata=TransformerBlockMetadata(
220+
return_hidden_states_index=0,
221+
return_encoder_hidden_states_index=None,
222+
),
223+
)
224+
225+
# Mochi
226+
TransformerBlockRegistry.register(
227+
model_class=MochiTransformerBlock,
228+
metadata=TransformerBlockMetadata(
229+
return_hidden_states_index=0,
230+
return_encoder_hidden_states_index=1,
231+
),
232+
)
233+
234+
# Wan
235+
TransformerBlockRegistry.register(
236+
model_class=WanTransformerBlock,
237+
metadata=TransformerBlockMetadata(
238+
return_hidden_states_index=0,
239+
return_encoder_hidden_states_index=None,
240+
),
241+
)
242+
243+
244+
# fmt: off
245+
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
246+
hidden_states = kwargs.get("hidden_states", None)
247+
if hidden_states is None and len(args) > 0:
248+
hidden_states = args[0]
249+
return hidden_states
250+
251+
252+
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
253+
hidden_states = kwargs.get("hidden_states", None)
254+
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
255+
if hidden_states is None and len(args) > 0:
256+
hidden_states = args[0]
257+
if encoder_hidden_states is None and len(args) > 1:
258+
encoder_hidden_states = args[1]
259+
return hidden_states, encoder_hidden_states
260+
261+
262+
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
263+
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
264+
# fmt: on

0 commit comments

Comments
 (0)