diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index e5816bd1b158..cf4cc2c71ee3 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -824,6 +824,13 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt ) elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index e9fec74e73b6..fac24a03df91 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -930,6 +930,13 @@ def prepare_latents( ) elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py index 4a35da47a50d..f2c73665e723 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py @@ -528,6 +528,13 @@ def prepare_latents( ) elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index 4c40d2fd9e5b..87f84d716c58 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -520,6 +520,13 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt ) elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py index fb9938aa6a9d..2ce81f6765e1 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -719,6 +719,13 @@ def prepare_latents( ) elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 8e1e3ab31912..458ca09de608 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -494,6 +494,13 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt ) elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 642a02fdf718..8abbc38db187 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -740,6 +740,13 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt ) elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index ea22e4928e45..ebabfe26aae4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -710,6 +710,13 @@ def prepare_latents( ) elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size)