@@ -266,6 +266,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
266266 _keep_in_fp32_modules = None
267267 _skip_layerwise_casting_patterns = None
268268 _supports_group_offloading = True
269+ _compile_blocks = []
269270
270271 def __init__ (self ):
271272 super ().__init__ ()
@@ -1402,6 +1403,54 @@ def float(self, *args):
14021403 else :
14031404 return super ().float (* args )
14041405
1406+ def compile_repeated_blocks (self , * args , ** kwargs ):
1407+ """
1408+ Compiles *only* the frequently repeated sub-modules of a model (e.g. the
1409+ Transformer layers) instead of compiling the entire model. This
1410+ technique—often called **regional compilation** (see the PyTorch recipe
1411+ https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)
1412+ can reduce end-to-end compile time substantially, while preserving the
1413+ runtime speed-ups you would expect from a full `torch.compile`.
1414+
1415+ The set of sub-modules to compile is discovered in one of two ways:
1416+
1417+ 1. **`_compile_blocks`** – Preferred. Define this attribute on your
1418+ subclass as a list/tuple of class names (strings). Every module whose
1419+ class name matches will be compiled.
1420+
1421+ 2. **`_no_split_modules`** – Fallback. If the preferred attribute is
1422+ missing or empty, we fall back to the legacy Diffusers attribute
1423+ `_no_split_modules`.
1424+
1425+ Once discovered, each matching sub-module is compiled by calling
1426+ ``submodule.compile(*args, **kwargs)``. Any positional or keyword
1427+ arguments you supply to :meth:`compile_repeated_blocks` are forwarded
1428+ verbatim to `torch.compile`.
1429+ """
1430+ compile_blocks = getattr (self , "_compile_blocks" , None )
1431+
1432+ if not compile_blocks :
1433+ logger .warning ("_compile_blocks attribute is empty. Using _no_split_modules to find compile regions." )
1434+
1435+ compile_blocks = getattr (self , "_no_split_modules" , None )
1436+
1437+ if not compile_blocks :
1438+ raise ValueError (
1439+ "Both _compile_blocks and _no_split_modules attribute are empty. "
1440+ "Set _compile_blocks for the model to benefit from faster compilation. "
1441+ )
1442+
1443+ has_compiled_region = False
1444+ for submod in self .modules ():
1445+ if submod .__class__ .__name__ in compile_blocks :
1446+ has_compiled_region = True
1447+ submod .compile (* args , ** kwargs )
1448+
1449+ if not has_compiled_region :
1450+ raise ValueError (
1451+ f"Regional compilation failed because { compile_blocks } classes are not found in the model. "
1452+ )
1453+
14051454 @classmethod
14061455 def _load_pretrained_model (
14071456 cls ,
0 commit comments