diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index 4836de7e16ab..922fd15c08fb 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -1091,8 +1091,6 @@ def forward( sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, - encoder_local_batch_size: int = 2, - decoder_local_batch_size: int = 2, ) -> Union[DecoderOutput, torch.Tensor]: r""" Args: @@ -1103,18 +1101,14 @@ def forward( Whether or not to return a [`DecoderOutput`] instead of a plain tuple. generator (`torch.Generator`, *optional*): PyTorch random number generator. - encoder_local_batch_size (`int`, *optional*, defaults to 2): - Local batch size for the encoder's batch inference. - decoder_local_batch_size (`int`, *optional*, defaults to 2): - Local batch size for the decoder's batch inference. """ x = sample - posterior = self.encode(x, local_batch_size=encoder_local_batch_size).latent_dist + posterior = self.encode(x).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() - dec = self.decode(z, local_batch_size=decoder_local_batch_size).sample + dec = self.decode(z).sample if not return_dict: return (dec,)