Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions docs/source/en/optimization/fp16.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,39 @@ Compilation is slow the first time, but once compiled, it is significantly faste

### Regional compilation

[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) reduces the cold start compilation time by only compiling a specific repeated region (or block) of the model instead of the entire model. The compiler reuses the cached and compiled code for the other blocks.

[Accelerate](https://huggingface.co/docs/accelerate/index) provides the [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method for automatically compiling the repeated blocks of a `nn.Module` sequentially. The rest of the model is compiled separately.
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **8–10 ×**.

To make this effortless, [`ModelMixin`] exposes [`ModelMixin.compile_repeated_blocks`] API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:

```py
# pip install -U diffusers
import torch
from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
).to("cuda")

# Compile only the repeated Transformer layers inside the UNet
pipe.unet.compile_repeated_blocks(fullgraph=True)
```

To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:


```py
class MyUNet(ModelMixin):
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default
```

For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).

**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.



```py
# pip install -U accelerate
Expand All @@ -167,6 +197,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
).to("cuda")
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
```
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.


### Graph breaks

Expand Down Expand Up @@ -241,4 +273,4 @@ An input is projected into three subspaces, represented by the projection matric

```py
pipeline.fuse_qkv_projections()
```
```
34 changes: 34 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_keep_in_fp32_modules = None
_skip_layerwise_casting_patterns = None
_supports_group_offloading = True
_repeated_blocks = []

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -1404,6 +1405,39 @@ def float(self, *args):
else:
return super().float(*args)

def compile_repeated_blocks(self, *args, **kwargs):
"""
Compiles *only* the frequently repeated sub-modules of a model (e.g. the Transformer layers) instead of
compiling the entire model. This technique—often called **regional compilation** (see the PyTorch recipe
https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) can reduce end-to-end compile time
substantially, while preserving the runtime speed-ups you would expect from a full `torch.compile`.

The set of sub-modules to compile is discovered by the presence of **`_repeated_blocks`** attribute in the
model definition. Define this attribute on your model subclass as a list/tuple of class names (strings). Every
module whose class name matches will be compiled.

Once discovered, each matching sub-module is compiled by calling `submodule.compile(*args, **kwargs)`. Any
positional or keyword arguments you supply to `compile_repeated_blocks` are forwarded verbatim to
`torch.compile`.
"""
repeated_blocks = getattr(self, "_repeated_blocks", None)

if not repeated_blocks:
raise ValueError(
"`_repeated_blocks` attribute is empty. "
f"Set `_repeated_blocks` for the class `{self.__class__.__name__}` to benefit from faster compilation. "
)
has_compiled_region = False
for submod in self.modules():
if submod.__class__.__name__ in repeated_blocks:
submod.compile(*args, **kwargs)
has_compiled_region = True

if not has_compiled_region:
raise ValueError(
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
)

@classmethod
def _load_pretrained_model(
cls,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/transformer_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ class ChromaTransformer2DModel(

_supports_gradient_checkpointing = True
_no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
_repeated_blocks = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]

@register_to_config
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ class FluxTransformer2DModel(
_supports_gradient_checkpointing = True
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]

@register_to_config
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
"HunyuanVideoPatchEmbed",
"HunyuanVideoTokenRefiner",
]
_repeated_blocks = [
"HunyuanVideoTransformerBlock",
"HunyuanVideoSingleTransformerBlock",
"HunyuanVideoPatchEmbed",
"HunyuanVideoTokenRefiner",
]

@register_to_config
def __init__(
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/transformer_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin

_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
_repeated_blocks = ["LTXVideoTransformerBlock"]

@register_to_config
def __init__(
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
_no_split_modules = ["WanTransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["WanTransformerBlock"]

@register_to_config
def __init__(
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class conditioning with `class_embed_type` equal to `None`.
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
_skip_layerwise_casting_patterns = ["norm"]
_repeated_blocks = ["BasicTransformerBlock"]

@register_to_config
def __init__(
Expand Down
21 changes: 21 additions & 0 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,6 +1935,27 @@ def test_torch_compile_recompilation_and_graph_break(self):
_ = model(**inputs_dict)
_ = model(**inputs_dict)

def test_torch_compile_repeated_blocks(self):
if self.model_class._repeated_blocks is None:
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")

init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict).to(torch_device)
model.compile_repeated_blocks(fullgraph=True)

recompile_limit = 1
if self.model_class.__name__ == "UNet2DConditionModel":
recompile_limit = 2
Comment on lines +1948 to +1949
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul For UNet, the shape changes between different instances of the transformer model, so we need 2 entries.


with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=recompile_limit),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat!

torch.no_grad(),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)

def test_compile_with_group_offloading(self):
torch._dynamo.config.cache_size_limit = 10000

Expand Down
Loading