@@ -843,6 +843,8 @@ def __call__(
843843 if do_classifier_free_guidance :
844844 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
845845 prompt_attention_mask = torch .cat ([negative_prompt_attention_mask , prompt_attention_mask ], dim = 0 )
846+ if prompt_embeds .ndim == 3 :
847+ prompt_embeds = prompt_embeds .unsqueeze (1 ) # b l d -> b 1 l d
846848
847849 # 4. Prepare timesteps
848850 timesteps , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , timesteps )
@@ -884,17 +886,9 @@ def __call__(
884886 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
885887 timestep = t .expand (latent_model_input .shape [0 ])
886888
887- if prompt_embeds .ndim == 3 :
888- prompt_embeds = prompt_embeds .unsqueeze (1 ) # b l d -> b 1 l d
889-
890- # prepare attention_mask.
891- # b c t h w -> b t h w
892- attention_mask = torch .ones_like (latent_model_input )[:, 0 ]
893-
894889 # predict noise model_output
895890 noise_pred = self .transformer (
896- latent_model_input ,
897- attention_mask = attention_mask ,
891+ hidden_states = latent_model_input ,
898892 encoder_hidden_states = prompt_embeds ,
899893 encoder_attention_mask = prompt_attention_mask ,
900894 timestep = timestep ,
0 commit comments