Skip to content

Commit f794d66

Browse files
committed
[rfc][compile] compile method for DiffusionPipeline
1 parent 7bc0a07 commit f794d66

File tree

3 files changed

+51
-0
lines changed

3 files changed

+51
-0
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
266266
_keep_in_fp32_modules = None
267267
_skip_layerwise_casting_patterns = None
268268
_supports_group_offloading = True
269+
_repeated_blocks = []
269270

270271
def __init__(self):
271272
super().__init__()
@@ -1404,6 +1405,54 @@ def float(self, *args):
14041405
else:
14051406
return super().float(*args)
14061407

1408+
def compile_repeated_blocks(self, *args, **kwargs):
1409+
"""
1410+
Compiles *only* the frequently repeated sub-modules of a model (e.g. the
1411+
Transformer layers) instead of compiling the entire model. This
1412+
technique—often called **regional compilation** (see the PyTorch recipe
1413+
https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)
1414+
can reduce end-to-end compile time substantially, while preserving the
1415+
runtime speed-ups you would expect from a full `torch.compile`.
1416+
1417+
The set of sub-modules to compile is discovered in one of two ways:
1418+
1419+
1. **`_repeated_blocks`** – Preferred. Define this attribute on your
1420+
subclass as a list/tuple of class names (strings). Every module whose
1421+
class name matches will be compiled.
1422+
1423+
2. **`_no_split_modules`** – Fallback. If the preferred attribute is
1424+
missing or empty, we fall back to the legacy Diffusers attribute
1425+
`_no_split_modules`.
1426+
1427+
Once discovered, each matching sub-module is compiled by calling
1428+
``submodule.compile(*args, **kwargs)``. Any positional or keyword
1429+
arguments you supply to :meth:`compile_repeated_blocks` are forwarded
1430+
verbatim to `torch.compile`.
1431+
"""
1432+
repeated_blocks = getattr(self, "_repeated_blocks", None)
1433+
1434+
if not repeated_blocks:
1435+
logger.warning("_repeated_blocks attribute is empty. Using _no_split_modules to find compile regions.")
1436+
1437+
repeated_blocks = getattr(self, "_no_split_modules", None)
1438+
1439+
if not repeated_blocks:
1440+
raise ValueError(
1441+
"Both _repeated_blocks and _no_split_modules attribute are empty. "
1442+
"Set _repeated_blocks for the model to benefit from faster compilation. "
1443+
)
1444+
1445+
has_compiled_region = False
1446+
for submod in self.modules():
1447+
if submod.__class__.__name__ in repeated_blocks:
1448+
has_compiled_region = True
1449+
submod.compile(*args, **kwargs)
1450+
1451+
if not has_compiled_region:
1452+
raise ValueError(
1453+
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
1454+
)
1455+
14071456
@classmethod
14081457
def _load_pretrained_model(
14091458
cls,

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ class FluxTransformer2DModel(
227227
_supports_gradient_checkpointing = True
228228
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
229229
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
230+
_repeated_blocks = _no_split_modules
230231

231232
@register_to_config
232233
def __init__(

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
345345
_no_split_modules = ["WanTransformerBlock"]
346346
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
347347
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
348+
_repeated_blocks = _no_split_modules
348349

349350
@register_to_config
350351
def __init__(

0 commit comments

Comments
 (0)