@@ -266,6 +266,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
266266 _keep_in_fp32_modules = None
267267 _skip_layerwise_casting_patterns = None
268268 _supports_group_offloading = True
269+ _repeated_blocks = []
269270
270271 def __init__ (self ):
271272 super ().__init__ ()
@@ -1404,6 +1405,54 @@ def float(self, *args):
14041405 else :
14051406 return super ().float (* args )
14061407
1408+ def compile_repeated_blocks (self , * args , ** kwargs ):
1409+ """
1410+ Compiles *only* the frequently repeated sub-modules of a model (e.g. the
1411+ Transformer layers) instead of compiling the entire model. This
1412+ technique—often called **regional compilation** (see the PyTorch recipe
1413+ https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)
1414+ can reduce end-to-end compile time substantially, while preserving the
1415+ runtime speed-ups you would expect from a full `torch.compile`.
1416+
1417+ The set of sub-modules to compile is discovered in one of two ways:
1418+
1419+ 1. **`_repeated_blocks`** – Preferred. Define this attribute on your
1420+ subclass as a list/tuple of class names (strings). Every module whose
1421+ class name matches will be compiled.
1422+
1423+ 2. **`_no_split_modules`** – Fallback. If the preferred attribute is
1424+ missing or empty, we fall back to the legacy Diffusers attribute
1425+ `_no_split_modules`.
1426+
1427+ Once discovered, each matching sub-module is compiled by calling
1428+ ``submodule.compile(*args, **kwargs)``. Any positional or keyword
1429+ arguments you supply to :meth:`compile_repeated_blocks` are forwarded
1430+ verbatim to `torch.compile`.
1431+ """
1432+ repeated_blocks = getattr (self , "_repeated_blocks" , None )
1433+
1434+ if not repeated_blocks :
1435+ logger .warning ("_repeated_blocks attribute is empty. Using _no_split_modules to find compile regions." )
1436+
1437+ repeated_blocks = getattr (self , "_no_split_modules" , None )
1438+
1439+ if not repeated_blocks :
1440+ raise ValueError (
1441+ "Both _repeated_blocks and _no_split_modules attribute are empty. "
1442+ "Set _repeated_blocks for the model to benefit from faster compilation. "
1443+ )
1444+
1445+ has_compiled_region = False
1446+ for submod in self .modules ():
1447+ if submod .__class__ .__name__ in repeated_blocks :
1448+ has_compiled_region = True
1449+ submod .compile (* args , ** kwargs )
1450+
1451+ if not has_compiled_region :
1452+ raise ValueError (
1453+ f"Regional compilation failed because { repeated_blocks } classes are not found in the model. "
1454+ )
1455+
14071456 @classmethod
14081457 def _load_pretrained_model (
14091458 cls ,
0 commit comments