Skip to content

Commit 932914f

Browse files
Apply suggestions from code review
Co-authored-by: Sayak Paul <[email protected]>
1 parent f794d66 commit 932914f

File tree

9 files changed

+73
-25
lines changed

9 files changed

+73
-25
lines changed

docs/source/en/optimization/fp16.md

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,39 @@ Compilation is slow the first time, but once compiled, it is significantly faste
152152

153153
### Regional compilation
154154

155-
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) reduces the cold start compilation time by only compiling a specific repeated region (or block) of the model instead of the entire model. The compiler reuses the cached and compiled code for the other blocks.
156155

157-
[Accelerate](https://huggingface.co/docs/accelerate/index) provides the [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method for automatically compiling the repeated blocks of a `nn.Module` sequentially. The rest of the model is compiled separately.
156+
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
157+
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **8–10 ×**.
158+
159+
To make this effortless, `ModelMixin` exposes **`compile_repeated_blocks`** API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:
160+
161+
```py
162+
# pip install -U diffusers
163+
import torch
164+
from diffusers import StableDiffusionXLPipeline
165+
166+
pipe = StableDiffusionXLPipeline.from_pretrained(
167+
"stabilityai/stable-diffusion-xl-base-1.0",
168+
torch_dtype=torch.float16,
169+
).to("cuda")
170+
171+
# Compile only the repeated Transformer layers inside the UNet
172+
pipe.unet.compile_repeated_blocks(fullgraph=True)
173+
```
174+
175+
To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:
176+
177+
178+
```py
179+
class MyUNet(ModelMixin):
180+
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default
181+
```
182+
183+
For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
184+
185+
**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.
186+
187+
158188

159189
```py
160190
# pip install -U accelerate
@@ -167,6 +197,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
167197
).to("cuda")
168198
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
169199
```
200+
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.
201+
170202

171203
### Graph breaks
172204

@@ -241,4 +273,4 @@ An input is projected into three subspaces, represented by the projection matric
241273

242274
```py
243275
pipeline.fuse_qkv_projections()
244-
```
276+
```

src/diffusers/models/modeling_utils.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,39 +1414,28 @@ def compile_repeated_blocks(self, *args, **kwargs):
14141414
can reduce end-to-end compile time substantially, while preserving the
14151415
runtime speed-ups you would expect from a full `torch.compile`.
14161416
1417-
The set of sub-modules to compile is discovered in one of two ways:
1418-
1419-
1. **`_repeated_blocks`** – Preferred. Define this attribute on your
1420-
subclass as a list/tuple of class names (strings). Every module whose
1421-
class name matches will be compiled.
1422-
1423-
2. **`_no_split_modules`** – Fallback. If the preferred attribute is
1424-
missing or empty, we fall back to the legacy Diffusers attribute
1425-
`_no_split_modules`.
1417+
The set of sub-modules to compile is discovered by the presence of
1418+
**`_repeated_blocks`** attribute in the model definition. Define this
1419+
attribute on your model subclass as a list/tuple of class names
1420+
(strings). Every module whose class name matches will be compiled.
14261421
14271422
Once discovered, each matching sub-module is compiled by calling
1428-
``submodule.compile(*args, **kwargs)``. Any positional or keyword
1429-
arguments you supply to :meth:`compile_repeated_blocks` are forwarded
1423+
`submodule.compile(*args, **kwargs)`. Any positional or keyword
1424+
arguments you supply to `compile_repeated_blocks` are forwarded
14301425
verbatim to `torch.compile`.
14311426
"""
14321427
repeated_blocks = getattr(self, "_repeated_blocks", None)
14331428

1434-
if not repeated_blocks:
1435-
logger.warning("_repeated_blocks attribute is empty. Using _no_split_modules to find compile regions.")
1436-
1437-
repeated_blocks = getattr(self, "_no_split_modules", None)
1438-
14391429
if not repeated_blocks:
14401430
raise ValueError(
1441-
"Both _repeated_blocks and _no_split_modules attribute are empty. "
1442-
"Set _repeated_blocks for the model to benefit from faster compilation. "
1431+
"`_repeated_blocks` attribute is empty. "
1432+
f"Set `_repeated_blocks` for the class `{self.__class__.__name__}` to benefit from faster compilation. "
14431433
)
1444-
14451434
has_compiled_region = False
14461435
for submod in self.modules():
14471436
if submod.__class__.__name__ in repeated_blocks:
1448-
has_compiled_region = True
14491437
submod.compile(*args, **kwargs)
1438+
has_compiled_region = True
14501439

14511440
if not has_compiled_region:
14521441
raise ValueError(

src/diffusers/models/transformers/transformer_chroma.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ class ChromaTransformer2DModel(
407407

408408
_supports_gradient_checkpointing = True
409409
_no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
410+
_repeated_blocks = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
410411
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
411412

412413
@register_to_config

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ class FluxTransformer2DModel(
227227
_supports_gradient_checkpointing = True
228228
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
229229
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
230-
_repeated_blocks = _no_split_modules
230+
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
231231

232232
@register_to_config
233233
def __init__(

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
870870
"HunyuanVideoPatchEmbed",
871871
"HunyuanVideoTokenRefiner",
872872
]
873+
_repeated_blocks = [
874+
"HunyuanVideoTransformerBlock",
875+
"HunyuanVideoSingleTransformerBlock",
876+
"HunyuanVideoPatchEmbed",
877+
"HunyuanVideoTokenRefiner",
878+
]
873879

874880
@register_to_config
875881
def __init__(

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
328328

329329
_supports_gradient_checkpointing = True
330330
_skip_layerwise_casting_patterns = ["norm"]
331+
_repeated_blocks = ["LTXVideoTransformerBlock"]
331332

332333
@register_to_config
333334
def __init__(

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
345345
_no_split_modules = ["WanTransformerBlock"]
346346
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
347347
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
348-
_repeated_blocks = _no_split_modules
348+
_repeated_blocks = ["WanTransformerBlock"]
349349

350350
@register_to_config
351351
def __init__(

src/diffusers/models/unets/unet_2d_condition.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ class conditioning with `class_embed_type` equal to `None`.
167167
_supports_gradient_checkpointing = True
168168
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
169169
_skip_layerwise_casting_patterns = ["norm"]
170+
_repeated_blocks = ["BasicTransformerBlock"]
170171

171172
@register_to_config
172173
def __init__(

tests/models/test_modeling_common.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1936,6 +1936,24 @@ def test_torch_compile_recompilation_and_graph_break(self):
19361936
_ = model(**inputs_dict)
19371937
_ = model(**inputs_dict)
19381938

1939+
def test_torch_compile_repeated_blocks(self):
1940+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1941+
1942+
model = self.model_class(**init_dict).to(torch_device)
1943+
model.compile_repeated_blocks(fullgraph=True)
1944+
1945+
recompile_limit = 1
1946+
if self.model_class.__name__ == "UNet2DConditionModel":
1947+
recompile_limit = 2
1948+
1949+
with (
1950+
torch._inductor.utils.fresh_inductor_cache(),
1951+
torch._dynamo.config.patch(recompile_limit=recompile_limit),
1952+
torch.no_grad(),
1953+
):
1954+
_ = model(**inputs_dict)
1955+
_ = model(**inputs_dict)
1956+
19391957
def test_compile_with_group_offloading(self):
19401958
torch._dynamo.config.cache_size_limit = 10000
19411959

0 commit comments

Comments
 (0)