From 1548b3b150f6d2d29d9f88a01ad5a1964302c08c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Fri, 12 Jul 2024 08:00:03 -0400 Subject: [PATCH 1/3] initial commit --- .../pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py | 3 +++ src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py | 3 +++ src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py | 3 +++ .../pipeline_stable_diffusion_xl_img2img.py | 3 +++ 4 files changed, 12 insertions(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index e9fec74e73b6..531f17d84a41 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -930,6 +930,9 @@ def prepare_latents( ) elif isinstance(generator, list): + if image.shape[0] < len(generator): + image = image.expand(len(generator), *image.shape[1:]) + init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py index 4a35da47a50d..f214b75be6ea 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py @@ -528,6 +528,9 @@ def prepare_latents( ) elif isinstance(generator, list): + if image.shape[0] < len(generator): + image = image.expand(len(generator), *image.shape[1:]) + init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py index fb9938aa6a9d..287f21d2267f 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -719,6 +719,9 @@ def prepare_latents( ) elif isinstance(generator, list): + if image.shape[0] < len(generator): + image = image.expand(len(generator), *image.shape[1:]) + init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index ea22e4928e45..915bf0ce24af 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -710,6 +710,9 @@ def prepare_latents( ) elif isinstance(generator, list): + if image.shape[0] < len(generator): + image = image.expand(len(generator), *image.shape[1:]) + init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) From 9fc8dea4ba6b67b2d7457d4643d33c5ddfcb64d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Wed, 17 Jul 2024 17:06:46 -0400 Subject: [PATCH 2/3] apply suggestion to sdxl pipelines --- .../controlnet/pipeline_controlnet_sd_xl_img2img.py | 8 ++++++-- src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py | 8 ++++++-- src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py | 8 ++++++-- .../pipeline_stable_diffusion_xl_img2img.py | 8 ++++++-- 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 531f17d84a41..fac24a03df91 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -930,8 +930,12 @@ def prepare_latents( ) elif isinstance(generator, list): - if image.shape[0] < len(generator): - image = image.expand(len(generator), *image.shape[1:]) + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py index f214b75be6ea..f2c73665e723 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py @@ -528,8 +528,12 @@ def prepare_latents( ) elif isinstance(generator, list): - if image.shape[0] < len(generator): - image = image.expand(len(generator), *image.shape[1:]) + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py index 287f21d2267f..2ce81f6765e1 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -719,8 +719,12 @@ def prepare_latents( ) elif isinstance(generator, list): - if image.shape[0] < len(generator): - image = image.expand(len(generator), *image.shape[1:]) + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 915bf0ce24af..ebabfe26aae4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -710,8 +710,12 @@ def prepare_latents( ) elif isinstance(generator, list): - if image.shape[0] < len(generator): - image = image.expand(len(generator), *image.shape[1:]) + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) From 1594674237b1e1400184426f0a3ca80ac83d082e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Wed, 17 Jul 2024 18:18:42 -0400 Subject: [PATCH 3/3] apply fix to sd pipelines --- .../pipelines/controlnet/pipeline_controlnet_img2img.py | 7 +++++++ .../pipeline_latent_consistency_img2img.py | 7 +++++++ .../pipeline_stable_diffusion_depth2img.py | 7 +++++++ .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 7 +++++++ 4 files changed, 28 insertions(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index e5816bd1b158..cf4cc2c71ee3 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -824,6 +824,13 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt ) elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index 4c40d2fd9e5b..87f84d716c58 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -520,6 +520,13 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt ) elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 8e1e3ab31912..458ca09de608 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -494,6 +494,13 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt ) elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 642a02fdf718..8abbc38db187 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -740,6 +740,13 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt ) elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size)