Skip to content

Commit 83ba712

Browse files
sayakpaulanijain2305
authored andcommitted
Merge branch 'main' into compile_utils
2 parents 75e665b + f20b83a commit 83ba712

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+171
-102
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,15 +1414,10 @@ def compile_repeated_blocks(self, *args, **kwargs):
14141414
can reduce end-to-end compile time substantially, while preserving the
14151415
runtime speed-ups you would expect from a full `torch.compile`.
14161416
1417-
The set of sub-modules to compile is discovered in one of two ways:
1418-
1419-
1. **`_repeated_blocks`** – Preferred. Define this attribute on your
1420-
subclass as a list/tuple of class names (strings). Every module whose
1421-
class name matches will be compiled.
1422-
1423-
2. **`_no_split_modules`** – Fallback. If the preferred attribute is
1424-
missing or empty, we fall back to the legacy Diffusers attribute
1425-
`_no_split_modules`.
1417+
The set of sub-modules to compile is discovered by the presence of
1418+
**`_repeated_blocks`** attribute in the model definition. Define this
1419+
attribute on your model subclass as a list/tuple of class names
1420+
(strings). Every module whose class name matches will be compiled.
14261421
14271422
Once discovered, each matching sub-module is compiled by calling
14281423
`submodule.compile(*args, **kwargs)`. Any positional or keyword
@@ -1431,22 +1426,16 @@ class name matches will be compiled.
14311426
"""
14321427
repeated_blocks = getattr(self, "_repeated_blocks", None)
14331428

1434-
if not repeated_blocks:
1435-
logger.warning("`_repeated_blocks` attribute is empty. Using `_no_split_modules` to find compile regions.")
1436-
1437-
repeated_blocks = getattr(self, "_no_split_modules", None)
1438-
14391429
if not repeated_blocks:
14401430
raise ValueError(
1441-
"Both `_repeated_blocks` and `_no_split_modules` attribute are empty. "
1442-
"Set `_repeated_blocks` for the model to benefit from faster compilation. "
1431+
"`_repeated_blocks` attribute is empty. "
1432+
f"Set `_repeated_blocks` for the class `{self.__class__.__name__}` to benefit from faster compilation. "
14431433
)
1444-
14451434
has_compiled_region = False
14461435
for submod in self.modules():
14471436
if submod.__class__.__name__ in repeated_blocks:
1448-
has_compiled_region = True
14491437
submod.compile(*args, **kwargs)
1438+
has_compiled_region = True
14501439

14511440
if not has_compiled_region:
14521441
raise ValueError(

src/diffusers/models/transformers/transformer_chroma.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ class ChromaTransformer2DModel(
407407

408408
_supports_gradient_checkpointing = True
409409
_no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
410+
_repeated_blocks = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
410411
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
411412

412413
@register_to_config

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ class FluxTransformer2DModel(
227227
_supports_gradient_checkpointing = True
228228
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
229229
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
230-
_repeated_blocks = _no_split_modules
230+
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
231231

232232
@register_to_config
233233
def __init__(

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
870870
"HunyuanVideoPatchEmbed",
871871
"HunyuanVideoTokenRefiner",
872872
]
873+
_repeated_blocks = [
874+
"HunyuanVideoTransformerBlock",
875+
"HunyuanVideoSingleTransformerBlock",
876+
"HunyuanVideoPatchEmbed",
877+
"HunyuanVideoTokenRefiner",
878+
]
873879

874880
@register_to_config
875881
def __init__(

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
328328

329329
_supports_gradient_checkpointing = True
330330
_skip_layerwise_casting_patterns = ["norm"]
331+
_repeated_blocks = ["LTXVideoTransformerBlock"]
331332

332333
@register_to_config
333334
def __init__(

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
345345
_no_split_modules = ["WanTransformerBlock"]
346346
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
347347
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
348-
_repeated_blocks = _no_split_modules
348+
_repeated_blocks = ["WanTransformerBlock"]
349349

350350
@register_to_config
351351
def __init__(

src/diffusers/models/unets/unet_2d_condition.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ class conditioning with `class_embed_type` equal to `None`.
167167
_supports_gradient_checkpointing = True
168168
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
169169
_skip_layerwise_casting_patterns = ["norm"]
170+
_repeated_blocks = ["BasicTransformerBlock"]
170171

171172
@register_to_config
172173
def __init__(

src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
replace_example_docstring,
4242
)
4343
from ...utils.import_utils import is_transformers_version
44-
from ...utils.torch_utils import randn_tensor
44+
from ...utils.torch_utils import empty_device_cache, randn_tensor
4545
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
4646
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
4747

@@ -267,9 +267,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
267267

268268
if self.device.type != "cpu":
269269
self.to("cpu", silence_dtype_warnings=True)
270-
device_mod = getattr(torch, device.type, None)
271-
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
272-
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
270+
empty_device_cache(device.type)
273271

274272
model_sequence = [
275273
self.text_encoder.text_model,

src/diffusers/pipelines/consisid/consisid_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def prepare_face_models(model_path, device, dtype):
294294
295295
Parameters:
296296
- model_path: Path to the directory containing model files.
297-
- device: The device (e.g., 'cuda', 'cpu') where models will be loaded.
297+
- device: The device (e.g., 'cuda', 'xpu', 'cpu') where models will be loaded.
298298
- dtype: Data type (e.g., torch.float32) for model inference.
299299
300300
Returns:

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
scale_lora_layers,
3838
unscale_lora_layers,
3939
)
40-
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
40+
from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
4141
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
4242
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
4343
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -1339,7 +1339,7 @@ def __call__(
13391339
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
13401340
self.unet.to("cpu")
13411341
self.controlnet.to("cpu")
1342-
torch.cuda.empty_cache()
1342+
empty_device_cache()
13431343

13441344
if not output_type == "latent":
13451345
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[

0 commit comments

Comments
 (0)