From 0871dc6edadd1f4ddd6d674252cbf344d7dffe9f Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 21 Dec 2024 02:00:06 +0100 Subject: [PATCH 01/12] update --- scripts/convert_ltx_to_diffusers.py | 158 +++++++++++++++-- .../models/autoencoders/autoencoder_kl_ltx.py | 162 ++++++++++++++++-- 2 files changed, 292 insertions(+), 28 deletions(-) diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py index f4398a2e687c..3a9e861f4ee6 100644 --- a/scripts/convert_ltx_to_diffusers.py +++ b/scripts/convert_ltx_to_diffusers.py @@ -1,7 +1,9 @@ import argparse from typing import Any, Dict +from pathlib import Path import torch +from accelerate import init_empty_weights from safetensors.torch import load_file from transformers import T5EncoderModel, T5Tokenizer @@ -21,7 +23,9 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]): "k_norm": "norm_k", } -TRANSFORMER_SPECIAL_KEYS_REMAP = {} +TRANSFORMER_SPECIAL_KEYS_REMAP = { + "vae": remove_keys_, +} VAE_KEYS_RENAME_DICT = { # decoder @@ -54,10 +58,33 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]): "per_channel_statistics.std-of-means": "latents_std", } +VAE_091_RENAME_DICT = { + # decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0.upsamplers.0", + "up_blocks.2": "up_blocks.0", + "up_blocks.3": "up_blocks.1.upsamplers.0", + "up_blocks.4": "up_blocks.1", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "up_blocks.7": "up_blocks.3.upsamplers.0", + "up_blocks.8": "up_blocks.3", + # common + "per_channel_scale1": "scale1", + "per_channel_scale2": "scale2", + "last_time_embedder": "time_embedder", + "last_scale_shift_table": "scale_shift_table", +} + VAE_SPECIAL_KEYS_REMAP = { "per_channel_statistics.channel": remove_keys_, "per_channel_statistics.mean-of-means": remove_keys_, "per_channel_statistics.mean-of-stds": remove_keys_, + "model.diffusion_model": remove_keys_, +} + +VAE_091_SPECIAL_KEYS_REMAP = { + "timestep_scale_multiplier": remove_keys_, } @@ -80,13 +107,16 @@ def convert_transformer( ckpt_path: str, dtype: torch.dtype, ): - PREFIX_KEY = "" + PREFIX_KEY = "model.diffusion_model." original_state_dict = get_state_dict(load_file(ckpt_path)) - transformer = LTXVideoTransformer3DModel().to(dtype=dtype) + with init_empty_weights(): + transformer = LTXVideoTransformer3DModel() for key in list(original_state_dict.keys()): - new_key = key[len(PREFIX_KEY) :] + new_key = key[:] + if new_key.startswith(PREFIX_KEY): + new_key = key[len(PREFIX_KEY) :] for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): new_key = new_key.replace(replace_key, rename_key) update_state_dict_inplace(original_state_dict, key, new_key) @@ -97,16 +127,21 @@ def convert_transformer( continue handler_fn_inplace(key, original_state_dict) - transformer.load_state_dict(original_state_dict, strict=True) + transformer.load_state_dict(original_state_dict, strict=True, assign=True) return transformer -def convert_vae(ckpt_path: str, dtype: torch.dtype): +def convert_vae(ckpt_path: str, config, dtype: torch.dtype): + PREFIX_KEY = "vae." + original_state_dict = get_state_dict(load_file(ckpt_path)) - vae = AutoencoderKLLTXVideo().to(dtype=dtype) + with init_empty_weights(): + vae = AutoencoderKLLTXVideo(**config) for key in list(original_state_dict.keys()): new_key = key[:] + if new_key.startswith(PREFIX_KEY): + new_key = key[len(PREFIX_KEY) :] for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): new_key = new_key.replace(replace_key, rename_key) update_state_dict_inplace(original_state_dict, key, new_key) @@ -117,9 +152,107 @@ def convert_vae(ckpt_path: str, dtype: torch.dtype): continue handler_fn_inplace(key, original_state_dict) - vae.load_state_dict(original_state_dict, strict=True) + vae.load_state_dict(original_state_dict, strict=True, assign=True) return vae +# OURS_VAE_CONFIG = { +# "_class_name": "CausalVideoAutoencoder", +# "dims": 3, +# "in_channels": 3, +# "out_channels": 3, +# "latent_channels": 128, +# "blocks": [ +# ["res_x", 4], +# ["compress_all", 1], +# ["res_x_y", 1], +# ["res_x", 3], +# ["compress_all", 1], +# ["res_x_y", 1], +# ["res_x", 3], +# ["compress_all", 1], +# ["res_x", 3], +# ["res_x", 4], +# ], +# "scaling_factor": 1.0, +# "norm_layer": "pixel_norm", +# "patch_size": 4, +# "latent_log_var": "uniform", +# "use_quant_conv": False, +# "causal_decoder": False, +# } + +# { +# "_class_name": "CausalVideoAutoencoder", +# "dims": 3, "in_channels": 3, "out_channels": 3, "latent_channels": 128, +# "encoder_blocks": [["res_x", {"num_layers": 4}], ["compress_all", {}], ["res_x_y", 1], ["res_x", {"num_layers": 3}], ["compress_all", {}], ["res_x_y", 1], ["res_x", {"num_layers": 3}], ["compress_all", {}], ["res_x", {"num_layers": 3}], ["res_x", {"num_layers": 4}]], + +# previous decoder +# mid: resx +# resx +# compress_all, resx +# resxy, compress_all, resx +# resxy, compress_all, resx + +# "decoder_blocks": [["res_x", {"num_layers": 5, "inject_noise": true}], ["compress_all", {"residual": true, "multiplier": 2}], ["res_x", {"num_layers": 6, "inject_noise": true}], ["compress_all", {"residual": true, "multiplier": 2}], ["res_x", {"num_layers": 7, "inject_noise": true}], ["compress_all", {"residual": true, "multiplier": 2}], ["res_x", {"num_layers": 8, "inject_noise": false}]], + +# current decoder +# mid: resx +# compress_all, resx +# compress_all, resx +# compress_all, resx + +# "scaling_factor": 1.0, "norm_layer": "pixel_norm", "patch_size": 4, "latent_log_var": "uniform", "use_quant_conv": false, "causal_decoder": false, "timestep_conditioning": true +# } + +def get_vae_config(version: str) -> Dict[str, Any]: + if version == "0.9.0": + config = { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (128, 256, 512, 512), + "decoder_block_out_channels": (128, 256, 512, 512), + "layers_per_block": (4, 3, 3, 3, 4), + "decoder_layers_per_block": (4, 3, 3, 3, 4), + "spatio_temporal_scaling": (True, True, True, False), + "decoder_spatio_temporal_scaling": (True, True, True, False), + "decoder_inject_noise": (False, False, False, False), + "upsample_residual": (False, False, False, False), + "upsample_factor": (1, 1, 1, 1), + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "scaling_factor": 1.0, + "encoder_causal": True, + "decoder_causal": False, + "timestep_conditioning": False, + } + elif version == "0.9.1": + config = { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (128, 256, 512, 512), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 3, 3, 3, 4), + "decoder_layers_per_block": (5, 6, 7, 8), + "spatio_temporal_scaling": (True, True, True, False), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, True, True, True), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": True, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "scaling_factor": 1.0, + "encoder_causal": True, + "decoder_causal": False, + } + VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT) + VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP) + return config + def get_args(): parser = argparse.ArgumentParser() @@ -139,6 +272,7 @@ def get_args(): parser.add_argument("--save_pipeline", action="store_true") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") + parser.add_argument("--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model") return parser.parse_args() @@ -161,6 +295,7 @@ def get_args(): transformer = None dtype = DTYPE_MAPPING[args.dtype] variant = VARIANT_MAPPING[args.dtype] + output_path = Path(args.output_path) if args.save_pipeline: assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None @@ -169,13 +304,14 @@ def get_args(): transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype) if not args.save_pipeline: transformer.save_pretrained( - args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant + output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant ) if args.vae_ckpt_path is not None: - vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype) + config = get_vae_config(args.version) + vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, config, dtype) if not args.save_pipeline: - vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant) + vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant) if args.save_pipeline: text_encoder_id = "google/t5-v1_1-xxl" diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index ff202b980b95..602e7886aa7d 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -22,6 +22,7 @@ from ...loaders import FromOriginalModelMixin from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation +from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from ..normalization import RMSNorm @@ -109,7 +110,9 @@ def __init__( elementwise_affine: bool = False, non_linearity: str = "swish", is_causal: bool = True, - ): + inject_noise: bool = False, + timestep_conditioning: bool = False, + ) -> None: super().__init__() out_channels = out_channels or in_channels @@ -134,19 +137,45 @@ def __init__( self.conv_shortcut = LTXCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal ) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: + + self.scale1 = None + self.scale2 = None + if inject_noise: + self.scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + self.scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + + self.scale_shift_table = None + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) + + def forward(self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = inputs hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1) + scale_1, shift_1, scale_2, shift_2 = self.scale_shift_table.unbind(dim=0) + hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states) + if self.scale1 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn(spatial_shape, device=hidden_states.device, dtype=hidden_states.dtype) + hidden_states = hidden_states + (spatial_noise * self.scale1)[None, :, None, :, :] + hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.scale_shift_table is not None: + hidden_states = hidden_states * (1 + scale_1) + shift_1 + hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) + if self.scale2 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn(spatial_shape, device=hidden_states.device, dtype=hidden_states.dtype) + hidden_states = hidden_states + (spatial_noise * self.scale2)[None, :, None, :, :] + if self.norm3 is not None: inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1) @@ -163,12 +192,16 @@ def __init__( in_channels: int, stride: Union[int, Tuple[int, int, int]] = 1, is_causal: bool = True, + residual: bool = False, + upscale_factor: int = 1, ) -> None: super().__init__() self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.residual = residual + self.upscale_factor = upscale_factor - out_channels = in_channels * stride[0] * stride[1] * stride[2] + out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor self.conv = LTXCausalConv3d( in_channels=in_channels, @@ -178,9 +211,19 @@ def __init__( is_causal=is_causal, ) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape + if self.residual: + residual = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor + residual = residual.repeat(1, repeats, 1, 1, 1) + residual = residual[:, :, self.stride[0] - 1 :] + hidden_states = self.conv(hidden_states) hidden_states = hidden_states.reshape( batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width @@ -188,6 +231,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) hidden_states = hidden_states[:, :, self.stride[0] - 1 :] + if self.residual: + hidden_states = hidden_states + residual + return hidden_states @@ -329,9 +375,15 @@ def __init__( resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, ) -> None: super().__init__() + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + resnets = [] for _ in range(num_layers): resnets.append( @@ -342,15 +394,27 @@ def __init__( eps=resnet_eps, non_linearity=resnet_act_fn, is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, ) ) self.resnets = nn.ModuleList(resnets) self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: r"""Forward method of the `LTXMidBlock3D` class.""" + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -360,9 +424,9 @@ def create_forward(*inputs): return create_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) else: - hidden_states = resnet(hidden_states) + hidden_states = resnet(hidden_states, temb) return hidden_states @@ -403,11 +467,19 @@ def __init__( resnet_act_fn: str = "swish", spatio_temporal_scale: bool = True, is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + upsample_residual: bool = False, + upscale_factor: int = 1, ): super().__init__() out_channels = out_channels or in_channels + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + self.conv_in = None if in_channels != out_channels: self.conv_in = LTXResnetBlock3d( @@ -417,11 +489,13 @@ def __init__( eps=resnet_eps, non_linearity=resnet_act_fn, is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, ) self.upsamplers = None if spatio_temporal_scale: - self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)]) + self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels * upscale_factor, stride=(2, 2, 2), is_causal=is_causal, residual=upsample_residual, upscale_factor=upscale_factor)]) resnets = [] for _ in range(num_layers): @@ -433,15 +507,27 @@ def __init__( eps=resnet_eps, non_linearity=resnet_act_fn, is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning ) ) self.resnets = nn.ModuleList(resnets) self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: if self.conv_in is not None: hidden_states = self.conv_in(hidden_states) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -622,6 +708,8 @@ class LTXDecoder3d(nn.Module): Epsilon value for ResNet normalization layers. is_causal (`bool`, defaults to `False`): Whether this layer behaves causally (future frames depend only on past frames) or not. + timestep_conditioning (`bool`, defaults to `False`): + Whether to condition the model on timesteps. """ def __init__( @@ -635,6 +723,10 @@ def __init__( patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, is_causal: bool = False, + inject_noise: Tuple[bool, ...] = (False, False, False, False), + timestep_conditioning: bool = False, + upsample_residual: Tuple[bool, ...] = (False, False, False, False), + upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1), ) -> None: super().__init__() @@ -652,15 +744,15 @@ def __init__( ) self.mid_block = LTXMidBlock3d( - in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal + in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal, inject_noise=inject_noise[0], timestep_conditioning=timestep_conditioning ) # up blocks num_block_out_channels = len(block_out_channels) self.up_blocks = nn.ModuleList([]) for i in range(num_block_out_channels): - input_channel = output_channel - output_channel = block_out_channels[i] + input_channel = output_channel // upsample_factor[i] + output_channel = block_out_channels[i] // upsample_factor[i] up_block = LTXUpBlock3d( in_channels=input_channel, @@ -669,6 +761,10 @@ def __init__( resnet_eps=resnet_norm_eps, spatio_temporal_scale=spatio_temporal_scaling[i], is_causal=is_causal, + inject_noise=inject_noise[i + 1], + timestep_conditioning=timestep_conditioning, + upsample_residual=upsample_residual[i], + upscale_factor=upsample_factor[i], ) self.up_blocks.append(up_block) @@ -680,9 +776,16 @@ def __init__( in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal ) + # timestep embedding + self.time_embedder = None + self.scale_shift_table = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) + self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) + self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.conv_in(hidden_states) if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -704,6 +807,20 @@ def create_forward(*inputs): hidden_states = up_block(hidden_states) hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.time_embedder is not None: + embedded_timestep = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + embedded_timestep = embedded_timestep.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1)) + embedded_timestep = embedded_timestep + self.scale_shift_table[None, :, None, None, None] + shift, scale = embedded_timestep.unbind(dim=1) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states) @@ -766,8 +883,15 @@ def __init__( out_channels: int = 3, latent_channels: int = 128, block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), - spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), + decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), + decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), + decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), + decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False), + upsample_residual: Tuple[bool, ...] = (False, False, False, False), + upsample_factor: Tuple[int, ...] = (1, 1, 1, 1), + timestep_conditioning: bool = False, patch_size: int = 4, patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, @@ -791,13 +915,17 @@ def __init__( self.decoder = LTXDecoder3d( in_channels=latent_channels, out_channels=out_channels, - block_out_channels=block_out_channels, - spatio_temporal_scaling=spatio_temporal_scaling, - layers_per_block=layers_per_block, + block_out_channels=decoder_block_out_channels, + spatio_temporal_scaling=decoder_spatio_temporal_scaling, + layers_per_block=decoder_layers_per_block, patch_size=patch_size, patch_size_t=patch_size_t, resnet_norm_eps=resnet_norm_eps, is_causal=decoder_causal, + timestep_conditioning=timestep_conditioning, + inject_noise=decoder_inject_noise, + upsample_residual=upsample_residual, + upsample_factor=upsample_factor, ) latents_mean = torch.zeros((latent_channels,), requires_grad=False) From 58a51aa5e3060d587ba6746909a50bd698adf2d6 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 21 Dec 2024 02:01:04 +0100 Subject: [PATCH 02/12] make style --- scripts/convert_ltx_to_diffusers.py | 56 ++----------------- .../models/autoencoders/autoencoder_kl_ltx.py | 30 +++++++--- 2 files changed, 27 insertions(+), 59 deletions(-) diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py index 3a9e861f4ee6..65ee4054af8f 100644 --- a/scripts/convert_ltx_to_diffusers.py +++ b/scripts/convert_ltx_to_diffusers.py @@ -1,6 +1,6 @@ import argparse -from typing import Any, Dict from pathlib import Path +from typing import Any, Dict import torch from accelerate import init_empty_weights @@ -133,7 +133,7 @@ def convert_transformer( def convert_vae(ckpt_path: str, config, dtype: torch.dtype): PREFIX_KEY = "vae." - + original_state_dict = get_state_dict(load_file(ckpt_path)) with init_empty_weights(): vae = AutoencoderKLLTXVideo(**config) @@ -155,54 +155,6 @@ def convert_vae(ckpt_path: str, config, dtype: torch.dtype): vae.load_state_dict(original_state_dict, strict=True, assign=True) return vae -# OURS_VAE_CONFIG = { -# "_class_name": "CausalVideoAutoencoder", -# "dims": 3, -# "in_channels": 3, -# "out_channels": 3, -# "latent_channels": 128, -# "blocks": [ -# ["res_x", 4], -# ["compress_all", 1], -# ["res_x_y", 1], -# ["res_x", 3], -# ["compress_all", 1], -# ["res_x_y", 1], -# ["res_x", 3], -# ["compress_all", 1], -# ["res_x", 3], -# ["res_x", 4], -# ], -# "scaling_factor": 1.0, -# "norm_layer": "pixel_norm", -# "patch_size": 4, -# "latent_log_var": "uniform", -# "use_quant_conv": False, -# "causal_decoder": False, -# } - -# { -# "_class_name": "CausalVideoAutoencoder", -# "dims": 3, "in_channels": 3, "out_channels": 3, "latent_channels": 128, -# "encoder_blocks": [["res_x", {"num_layers": 4}], ["compress_all", {}], ["res_x_y", 1], ["res_x", {"num_layers": 3}], ["compress_all", {}], ["res_x_y", 1], ["res_x", {"num_layers": 3}], ["compress_all", {}], ["res_x", {"num_layers": 3}], ["res_x", {"num_layers": 4}]], - -# previous decoder -# mid: resx -# resx -# compress_all, resx -# resxy, compress_all, resx -# resxy, compress_all, resx - -# "decoder_blocks": [["res_x", {"num_layers": 5, "inject_noise": true}], ["compress_all", {"residual": true, "multiplier": 2}], ["res_x", {"num_layers": 6, "inject_noise": true}], ["compress_all", {"residual": true, "multiplier": 2}], ["res_x", {"num_layers": 7, "inject_noise": true}], ["compress_all", {"residual": true, "multiplier": 2}], ["res_x", {"num_layers": 8, "inject_noise": false}]], - -# current decoder -# mid: resx -# compress_all, resx -# compress_all, resx -# compress_all, resx - -# "scaling_factor": 1.0, "norm_layer": "pixel_norm", "patch_size": 4, "latent_log_var": "uniform", "use_quant_conv": false, "causal_decoder": false, "timestep_conditioning": true -# } def get_vae_config(version: str) -> Dict[str, Any]: if version == "0.9.0": @@ -272,7 +224,9 @@ def get_args(): parser.add_argument("--save_pipeline", action="store_true") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") - parser.add_argument("--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model") + parser.add_argument( + "--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model" + ) return parser.parse_args() diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 602e7886aa7d..fdea7a622568 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -137,13 +137,13 @@ def __init__( self.conv_shortcut = LTXCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal ) - + self.scale1 = None self.scale2 = None if inject_noise: self.scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1)) self.scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1)) - + self.scale_shift_table = None if timestep_conditioning: self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) @@ -166,7 +166,7 @@ def forward(self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None) -> if self.scale_shift_table is not None: hidden_states = hidden_states * (1 + scale_1) + shift_1 - + hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) @@ -211,7 +211,6 @@ def __init__( is_causal=is_causal, ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape @@ -495,7 +494,17 @@ def __init__( self.upsamplers = None if spatio_temporal_scale: - self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels * upscale_factor, stride=(2, 2, 2), is_causal=is_causal, residual=upsample_residual, upscale_factor=upscale_factor)]) + self.upsamplers = nn.ModuleList( + [ + LTXUpsampler3d( + out_channels * upscale_factor, + stride=(2, 2, 2), + is_causal=is_causal, + residual=upsample_residual, + upscale_factor=upscale_factor, + ) + ] + ) resnets = [] for _ in range(num_layers): @@ -508,7 +517,7 @@ def __init__( non_linearity=resnet_act_fn, is_causal=is_causal, inject_noise=inject_noise, - timestep_conditioning=timestep_conditioning + timestep_conditioning=timestep_conditioning, ) ) self.resnets = nn.ModuleList(resnets) @@ -518,7 +527,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: if self.conv_in is not None: hidden_states = self.conv_in(hidden_states) - + if self.time_embedder is not None: temb = self.time_embedder( timestep=temb.flatten(), @@ -744,7 +753,12 @@ def __init__( ) self.mid_block = LTXMidBlock3d( - in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal, inject_noise=inject_noise[0], timestep_conditioning=timestep_conditioning + in_channels=output_channel, + num_layers=layers_per_block[0], + resnet_eps=resnet_norm_eps, + is_causal=is_causal, + inject_noise=inject_noise[0], + timestep_conditioning=timestep_conditioning, ) # up blocks From 5316f4b809bca719908cfd6cada9f4d5773395a9 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 21 Dec 2024 04:41:01 +0100 Subject: [PATCH 03/12] update --- scripts/convert_ltx_to_diffusers.py | 6 +- .../models/autoencoders/autoencoder_kl_ltx.py | 134 +++++++++----- src/diffusers/pipelines/ltx/pipeline_ltx.py | 22 ++- .../test_models_autoencoder_ltx_video.py | 169 ++++++++++++++++++ tests/pipelines/ltx/test_ltx.py | 11 +- 5 files changed, 294 insertions(+), 48 deletions(-) create mode 100644 tests/models/autoencoders/test_models_autoencoder_ltx_video.py diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py index 65ee4054af8f..7df0745fd98c 100644 --- a/scripts/convert_ltx_to_diffusers.py +++ b/scripts/convert_ltx_to_diffusers.py @@ -70,8 +70,6 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]): "up_blocks.7": "up_blocks.3.upsamplers.0", "up_blocks.8": "up_blocks.3", # common - "per_channel_scale1": "scale1", - "per_channel_scale2": "scale2", "last_time_embedder": "time_embedder", "last_scale_shift_table": "scale_shift_table", } @@ -168,7 +166,7 @@ def get_vae_config(version: str) -> Dict[str, Any]: "decoder_layers_per_block": (4, 3, 3, 3, 4), "spatio_temporal_scaling": (True, True, True, False), "decoder_spatio_temporal_scaling": (True, True, True, False), - "decoder_inject_noise": (False, False, False, False), + "decoder_inject_noise": (False, False, False, False, False), "upsample_residual": (False, False, False, False), "upsample_factor": (1, 1, 1, 1), "patch_size": 4, @@ -190,7 +188,7 @@ def get_vae_config(version: str) -> Dict[str, Any]: "decoder_layers_per_block": (5, 6, 7, 8), "spatio_temporal_scaling": (True, True, True, False), "decoder_spatio_temporal_scaling": (True, True, True), - "decoder_inject_noise": (False, True, True, True), + "decoder_inject_noise": (True, True, True, False), "upsample_residual": (True, True, True), "upsample_factor": (2, 2, 2), "timestep_conditioning": True, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index fdea7a622568..d6cc60dd24c3 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -138,43 +138,53 @@ def __init__( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal ) - self.scale1 = None - self.scale2 = None + self.per_channel_scale1 = None + self.per_channel_scale2 = None if inject_noise: - self.scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1)) - self.scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1)) self.scale_shift_table = None if timestep_conditioning: self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) - def forward(self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None + ) -> torch.Tensor: hidden_states = inputs hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1) - scale_1, shift_1, scale_2, shift_2 = self.scale_shift_table.unbind(dim=0) + + if self.scale_shift_table is not None: + temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None] + shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale_1) + shift_1 hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states) - if self.scale1 is not None: + if self.per_channel_scale1 is not None: spatial_shape = hidden_states.shape[-2:] - spatial_noise = torch.randn(spatial_shape, device=hidden_states.device, dtype=hidden_states.dtype) - hidden_states = hidden_states + (spatial_noise * self.scale1)[None, :, None, :, :] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + ) + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, :, :] hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1) if self.scale_shift_table is not None: - hidden_states = hidden_states * (1 + scale_1) + shift_1 + hidden_states = hidden_states * (1 + scale_2) + shift_2 hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) - if self.scale2 is not None: + if self.per_channel_scale2 is not None: spatial_shape = hidden_states.shape[-2:] - spatial_noise = torch.randn(spatial_shape, device=hidden_states.device, dtype=hidden_states.dtype) - hidden_states = hidden_states + (spatial_noise * self.scale2)[None, :, None, :, :] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + ) + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, :, :] if self.norm3 is not None: inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1) @@ -318,7 +328,12 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: r"""Forward method of the `LTXDownBlock3D` class.""" for i, resnet in enumerate(self.resnets): @@ -330,16 +345,18 @@ def create_forward(*inputs): return create_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, generator + ) else: - hidden_states = resnet(hidden_states) + hidden_states = resnet(hidden_states, temb, generator) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) if self.conv_out is not None: - hidden_states = self.conv_out(hidden_states) + hidden_states = self.conv_out(hidden_states, temb, generator) return hidden_states @@ -401,7 +418,12 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: r"""Forward method of the `LTXMidBlock3D` class.""" if self.time_embedder is not None: @@ -423,9 +445,11 @@ def create_forward(*inputs): return create_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, generator + ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, generator) return hidden_states @@ -524,9 +548,14 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: if self.conv_in is not None: - hidden_states = self.conv_in(hidden_states) + hidden_states = self.conv_in(hidden_states, temb, generator) if self.time_embedder is not None: temb = self.time_embedder( @@ -551,9 +580,11 @@ def create_forward(*inputs): return create_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, generator + ) else: - hidden_states = resnet(hidden_states) + hidden_states = resnet(hidden_states, temb, generator) return hidden_states @@ -746,6 +777,9 @@ def __init__( block_out_channels = tuple(reversed(block_out_channels)) spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) layers_per_block = tuple(reversed(layers_per_block)) + inject_noise = tuple(reversed(inject_noise)) + upsample_residual = tuple(reversed(upsample_residual)) + upsample_factor = tuple(reversed(upsample_factor)) output_channel = block_out_channels[0] self.conv_in = LTXCausalConv3d( @@ -810,29 +844,31 @@ def create_forward(*inputs): return create_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), hidden_states, temb + ) for up_block in self.up_blocks: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states, temb) else: - hidden_states = self.mid_block(hidden_states) + hidden_states = self.mid_block(hidden_states, temb) for up_block in self.up_blocks: - hidden_states = up_block(hidden_states) + hidden_states = up_block(hidden_states, temb) hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) if self.time_embedder is not None: - embedded_timestep = self.time_embedder( + temb = self.time_embedder( timestep=temb.flatten(), resolution=None, aspect_ratio=None, batch_size=hidden_states.size(0), hidden_dtype=hidden_states.dtype, ) - embedded_timestep = embedded_timestep.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1)) - embedded_timestep = embedded_timestep + self.scale_shift_table[None, :, None, None, None] - shift, scale = embedded_timestep.unbind(dim=1) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1)) + temb = temb + self.scale_shift_table[None, ..., None, None, None] + shift, scale = temb.unbind(dim=1) hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.conv_act(hidden_states) @@ -902,7 +938,7 @@ def __init__( decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), - decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False), + decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False), upsample_residual: Tuple[bool, ...] = (False, False, False, False), upsample_factor: Tuple[int, ...] = (1, 1, 1, 1), timestep_conditioning: bool = False, @@ -1078,13 +1114,15 @@ def encode( return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) - def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + def _decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: batch_size, num_channels, num_frames, height, width = z.shape tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): - return self.tiled_decode(z, return_dict=return_dict) + return self.tiled_decode(z, temb, return_dict=return_dict) if self.use_framewise_decoding: # TODO(aryan): requires investigation @@ -1094,7 +1132,7 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls." ) else: - dec = self.decoder(z) + dec = self.decoder(z, temb) if not return_dict: return (dec,) @@ -1102,7 +1140,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut return DecoderOutput(sample=dec) @apply_forward_hook - def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + def decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: """ Decode a batch of images. @@ -1117,10 +1157,15 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp returned. """ if self.use_slicing and z.shape[0] > 1: - decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + if temb is not None: + decoded_slices = [ + self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1)) + ] + else: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] decoded = torch.cat(decoded_slices) else: - decoded = self._decode(z).sample + decoded = self._decode(z, temb).sample if not return_dict: return (decoded,) @@ -1202,7 +1247,9 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] return enc - def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + def tiled_decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: r""" Decode a batch of images using a tiled decoder. @@ -1243,7 +1290,9 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls." ) else: - time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]) + time = self.decoder( + z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb + ) row.append(time) rows.append(row) @@ -1271,6 +1320,7 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod def forward( self, sample: torch.Tensor, + temb: Optional[torch.Tensor] = None, sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, @@ -1281,7 +1331,7 @@ def forward( z = posterior.sample(generator=generator) else: z = posterior.mode() - dec = self.decode(z) + dec = self.decode(z, temb) if not return_dict: return (dec,) return dec diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 7180601dad41..176a358e1e54 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -511,6 +511,8 @@ def __call__( prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, + decode_timestep: Union[float, List[float]] = 0.05, + decode_noise_scale: Union[float, List[float]] = 0.025, output_type: Optional[str] = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -753,7 +755,25 @@ def __call__( latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor ) latents = latents.to(prompt_embeds.dtype) - video = self.vae.decode(latents, return_dict=False)[0] + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + video = self.vae.decode(latents, timestep, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) # Offload all models diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py new file mode 100644 index 000000000000..01a8c0e77806 --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py @@ -0,0 +1,169 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AutoencoderKLLTXVideo +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin + + +enable_full_determinism() + + +class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): + model_class = AutoencoderKLLTXVideo + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_ltx_video_config(self): + return { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 8, + "block_out_channels": (8, 8, 8, 8), + "decoder_block_out_channels": (8, 8, 8, 8), + "layers_per_block": (1, 1, 1, 1, 1), + "decoder_layers_per_block": (1, 1, 1, 1, 1), + "spatio_temporal_scaling": (True, True, False, False), + "decoder_spatio_temporal_scaling": (True, True, False, False), + "decoder_inject_noise": (False, False, False, False, False), + "upsample_residual": (False, False, False, False), + "upsample_factor": (1, 1, 1, 1), + "timestep_conditioning": False, + "patch_size": 1, + "patch_size_t": 1, + "encoder_causal": True, + "decoder_causal": False, + } + + @property + def dummy_input(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (16, 16) + + image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + + return {"sample": image} + + @property + def input_shape(self): + return (3, 9, 16, 16) + + @property + def output_shape(self): + return (3, 9, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_ltx_video_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "LTXEncoder3d", + "LTXDecoder3d", + "LTXDownBlock3D", + "LTXMidBlock3d", + "LTXUpBlock3d", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass + + @unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.") + def test_forward_with_norm_groups(self): + pass + + +class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): + model_class = AutoencoderKLLTXVideo + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_ltx_video_config(self): + return { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 8, + "block_out_channels": (8, 8, 8, 8), + "decoder_block_out_channels": (16, 32, 64), + "layers_per_block": (1, 1, 1, 1), + "decoder_layers_per_block": (1, 1, 1, 1), + "spatio_temporal_scaling": (True, True, True, False), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (True, True, True, False), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": True, + "patch_size": 1, + "patch_size_t": 1, + "encoder_causal": True, + "decoder_causal": False, + } + + @property + def dummy_input(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (16, 16) + + image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + timestep = torch.tensor([0.05] * batch_size, device=torch_device) + + return {"sample": image, "temb": timestep} + + @property + def input_shape(self): + return (3, 9, 16, 16) + + @property + def output_shape(self): + return (3, 9, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_ltx_video_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "LTXEncoder3d", + "LTXDecoder3d", + "LTXDownBlock3D", + "LTXMidBlock3d", + "LTXUpBlock3d", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass + + @unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.") + def test_forward_with_norm_groups(self): + pass diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py index 0f9819bfd6d8..dd166c6242fc 100644 --- a/tests/pipelines/ltx/test_ltx.py +++ b/tests/pipelines/ltx/test_ltx.py @@ -63,10 +63,19 @@ def get_dummy_components(self): torch.manual_seed(0) vae = AutoencoderKLLTXVideo( + in_channels=3, + out_channels=3, latent_channels=8, block_out_channels=(8, 8, 8, 8), - spatio_temporal_scaling=(True, True, False, False), + decoder_block_out_channels=(8, 8, 8, 8), layers_per_block=(1, 1, 1, 1, 1), + decoder_layers_per_block=(1, 1, 1, 1, 1), + spatio_temporal_scaling=(True, True, False, False), + decoder_spatio_temporal_scaling=(True, True, False, False), + decoder_inject_noise=(False, False, False, False, False), + upsample_residual=(False, False, False, False), + upsample_factor=(1, 1, 1, 1), + timestep_conditioning=False, patch_size=1, patch_size_t=1, encoder_causal=True, From a6d990c9d00b0777ad1ff93bff4ef525d3482edc Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 21 Dec 2024 04:54:44 +0100 Subject: [PATCH 04/12] update --- src/diffusers/models/autoencoders/autoencoder_kl_ltx.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index d6cc60dd24c3..a1f8f5148f64 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -167,8 +167,8 @@ def forward( spatial_shape = hidden_states.shape[-2:] spatial_noise = torch.randn( spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype - ) - hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, :, :] + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...] hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1) @@ -183,8 +183,8 @@ def forward( spatial_shape = hidden_states.shape[-2:] spatial_noise = torch.randn( spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype - ) - hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, :, :] + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...] if self.norm3 is not None: inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1) From 9d776e704fbadd978f2fdbc921f560214d0fc9fc Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 21 Dec 2024 12:50:43 +0100 Subject: [PATCH 05/12] update --- src/diffusers/pipelines/ltx/pipeline_ltx.py | 4 ++++ .../pipelines/ltx/pipeline_ltx_image2video.py | 24 +++++++++++++++++++ tests/pipelines/ltx/test_ltx_image2video.py | 11 ++++++++- 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 176a358e1e54..20627c0ab906 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -565,6 +565,10 @@ def __call__( provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.05`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `0.025`): + The interpolation factor between random noise and denoised latents at the decode timestep. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index fbb30e304d65..1aebb18f60cc 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -571,6 +571,8 @@ def __call__( prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, + decode_timestep: Union[float, List[float]] = 0.05, + decode_noise_scale: Union[float, List[float]] = 0.025, output_type: Optional[str] = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -625,6 +627,10 @@ def __call__( provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.05`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `0.025`): + The interpolation factor between random noise and denoised latents at the decode timestep. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -849,6 +855,24 @@ def __call__( latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor ) latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) diff --git a/tests/pipelines/ltx/test_ltx_image2video.py b/tests/pipelines/ltx/test_ltx_image2video.py index 40397e4c3619..1c3e018a8a4b 100644 --- a/tests/pipelines/ltx/test_ltx_image2video.py +++ b/tests/pipelines/ltx/test_ltx_image2video.py @@ -68,10 +68,19 @@ def get_dummy_components(self): torch.manual_seed(0) vae = AutoencoderKLLTXVideo( + in_channels=3, + out_channels=3, latent_channels=8, block_out_channels=(8, 8, 8, 8), - spatio_temporal_scaling=(True, True, False, False), + decoder_block_out_channels=(8, 8, 8, 8), layers_per_block=(1, 1, 1, 1, 1), + decoder_layers_per_block=(1, 1, 1, 1, 1), + spatio_temporal_scaling=(True, True, False, False), + decoder_spatio_temporal_scaling=(True, True, False, False), + decoder_inject_noise=(False, False, False, False, False), + upsample_residual=(False, False, False, False), + upsample_factor=(1, 1, 1, 1), + timestep_conditioning=False, patch_size=1, patch_size_t=1, encoder_causal=True, From 734fb71a860ac9889e33de5d73f3eb430870071a Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 21 Dec 2024 12:58:39 +0100 Subject: [PATCH 06/12] make style --- src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 1aebb18f60cc..5c9caa449ed9 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -872,7 +872,7 @@ def __call__( :, None, None, None, None ] latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise - + video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) From 8fc5cfc04eadd89c70d5e8d5d2e3c5c8cbc79771 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 22 Dec 2024 13:13:41 +0100 Subject: [PATCH 07/12] single file related changes --- docs/source/en/api/pipelines/ltx_video.md | 29 ++++++++++++++++++- src/diffusers/loaders/single_file_utils.py | 26 ++++++++++++++++- src/diffusers/pipelines/ltx/pipeline_ltx.py | 8 ++--- .../pipelines/ltx/pipeline_ltx_image2video.py | 8 ++--- 4 files changed, 61 insertions(+), 10 deletions(-) diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md index a925b848706e..5b818107774b 100644 --- a/docs/source/en/api/pipelines/ltx_video.md +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. --> -# LTX +# LTX Video [LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases. @@ -30,6 +30,7 @@ Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.f import torch from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel +# `single_file_url` could also be https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.1.safetensors single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors" transformer = LTXVideoTransformer3DModel.from_single_file( single_file_url, torch_dtype=torch.bfloat16 @@ -99,6 +100,32 @@ export_to_video(video, "output_gguf_ltx.mp4", fps=24) Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support. + + +Loading and running inference with [LTX Video 0.9.1](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) weights. + +```python +import torch +from diffusers import LTXPipeline +from diffusers.utils import export_to_video + +pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" +negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + +video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=768, + height=512, + num_frames=161, + num_inference_steps=50, +).frames[0] +export_to_video(video, "output.mp4", fps=24) +``` + Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption. ## LTXPipeline diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index f1408c2c409b..a7a2ef646cc9 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -157,6 +157,7 @@ "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"}, "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, "ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"}, + "ltx-video-0.9.1": {"pretrained_model_name_or_path": "a-r-r-o-w/LTX-Video-0.9.1-diffusers"}, "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"}, "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"}, "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"}, @@ -603,7 +604,10 @@ def infer_diffusers_model_type(checkpoint): model_type = "flux-schnell" elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]): - model_type = "ltx-video" + if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint: + model_type = "ltx-video-0.9.1" + else: + model_type = "ltx-video" elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint: encoder_key = "encoder.project_in.conv.conv.bias" @@ -2333,12 +2337,32 @@ def remove_keys_(key: str, state_dict): "per_channel_statistics.std-of-means": "latents_std", } + VAE_091_RENAME_DICT = { + # decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0.upsamplers.0", + "up_blocks.2": "up_blocks.0", + "up_blocks.3": "up_blocks.1.upsamplers.0", + "up_blocks.4": "up_blocks.1", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "up_blocks.7": "up_blocks.3.upsamplers.0", + "up_blocks.8": "up_blocks.3", + # common + "last_time_embedder": "time_embedder", + "last_scale_shift_table": "scale_shift_table", + } + VAE_SPECIAL_KEYS_REMAP = { "per_channel_statistics.channel": remove_keys_, "per_channel_statistics.mean-of-means": remove_keys_, "per_channel_statistics.mean-of-stds": remove_keys_, + "timestep_scale_multiplier": remove_keys_, } + if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict: + VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT) + for key in list(converted_state_dict.keys()): new_key = key for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 20627c0ab906..96d41bb3224b 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -511,8 +511,8 @@ def __call__( prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, - decode_timestep: Union[float, List[float]] = 0.05, - decode_noise_scale: Union[float, List[float]] = 0.025, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -565,9 +565,9 @@ def __call__( provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for negative text embeddings. - decode_timestep (`float`, defaults to `0.05`): + decode_timestep (`float`, defaults to `0.0`): The timestep at which generated video is decoded. - decode_noise_scale (`float`, defaults to `0.025`): + decode_noise_scale (`float`, defaults to `None`): The interpolation factor between random noise and denoised latents at the decode timestep. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 5c9caa449ed9..602dd4f23c87 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -571,8 +571,8 @@ def __call__( prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, - decode_timestep: Union[float, List[float]] = 0.05, - decode_noise_scale: Union[float, List[float]] = 0.025, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -627,9 +627,9 @@ def __call__( provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for negative text embeddings. - decode_timestep (`float`, defaults to `0.05`): + decode_timestep (`float`, defaults to `0.0`): The timestep at which generated video is decoded. - decode_noise_scale (`float`, defaults to `0.025`): + decode_noise_scale (`float`, defaults to `None`): The interpolation factor between random noise and denoised latents at the decode timestep. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between From 65cc82d982287b4f65eda904b9a640b796dd924a Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 22 Dec 2024 23:56:46 +0100 Subject: [PATCH 08/12] update --- docs/source/en/api/pipelines/ltx_video.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md index 5b818107774b..e5faef60b438 100644 --- a/docs/source/en/api/pipelines/ltx_video.md +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -121,6 +121,8 @@ video = pipe( width=768, height=512, num_frames=161, + decode_timestep=0.03, + decode_noise_scale=0.025, num_inference_steps=50, ).frames[0] export_to_video(video, "output.mp4", fps=24) From 167df2cf8ab1f397c4baf914dee00199fe6a089f Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 23 Dec 2024 02:35:38 +0100 Subject: [PATCH 09/12] fix --- src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 602dd4f23c87..71fd725c915b 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -873,7 +873,7 @@ def __call__( ] latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise - video = self.vae.decode(latents, return_dict=False)[0] + video = self.vae.decode(latents, timestep, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) # Offload all models From 60611892f1b4a6a91f9432feacb482a4eefa26aa Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 23 Dec 2024 06:52:46 +0100 Subject: [PATCH 10/12] update single file urls and docs --- docs/source/en/api/pipelines/ltx_video.md | 11 ++++++++++- src/diffusers/loaders/single_file_utils.py | 4 ++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md index e5faef60b438..017a8ac49e53 100644 --- a/docs/source/en/api/pipelines/ltx_video.md +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -22,9 +22,18 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m +Available models: + +| Model name | Recommended dtype | +|:-------------:|:-----------------:| +| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` | +| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` | + +Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository. + ## Loading Single Files -Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. +Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format. ```python import torch diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index a7a2ef646cc9..bf5d15909825 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -156,8 +156,8 @@ "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"}, "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"}, "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, - "ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"}, - "ltx-video-0.9.1": {"pretrained_model_name_or_path": "a-r-r-o-w/LTX-Video-0.9.1-diffusers"}, + "ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"}, + "ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"}, "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"}, "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"}, "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"}, From 178c22dc51dfa164e0f334a2c2beb87898eaca60 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 23 Dec 2024 11:20:37 +0100 Subject: [PATCH 11/12] update --- .../test_models_autoencoder_ltx_video.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py index 01a8c0e77806..37f9837c8245 100644 --- a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py @@ -82,11 +82,11 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing_is_applied(self): expected_set = { - "LTXEncoder3d", - "LTXDecoder3d", - "LTXDownBlock3D", - "LTXMidBlock3d", - "LTXUpBlock3d", + "LTXVideoEncoder3d", + "LTXVideoDecoder3d", + "LTXVideoDownBlock3D", + "LTXVideoMidBlock3d", + "LTXVideoUpBlock3d", } super().test_gradient_checkpointing_is_applied(expected_set=expected_set) @@ -152,11 +152,11 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing_is_applied(self): expected_set = { - "LTXEncoder3d", - "LTXDecoder3d", - "LTXDownBlock3D", - "LTXMidBlock3d", - "LTXUpBlock3d", + "LTXVideoEncoder3d", + "LTXVideoDecoder3d", + "LTXVideoDownBlock3D", + "LTXVideoMidBlock3d", + "LTXVideoUpBlock3d", } super().test_gradient_checkpointing_is_applied(expected_set=expected_set) From a5e6c13fc6ecb1861ccd6ee3da7cdb8c7a1e0f68 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 23 Dec 2024 14:30:23 +0100 Subject: [PATCH 12/12] fix --- tests/lora/test_lora_layers_ltx_video.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py index 1ed426f6e8dd..0eccaa73ad42 100644 --- a/tests/lora/test_lora_layers_ltx_video.py +++ b/tests/lora/test_lora_layers_ltx_video.py @@ -52,10 +52,19 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): } transformer_cls = LTXVideoTransformer3DModel vae_kwargs = { + "in_channels": 3, + "out_channels": 3, "latent_channels": 8, "block_out_channels": (8, 8, 8, 8), - "spatio_temporal_scaling": (True, True, False, False), + "decoder_block_out_channels": (8, 8, 8, 8), "layers_per_block": (1, 1, 1, 1, 1), + "decoder_layers_per_block": (1, 1, 1, 1, 1), + "spatio_temporal_scaling": (True, True, False, False), + "decoder_spatio_temporal_scaling": (True, True, False, False), + "decoder_inject_noise": (False, False, False, False, False), + "upsample_residual": (False, False, False, False), + "upsample_factor": (1, 1, 1, 1), + "timestep_conditioning": False, "patch_size": 1, "patch_size_t": 1, "encoder_causal": True,