@@ -761,25 +761,28 @@ def __call__(
761761 latent_model_input = torch .cat ([latents , condition ], dim = 1 ).to (transformer_dtype )
762762 timestep = t .expand (latents .shape [0 ])
763763
764- noise_pred = current_model (
765- hidden_states = latent_model_input ,
766- timestep = timestep ,
767- encoder_hidden_states = prompt_embeds ,
768- encoder_hidden_states_image = image_embeds ,
769- attention_kwargs = attention_kwargs ,
770- return_dict = False ,
771- )[0 ]
772-
773- if self .do_classifier_free_guidance :
774- noise_uncond = current_model (
764+ with current_model .cache_context ("cond" ):
765+ noise_pred = current_model (
775766 hidden_states = latent_model_input ,
776767 timestep = timestep ,
777- encoder_hidden_states = negative_prompt_embeds ,
768+ encoder_hidden_states = prompt_embeds ,
778769 encoder_hidden_states_image = image_embeds ,
779770 attention_kwargs = attention_kwargs ,
780771 return_dict = False ,
781772 )[0 ]
782- noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond )
773+
774+
775+ if self .do_classifier_free_guidance :
776+ with current_model .cache_context ("uncond" ):
777+ noise_uncond = current_model (
778+ hidden_states = latent_model_input ,
779+ timestep = timestep ,
780+ encoder_hidden_states = negative_prompt_embeds ,
781+ encoder_hidden_states_image = image_embeds ,
782+ attention_kwargs = attention_kwargs ,
783+ return_dict = False ,
784+ )[0 ]
785+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond )
783786
784787 # compute the previous noisy sample x_t -> x_t-1
785788 latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
0 commit comments