Skip to content

Commit 9214f4a

Browse files
committed
remove attention mask for self-attention
1 parent 2065adc commit 9214f4a

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,8 @@ def __call__(
843843
if do_classifier_free_guidance:
844844
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
845845
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
846+
if prompt_embeds.ndim == 3:
847+
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
846848

847849
# 4. Prepare timesteps
848850
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
@@ -884,17 +886,9 @@ def __call__(
884886
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
885887
timestep = t.expand(latent_model_input.shape[0])
886888

887-
if prompt_embeds.ndim == 3:
888-
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
889-
890-
# prepare attention_mask.
891-
# b c t h w -> b t h w
892-
attention_mask = torch.ones_like(latent_model_input)[:, 0]
893-
894889
# predict noise model_output
895890
noise_pred = self.transformer(
896-
latent_model_input,
897-
attention_mask=attention_mask,
891+
hidden_states=latent_model_input,
898892
encoder_hidden_states=prompt_embeds,
899893
encoder_attention_mask=prompt_attention_mask,
900894
timestep=timestep,

0 commit comments

Comments
 (0)