You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/en/optimization/fp16.md
+35-3Lines changed: 35 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -152,9 +152,39 @@ Compilation is slow the first time, but once compiled, it is significantly faste
152
152
153
153
### Regional compilation
154
154
155
-
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) reduces the cold start compilation time by only compiling a specific repeated region (or block) of the model instead of the entire model. The compiler reuses the cached and compiled code for the other blocks.
156
155
157
-
[Accelerate](https://huggingface.co/docs/accelerate/index) provides the [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method for automatically compiling the repeated blocks of a `nn.Module` sequentially. The rest of the model is compiled separately.
156
+
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
157
+
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **8–10 ×**.
158
+
159
+
To make this effortless, `ModelMixin` exposes **`compile_repeated_blocks`** API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:
160
+
161
+
```py
162
+
# pip install -U diffusers
163
+
import torch
164
+
from diffusers import StableDiffusionXLPipeline
165
+
166
+
pipe = StableDiffusionXLPipeline.from_pretrained(
167
+
"stabilityai/stable-diffusion-xl-base-1.0",
168
+
torch_dtype=torch.float16,
169
+
).to("cuda")
170
+
171
+
# Compile only the repeated Transformer layers inside the UNet
172
+
pipe.unet.compile_repeated_blocks(fullgraph=True)
173
+
```
174
+
175
+
To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:
176
+
177
+
178
+
```py
179
+
classMyUNet(ModelMixin):
180
+
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default
181
+
```
182
+
183
+
For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
184
+
185
+
**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.
201
+
170
202
171
203
### Graph breaks
172
204
@@ -241,4 +273,4 @@ An input is projected into three subspaces, represented by the projection matric
0 commit comments