File tree Expand file tree Collapse file tree 3 files changed +37
-0
lines changed Expand file tree Collapse file tree 3 files changed +37
-0
lines changed Original file line number Diff line number Diff line change @@ -286,6 +286,8 @@ def __init__(
286286
287287 self .gradient_checkpointing = False
288288
289+ self .compile_region_classes = (FluxTransformerBlock , FluxSingleTransformerBlock )
290+
289291 @property
290292 # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
291293 def attn_processors (self ) -> Dict [str , AttentionProcessor ]:
Original file line number Diff line number Diff line change @@ -403,6 +403,8 @@ def __init__(
403403
404404 self .gradient_checkpointing = False
405405
406+ self .compile_region_classes = (WanTransformerBlock ,)
407+
406408 def forward (
407409 self ,
408410 hidden_states : torch .Tensor ,
Original file line number Diff line number Diff line change @@ -2027,6 +2027,39 @@ def _maybe_raise_error_if_group_offload_active(
20272027 return True
20282028 return False
20292029
2030+ def compile (
2031+ self ,
2032+ compile_regions_for_transformer : bool = True ,
2033+ transformer_module_name : str = "transformer" ,
2034+ other_modules_names : List [str ] = [],
2035+ ** compile_kwargs ,
2036+ ):
2037+ transformer = getattr (self , transformer_module_name , None )
2038+ if transformer is None :
2039+ raise ValueError (
2040+ f"{ transformer_module_name } not found in the pipeline. Set `transformer_module_name` to the correct module name."
2041+ )
2042+
2043+ if compile_regions_for_transformer :
2044+ compile_region_classes = getattr (transformer , "compile_region_classes" , None )
2045+ if compile_region_classes is None :
2046+ raise ValueError (
2047+ f"{ transformer_module_name } does not have `compile_region_classes` attribute. Set `compile_regions_for_transformer` to False."
2048+ )
2049+
2050+ for submod in transformer .modules ():
2051+ if isinstance (submod , compile_region_classes ):
2052+ submod .compile (** compile_kwargs )
2053+ else :
2054+ transformer .compile (** compile_kwargs )
2055+
2056+ for module_name in other_modules_names :
2057+ module = getattr (self , module_name , None )
2058+ if module is None :
2059+ raise ValueError (
2060+ f"{ module_name } not found in the pipeline. Set `other_modules_names` to the correct module names."
2061+ )
2062+
20302063
20312064class StableDiffusionMixin :
20322065 r"""
You can’t perform that action at this time.
0 commit comments