Skip to content

Commit 15f95d5

Browse files
committed
enable caching for WanImageToVideoPipeline
1 parent e411d50 commit 15f95d5

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -761,25 +761,28 @@ def __call__(
761761
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
762762
timestep = t.expand(latents.shape[0])
763763

764-
noise_pred = current_model(
765-
hidden_states=latent_model_input,
766-
timestep=timestep,
767-
encoder_hidden_states=prompt_embeds,
768-
encoder_hidden_states_image=image_embeds,
769-
attention_kwargs=attention_kwargs,
770-
return_dict=False,
771-
)[0]
772-
773-
if self.do_classifier_free_guidance:
774-
noise_uncond = current_model(
764+
with current_model.cache_context("cond"):
765+
noise_pred = current_model(
775766
hidden_states=latent_model_input,
776767
timestep=timestep,
777-
encoder_hidden_states=negative_prompt_embeds,
768+
encoder_hidden_states=prompt_embeds,
778769
encoder_hidden_states_image=image_embeds,
779770
attention_kwargs=attention_kwargs,
780771
return_dict=False,
781772
)[0]
782-
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
773+
774+
775+
if self.do_classifier_free_guidance:
776+
with current_model.cache_context("uncond"):
777+
noise_uncond = current_model(
778+
hidden_states=latent_model_input,
779+
timestep=timestep,
780+
encoder_hidden_states=negative_prompt_embeds,
781+
encoder_hidden_states_image=image_embeds,
782+
attention_kwargs=attention_kwargs,
783+
return_dict=False,
784+
)[0]
785+
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
783786

784787
# compute the previous noisy sample x_t -> x_t-1
785788
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

0 commit comments

Comments
 (0)