Skip to content

Commit d93381c

Browse files
anijain2305sayakpaulgithub-actions[bot]
authored
[rfc][compile] compile method for DiffusionPipeline (#11705)
* [rfc][compile] compile method for DiffusionPipeline * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> * Apply style fixes * Update docs/source/en/optimization/fp16.md * check --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 3649d7b commit d93381c

File tree

9 files changed

+101
-3
lines changed

9 files changed

+101
-3
lines changed

docs/source/en/optimization/fp16.md

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,39 @@ Compilation is slow the first time, but once compiled, it is significantly faste
152152

153153
### Regional compilation
154154

155-
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) reduces the cold start compilation time by only compiling a specific repeated region (or block) of the model instead of the entire model. The compiler reuses the cached and compiled code for the other blocks.
156155

157-
[Accelerate](https://huggingface.co/docs/accelerate/index) provides the [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method for automatically compiling the repeated blocks of a `nn.Module` sequentially. The rest of the model is compiled separately.
156+
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
157+
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **8–10 ×**.
158+
159+
To make this effortless, [`ModelMixin`] exposes [`ModelMixin.compile_repeated_blocks`] API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:
160+
161+
```py
162+
# pip install -U diffusers
163+
import torch
164+
from diffusers import StableDiffusionXLPipeline
165+
166+
pipe = StableDiffusionXLPipeline.from_pretrained(
167+
"stabilityai/stable-diffusion-xl-base-1.0",
168+
torch_dtype=torch.float16,
169+
).to("cuda")
170+
171+
# Compile only the repeated Transformer layers inside the UNet
172+
pipe.unet.compile_repeated_blocks(fullgraph=True)
173+
```
174+
175+
To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:
176+
177+
178+
```py
179+
class MyUNet(ModelMixin):
180+
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default
181+
```
182+
183+
For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
184+
185+
**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.
186+
187+
158188

159189
```py
160190
# pip install -U accelerate
@@ -167,6 +197,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
167197
).to("cuda")
168198
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
169199
```
200+
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.
201+
170202

171203
### Graph breaks
172204

@@ -241,4 +273,4 @@ An input is projected into three subspaces, represented by the projection matric
241273

242274
```py
243275
pipeline.fuse_qkv_projections()
244-
```
276+
```

src/diffusers/models/modeling_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
266266
_keep_in_fp32_modules = None
267267
_skip_layerwise_casting_patterns = None
268268
_supports_group_offloading = True
269+
_repeated_blocks = []
269270

270271
def __init__(self):
271272
super().__init__()
@@ -1404,6 +1405,39 @@ def float(self, *args):
14041405
else:
14051406
return super().float(*args)
14061407

1408+
def compile_repeated_blocks(self, *args, **kwargs):
1409+
"""
1410+
Compiles *only* the frequently repeated sub-modules of a model (e.g. the Transformer layers) instead of
1411+
compiling the entire model. This technique—often called **regional compilation** (see the PyTorch recipe
1412+
https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) can reduce end-to-end compile time
1413+
substantially, while preserving the runtime speed-ups you would expect from a full `torch.compile`.
1414+
1415+
The set of sub-modules to compile is discovered by the presence of **`_repeated_blocks`** attribute in the
1416+
model definition. Define this attribute on your model subclass as a list/tuple of class names (strings). Every
1417+
module whose class name matches will be compiled.
1418+
1419+
Once discovered, each matching sub-module is compiled by calling `submodule.compile(*args, **kwargs)`. Any
1420+
positional or keyword arguments you supply to `compile_repeated_blocks` are forwarded verbatim to
1421+
`torch.compile`.
1422+
"""
1423+
repeated_blocks = getattr(self, "_repeated_blocks", None)
1424+
1425+
if not repeated_blocks:
1426+
raise ValueError(
1427+
"`_repeated_blocks` attribute is empty. "
1428+
f"Set `_repeated_blocks` for the class `{self.__class__.__name__}` to benefit from faster compilation. "
1429+
)
1430+
has_compiled_region = False
1431+
for submod in self.modules():
1432+
if submod.__class__.__name__ in repeated_blocks:
1433+
submod.compile(*args, **kwargs)
1434+
has_compiled_region = True
1435+
1436+
if not has_compiled_region:
1437+
raise ValueError(
1438+
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
1439+
)
1440+
14071441
@classmethod
14081442
def _load_pretrained_model(
14091443
cls,

src/diffusers/models/transformers/transformer_chroma.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ class ChromaTransformer2DModel(
407407

408408
_supports_gradient_checkpointing = True
409409
_no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
410+
_repeated_blocks = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
410411
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
411412

412413
@register_to_config

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ class FluxTransformer2DModel(
227227
_supports_gradient_checkpointing = True
228228
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
229229
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
230+
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
230231

231232
@register_to_config
232233
def __init__(

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
870870
"HunyuanVideoPatchEmbed",
871871
"HunyuanVideoTokenRefiner",
872872
]
873+
_repeated_blocks = [
874+
"HunyuanVideoTransformerBlock",
875+
"HunyuanVideoSingleTransformerBlock",
876+
"HunyuanVideoPatchEmbed",
877+
"HunyuanVideoTokenRefiner",
878+
]
873879

874880
@register_to_config
875881
def __init__(

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
328328

329329
_supports_gradient_checkpointing = True
330330
_skip_layerwise_casting_patterns = ["norm"]
331+
_repeated_blocks = ["LTXVideoTransformerBlock"]
331332

332333
@register_to_config
333334
def __init__(

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
345345
_no_split_modules = ["WanTransformerBlock"]
346346
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
347347
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
348+
_repeated_blocks = ["WanTransformerBlock"]
348349

349350
@register_to_config
350351
def __init__(

src/diffusers/models/unets/unet_2d_condition.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ class conditioning with `class_embed_type` equal to `None`.
167167
_supports_gradient_checkpointing = True
168168
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
169169
_skip_layerwise_casting_patterns = ["norm"]
170+
_repeated_blocks = ["BasicTransformerBlock"]
170171

171172
@register_to_config
172173
def __init__(

tests/models/test_modeling_common.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1935,6 +1935,27 @@ def test_torch_compile_recompilation_and_graph_break(self):
19351935
_ = model(**inputs_dict)
19361936
_ = model(**inputs_dict)
19371937

1938+
def test_torch_compile_repeated_blocks(self):
1939+
if self.model_class._repeated_blocks is None:
1940+
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
1941+
1942+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1943+
1944+
model = self.model_class(**init_dict).to(torch_device)
1945+
model.compile_repeated_blocks(fullgraph=True)
1946+
1947+
recompile_limit = 1
1948+
if self.model_class.__name__ == "UNet2DConditionModel":
1949+
recompile_limit = 2
1950+
1951+
with (
1952+
torch._inductor.utils.fresh_inductor_cache(),
1953+
torch._dynamo.config.patch(recompile_limit=recompile_limit),
1954+
torch.no_grad(),
1955+
):
1956+
_ = model(**inputs_dict)
1957+
_ = model(**inputs_dict)
1958+
19381959
def test_compile_with_group_offloading(self):
19391960
torch._dynamo.config.cache_size_limit = 10000
19401961

0 commit comments

Comments
 (0)