@@ -112,10 +112,20 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
112112 A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
113113 vae ([`AutoencoderKLWan`]):
114114 Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
115+ transformer_2 ([`WanTransformer3DModel`], *optional*):
116+ Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables
117+ two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise
118+ stages. If not provided, only `transformer` is used.
119+ boundary_ratio (`float`, *optional*, defaults to `None`):
120+ Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
121+ The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
122+ `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
123+ boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
115124 """
116125
117- model_cpu_offload_seq = "text_encoder->transformer->vae"
126+ model_cpu_offload_seq = "text_encoder->transformer->transformer_2-> vae"
118127 _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
128+ _optional_components = ["transformer_2" ]
119129
120130 def __init__ (
121131 self ,
@@ -124,6 +134,9 @@ def __init__(
124134 transformer : WanTransformer3DModel ,
125135 vae : AutoencoderKLWan ,
126136 scheduler : FlowMatchEulerDiscreteScheduler ,
137+ transformer_2 : Optional [WanTransformer3DModel ] = None ,
138+ boundary_ratio : Optional [float ] = None ,
139+ expand_timesteps : bool = False , # Wan2.2 ti2v
127140 ):
128141 super ().__init__ ()
129142
@@ -133,10 +146,12 @@ def __init__(
133146 tokenizer = tokenizer ,
134147 transformer = transformer ,
135148 scheduler = scheduler ,
149+ transformer_2 = transformer_2 ,
136150 )
137-
138- self .vae_scale_factor_temporal = 2 ** sum (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 4
139- self .vae_scale_factor_spatial = 2 ** len (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 8
151+ self .register_to_config (boundary_ratio = boundary_ratio )
152+ self .register_to_config (expand_timesteps = expand_timesteps )
153+ self .vae_scale_factor_temporal = self .vae .config .scale_factor_temporal if getattr (self , "vae" , None ) else 4
154+ self .vae_scale_factor_spatial = self .vae .config .scale_factor_spatial if getattr (self , "vae" , None ) else 8
140155 self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
141156
142157 def _get_t5_prompt_embeds (
@@ -270,6 +285,7 @@ def check_inputs(
270285 prompt_embeds = None ,
271286 negative_prompt_embeds = None ,
272287 callback_on_step_end_tensor_inputs = None ,
288+ guidance_scale_2 = None ,
273289 ):
274290 if height % 16 != 0 or width % 16 != 0 :
275291 raise ValueError (f"`height` and `width` have to be divisible by 16 but are { height } and { width } ." )
@@ -302,6 +318,9 @@ def check_inputs(
302318 ):
303319 raise ValueError (f"`negative_prompt` has to be of type `str` or `list` but is { type (negative_prompt )} " )
304320
321+ if self .config .boundary_ratio is None and guidance_scale_2 is not None :
322+ raise ValueError ("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None." )
323+
305324 def prepare_latents (
306325 self ,
307326 batch_size : int ,
@@ -369,6 +388,7 @@ def __call__(
369388 num_frames : int = 81 ,
370389 num_inference_steps : int = 50 ,
371390 guidance_scale : float = 5.0 ,
391+ guidance_scale_2 : Optional [float ] = None ,
372392 num_videos_per_prompt : Optional [int ] = 1 ,
373393 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
374394 latents : Optional [torch .Tensor ] = None ,
@@ -407,6 +427,10 @@ def __call__(
407427 of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
408428 `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
409429 the text `prompt`, usually at the expense of lower image quality.
430+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
431+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
432+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
433+ and the pipeline's `boundary_ratio` are not None.
410434 num_videos_per_prompt (`int`, *optional*, defaults to 1):
411435 The number of images to generate per prompt.
412436 generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -461,6 +485,7 @@ def __call__(
461485 prompt_embeds ,
462486 negative_prompt_embeds ,
463487 callback_on_step_end_tensor_inputs ,
488+ guidance_scale_2 ,
464489 )
465490
466491 if num_frames % self .vae_scale_factor_temporal != 1 :
@@ -470,7 +495,11 @@ def __call__(
470495 num_frames = num_frames // self .vae_scale_factor_temporal * self .vae_scale_factor_temporal + 1
471496 num_frames = max (num_frames , 1 )
472497
498+ if self .config .boundary_ratio is not None and guidance_scale_2 is None :
499+ guidance_scale_2 = guidance_scale
500+
473501 self ._guidance_scale = guidance_scale
502+ self ._guidance_scale_2 = guidance_scale_2
474503 self ._attention_kwargs = attention_kwargs
475504 self ._current_timestep = None
476505 self ._interrupt = False
@@ -520,21 +549,44 @@ def __call__(
520549 latents ,
521550 )
522551
552+ mask = torch .ones (latents .shape , dtype = torch .float32 , device = device )
553+
523554 # 6. Denoising loop
524555 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
525556 self ._num_timesteps = len (timesteps )
526557
558+ if self .config .boundary_ratio is not None :
559+ boundary_timestep = self .config .boundary_ratio * self .scheduler .config .num_train_timesteps
560+ else :
561+ boundary_timestep = None
562+
527563 with self .progress_bar (total = num_inference_steps ) as progress_bar :
528564 for i , t in enumerate (timesteps ):
529565 if self .interrupt :
530566 continue
531567
532568 self ._current_timestep = t
533- latent_model_input = latents .to (transformer_dtype )
534- timestep = t .expand (latents .shape [0 ])
535569
536- with self .transformer .cache_context ("cond" ):
537- noise_pred = self .transformer (
570+ if boundary_timestep is None or t >= boundary_timestep :
571+ # wan2.1 or high-noise stage in wan2.2
572+ current_model = self .transformer
573+ current_guidance_scale = guidance_scale
574+ else :
575+ # low-noise stage in wan2.2
576+ current_model = self .transformer_2
577+ current_guidance_scale = guidance_scale_2
578+
579+ latent_model_input = latents .to (transformer_dtype )
580+ if self .config .expand_timesteps :
581+ # seq_len: num_latent_frames * latent_height//2 * latent_width//2
582+ temp_ts = (mask [0 ][0 ][:, ::2 , ::2 ] * t ).flatten ()
583+ # batch_size, seq_len
584+ timestep = temp_ts .unsqueeze (0 ).expand (latents .shape [0 ], - 1 )
585+ else :
586+ timestep = t .expand (latents .shape [0 ])
587+
588+ with current_model .cache_context ("cond" ):
589+ noise_pred = current_model (
538590 hidden_states = latent_model_input ,
539591 timestep = timestep ,
540592 encoder_hidden_states = prompt_embeds ,
@@ -543,15 +595,15 @@ def __call__(
543595 )[0 ]
544596
545597 if self .do_classifier_free_guidance :
546- with self . transformer .cache_context ("uncond" ):
547- noise_uncond = self . transformer (
598+ with current_model .cache_context ("uncond" ):
599+ noise_uncond = current_model (
548600 hidden_states = latent_model_input ,
549601 timestep = timestep ,
550602 encoder_hidden_states = negative_prompt_embeds ,
551603 attention_kwargs = attention_kwargs ,
552604 return_dict = False ,
553605 )[0 ]
554- noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond )
606+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond )
555607
556608 # compute the previous noisy sample x_t -> x_t-1
557609 latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
0 commit comments