Skip to content

Commit 73989a6

Browse files
Apply suggestions from code review
Co-authored-by: Sayak Paul <[email protected]>
1 parent f794d66 commit 73989a6

File tree

2 files changed

+40
-8
lines changed

2 files changed

+40
-8
lines changed

docs/source/en/optimization/fp16.md

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,39 @@ Compilation is slow the first time, but once compiled, it is significantly faste
152152

153153
### Regional compilation
154154

155-
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) reduces the cold start compilation time by only compiling a specific repeated region (or block) of the model instead of the entire model. The compiler reuses the cached and compiled code for the other blocks.
156155

157-
[Accelerate](https://huggingface.co/docs/accelerate/index) provides the [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method for automatically compiling the repeated blocks of a `nn.Module` sequentially. The rest of the model is compiled separately.
156+
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
157+
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **8–10 ×**.
158+
159+
To make this effortless, `ModelMixin` exposes **`compile_repeated_blocks`** API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:
160+
161+
```py
162+
# pip install -U diffusers
163+
import torch
164+
from diffusers import StableDiffusionXLPipeline
165+
166+
pipe = StableDiffusionXLPipeline.from_pretrained(
167+
"stabilityai/stable-diffusion-xl-base-1.0",
168+
torch_dtype=torch.float16,
169+
).to("cuda")
170+
171+
# Compile only the repeated Transformer layers inside the UNet
172+
pipe.unet.compile_repeated_blocks(fullgraph=True)
173+
```
174+
175+
To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:
176+
177+
178+
```py
179+
class MyUNet(ModelMixin):
180+
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default
181+
```
182+
183+
For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
184+
185+
**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.
186+
187+
158188

159189
```py
160190
# pip install -U accelerate
@@ -167,6 +197,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
167197
).to("cuda")
168198
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
169199
```
200+
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.
201+
170202

171203
### Graph breaks
172204

@@ -241,4 +273,4 @@ An input is projected into three subspaces, represented by the projection matric
241273

242274
```py
243275
pipeline.fuse_qkv_projections()
244-
```
276+
```

src/diffusers/models/modeling_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,21 +1425,21 @@ class name matches will be compiled.
14251425
`_no_split_modules`.
14261426
14271427
Once discovered, each matching sub-module is compiled by calling
1428-
``submodule.compile(*args, **kwargs)``. Any positional or keyword
1429-
arguments you supply to :meth:`compile_repeated_blocks` are forwarded
1428+
`submodule.compile(*args, **kwargs)`. Any positional or keyword
1429+
arguments you supply to `compile_repeated_blocks` are forwarded
14301430
verbatim to `torch.compile`.
14311431
"""
14321432
repeated_blocks = getattr(self, "_repeated_blocks", None)
14331433

14341434
if not repeated_blocks:
1435-
logger.warning("_repeated_blocks attribute is empty. Using _no_split_modules to find compile regions.")
1435+
logger.warning("`_repeated_blocks` attribute is empty. Using `_no_split_modules` to find compile regions.")
14361436

14371437
repeated_blocks = getattr(self, "_no_split_modules", None)
14381438

14391439
if not repeated_blocks:
14401440
raise ValueError(
1441-
"Both _repeated_blocks and _no_split_modules attribute are empty. "
1442-
"Set _repeated_blocks for the model to benefit from faster compilation. "
1441+
"Both `_repeated_blocks` and `_no_split_modules` attribute are empty. "
1442+
"Set `_repeated_blocks` for the model to benefit from faster compilation. "
14431443
)
14441444

14451445
has_compiled_region = False

0 commit comments

Comments
 (0)