Skip to content

Commit 2a92c01

Browse files
committed
[rfc][compile] compile method for DiffusionPipeline
1 parent 47ef794 commit 2a92c01

File tree

3 files changed

+51
-0
lines changed

3 files changed

+51
-0
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
266266
_keep_in_fp32_modules = None
267267
_skip_layerwise_casting_patterns = None
268268
_supports_group_offloading = True
269+
_compile_blocks = []
269270

270271
def __init__(self):
271272
super().__init__()
@@ -1402,6 +1403,54 @@ def float(self, *args):
14021403
else:
14031404
return super().float(*args)
14041405

1406+
def compile_repeated_blocks(self, *args, **kwargs):
1407+
"""
1408+
Compiles *only* the frequently repeated sub-modules of a model (e.g. the
1409+
Transformer layers) instead of compiling the entire model. This
1410+
technique—often called **regional compilation** (see the PyTorch recipe
1411+
https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)
1412+
can reduce end-to-end compile time substantially, while preserving the
1413+
runtime speed-ups you would expect from a full `torch.compile`.
1414+
1415+
The set of sub-modules to compile is discovered in one of two ways:
1416+
1417+
1. **`_compile_blocks`** – Preferred. Define this attribute on your
1418+
subclass as a list/tuple of class names (strings). Every module whose
1419+
class name matches will be compiled.
1420+
1421+
2. **`_no_split_modules`** – Fallback. If the preferred attribute is
1422+
missing or empty, we fall back to the legacy Diffusers attribute
1423+
`_no_split_modules`.
1424+
1425+
Once discovered, each matching sub-module is compiled by calling
1426+
``submodule.compile(*args, **kwargs)``. Any positional or keyword
1427+
arguments you supply to :meth:`compile_repeated_blocks` are forwarded
1428+
verbatim to `torch.compile`.
1429+
"""
1430+
compile_blocks = getattr(self, "_compile_blocks", None)
1431+
1432+
if not compile_blocks:
1433+
logger.warning("_compile_blocks attribute is empty. Using _no_split_modules to find compile regions.")
1434+
1435+
compile_blocks = getattr(self, "_no_split_modules", None)
1436+
1437+
if not compile_blocks:
1438+
raise ValueError(
1439+
"Both _compile_blocks and _no_split_modules attribute are empty. "
1440+
"Set _compile_blocks for the model to benefit from faster compilation. "
1441+
)
1442+
1443+
has_compiled_region = False
1444+
for submod in self.modules():
1445+
if submod.__class__.__name__ in compile_blocks:
1446+
has_compiled_region = True
1447+
submod.compile(*args, **kwargs)
1448+
1449+
if not has_compiled_region:
1450+
raise ValueError(
1451+
f"Regional compilation failed because {compile_blocks} classes are not found in the model. "
1452+
)
1453+
14051454
@classmethod
14061455
def _load_pretrained_model(
14071456
cls,

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ class FluxTransformer2DModel(
227227
_supports_gradient_checkpointing = True
228228
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
229229
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
230+
_compile_blocks = _no_split_modules
230231

231232
@register_to_config
232233
def __init__(

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
345345
_no_split_modules = ["WanTransformerBlock"]
346346
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
347347
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
348+
_compile_blocks = _no_split_modules
348349

349350
@register_to_config
350351
def __init__(

0 commit comments

Comments
 (0)