2323from ...callbacks import MultiPipelineCallbacks , PipelineCallback
2424from ...image_processor import PipelineImageInput
2525from ...loaders import WanLoraLoaderMixin
26- from ...models import AutoencoderKLWan , WanTransformer3DModel
26+ from ...models import AutoencoderKLWan , WanVACETransformer3DModel
2727from ...schedulers import FlowMatchEulerDiscreteScheduler
2828from ...utils import is_ftfy_available , is_torch_xla_available , logging , replace_example_docstring
2929from ...utils .torch_utils import randn_tensor
@@ -137,7 +137,7 @@ def __init__(
137137 self ,
138138 tokenizer : AutoTokenizer ,
139139 text_encoder : UMT5EncoderModel ,
140- transformer : WanTransformer3DModel ,
140+ transformer : WanVACETransformer3DModel ,
141141 vae : AutoencoderKLWan ,
142142 scheduler : FlowMatchEulerDiscreteScheduler ,
143143 ):
@@ -421,6 +421,13 @@ def preprocess_conditions(
421421 f"Batch size of `video` { video .shape [0 ]} and length of `reference_images` { len (reference_images )} does not match."
422422 )
423423
424+ ref_images_lengths = [len (reference_images_batch ) for reference_images_batch in reference_images ]
425+ if any (l != ref_images_lengths [0 ] for l in ref_images_lengths ):
426+ raise ValueError (
427+ f"All batches of `reference_images` should have the same length, but got { ref_images_lengths } . Support for this "
428+ "may be added in the future."
429+ )
430+
424431 reference_images_preprocessed = []
425432 for i , reference_images_batch in enumerate (reference_images ):
426433 preprocessed_images = []
@@ -449,7 +456,10 @@ def prepare_video_latents(
449456 mask : torch .Tensor ,
450457 reference_images : Optional [List [List [torch .Tensor ]]] = None ,
451458 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
459+ device : Optional [torch .device ] = None ,
452460 ) -> torch .Tensor :
461+ device = device or self ._execution_device
462+
453463 if isinstance (generator , list ):
454464 # TODO: support this
455465 raise ValueError ("Passing a list of generators is not yet supported. This may be supported in the future." )
@@ -473,15 +483,25 @@ def prepare_video_latents(
473483 vae_dtype = self .vae .dtype
474484 video = video .to (dtype = vae_dtype )
475485
486+ latents_mean = torch .tensor (self .vae .config .latents_mean , device = device , dtype = torch .float32 ).view (
487+ 1 , self .vae .config .z_dim , 1 , 1 , 1
488+ )
489+ latents_std = 1.0 / torch .tensor (self .vae .config .latents_std , device = device , dtype = torch .float32 ).view (
490+ 1 , self .vae .config .z_dim , 1 , 1 , 1
491+ )
492+
476493 if mask is None :
477494 latents = retrieve_latents (self .vae .encode (video ), generator , sample_mode = "argmax" ).unbind (0 )
495+ latents = ((latents .float () - latents_mean ) * latents_std ).to (vae_dtype )
478496 else :
479497 mask = mask .to (dtype = vae_dtype )
480498 mask = torch .where (mask > 0.5 , 1.0 , 0.0 )
481499 inactive = video * (1 - mask )
482500 reactive = video * mask
483501 inactive = retrieve_latents (self .vae .encode (inactive ), generator , sample_mode = "argmax" )
484502 reactive = retrieve_latents (self .vae .encode (reactive ), generator , sample_mode = "argmax" )
503+ inactive = ((inactive .float () - latents_mean ) * latents_std ).to (vae_dtype )
504+ reactive = ((reactive .float () - latents_mean ) * latents_std ).to (vae_dtype )
485505 latents = torch .cat ([inactive , reactive ], dim = 1 )
486506
487507 latent_list = []
@@ -491,6 +511,7 @@ def prepare_video_latents(
491511 reference_image = reference_image .to (dtype = vae_dtype )
492512 reference_image = reference_image [None , :, None , :, :] # [1, C, 1, H, W]
493513 reference_latent = retrieve_latents (self .vae .encode (reference_image ), generator , sample_mode = "argmax" )
514+ reference_latent = ((reference_latent .float () - latents_mean ) * latents_std ).to (vae_dtype )
494515 reference_latent = torch .cat ([reference_latent , torch .zeros_like (reference_latent )], dim = 1 )
495516 latent = torch .cat ([reference_latent .squeeze (0 ), latent ], dim = 1 ) # Concat across frame dimension
496517 latent_list .append (latent )
@@ -790,7 +811,7 @@ def __call__(
790811 device ,
791812 )
792813
793- conditioning_latents = self .prepare_video_latents (video , mask , reference_images , generator )
814+ conditioning_latents = self .prepare_video_latents (video , mask , reference_images , generator , device )
794815 mask = self .prepare_masks (mask , reference_images , generator )
795816 conditioning_latents = torch .cat ([conditioning_latents , mask ], dim = 1 )
796817 conditioning_latents = conditioning_latents .to (transformer_dtype )
@@ -808,6 +829,11 @@ def __call__(
808829 latents ,
809830 )
810831
832+ if conditioning_latents .shape [2 ] != latents .shape [2 ]:
833+ logger .warning (
834+ "The number of frames in the conditioning latents does not match the number of frames to be generated. Generation quality may be affected."
835+ )
836+
811837 # 6. Denoising loop
812838 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
813839 self ._num_timesteps = len (timesteps )
0 commit comments