Skip to content

Commit c517579

Browse files
committed
feat: add multiple input image support in Flux Kontext
1 parent 425a715 commit c517579

File tree

2 files changed

+125
-39
lines changed

2 files changed

+125
-39
lines changed

src/diffusers/image_processor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@
3535
List[torch.Tensor],
3636
]
3737

38+
PipelineSeveralImagesInput = Union[
39+
Tuple[PIL.Image.Image, ...],
40+
Tuple[np.ndarray, ...],
41+
Tuple[torch.Tensor, ...],
42+
List[Tuple[PIL.Image.Image, ...]],
43+
List[Tuple[np.ndarray, ...]],
44+
List[Tuple[torch.Tensor, ...]],
45+
]
46+
47+
48+
3849
PipelineDepthInput = PipelineImageInput
3950

4051

@@ -523,6 +534,16 @@ def resize(
523534
size=(height, width),
524535
)
525536
image = self.pt_to_numpy(image)
537+
elif isinstance(image, tuple):
538+
image = tuple(
539+
self.resize(
540+
img,
541+
height=height,
542+
width=width,
543+
resize_mode=resize_mode,
544+
)
545+
for img in image
546+
)
526547
return image
527548

528549
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:

src/diffusers/pipelines/flux/pipeline_flux_kontext.py

Lines changed: 104 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
T5TokenizerFast,
2727
)
2828

29-
from ...image_processor import PipelineImageInput, VaeImageProcessor
29+
from ...image_processor import PipelineImageInput, PipelineSeveralImagesInput, VaeImageProcessor
3030
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
3131
from ...models import AutoencoderKL, FluxTransformer2DModel
3232
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -641,9 +641,61 @@ def disable_vae_tiling(self):
641641
"""
642642
self.vae.disable_tiling()
643643

644+
def preprocess_image(self, image: PipelineImageInput, _auto_resize: bool, multiple_of: int) -> torch.Tensor:
645+
img = image[0] if isinstance(image, list) else image
646+
image_height, image_width = self.image_processor.get_default_height_width(img)
647+
aspect_ratio = image_width / image_height
648+
if _auto_resize:
649+
# Kontext is trained on specific resolutions, using one of them is recommended
650+
_, image_width, image_height = min(
651+
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
652+
)
653+
image_width = image_width // multiple_of * multiple_of
654+
image_height = image_height // multiple_of * multiple_of
655+
image = self.image_processor.resize(image, image_height, image_width)
656+
image = self.image_processor.preprocess(image, image_height, image_width)
657+
return image
658+
659+
def preprocess_images(
660+
self,
661+
images: PipelineSeveralImagesInput,
662+
_auto_resize: bool,
663+
multiple_of: int,
664+
) -> torch.Tensor:
665+
# TODO for reviewer: I'm not sure what's the best way to implement this part given the philosophy of the repo.
666+
# The solutions I thought about are:
667+
# - Make the `resize` and `preprocess` methods of `VaeImageProcessor` more generic (using TypeVar for instance)
668+
# - Start by converting the image to a List[Tuple[ {image_format} ]], to unify the processing logic
669+
# - Or duplicate the code, as done here.
670+
# What do you think ?
671+
672+
# convert multiple_images to a list of tuple, to simplify following logic
673+
if not isinstance(images, list):
674+
images = [images]
675+
# now multiple_images is a list of tuples.
676+
677+
img = images[0][0]
678+
image_height, image_width = self.image_processor.get_default_height_width(img)
679+
aspect_ratio = image_width / image_height
680+
if _auto_resize:
681+
# Kontext is trained on specific resolutions, using one of them is recommended
682+
_, image_width, image_height = min(
683+
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
684+
)
685+
image_width = image_width // multiple_of * multiple_of
686+
image_height = image_height // multiple_of * multiple_of
687+
n_image_per_batch = len(images[0])
688+
output_images = []
689+
for i in range(n_image_per_batch):
690+
image = [batch_images[i] for batch_images in images]
691+
image = self.image_processor.resize(image, image_height, image_width)
692+
image = self.image_processor.preprocess(image, image_height, image_width)
693+
output_images.append(image)
694+
return output_images
695+
644696
def prepare_latents(
645697
self,
646-
image: Optional[torch.Tensor],
698+
images: Optional[list[torch.Tensor]],
647699
batch_size: int,
648700
num_channels_latents: int,
649701
height: int,
@@ -665,33 +717,45 @@ def prepare_latents(
665717
width = 2 * (int(width) // (self.vae_scale_factor * 2))
666718
shape = (batch_size, num_channels_latents, height, width)
667719

668-
image_latents = image_ids = None
669-
if image is not None:
670-
image = image.to(device=device, dtype=dtype)
671-
if image.shape[1] != self.latent_channels:
672-
image_latents = self._encode_vae_image(image=image, generator=generator)
673-
else:
674-
image_latents = image
675-
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
676-
# expand init_latents for batch_size
677-
additional_image_per_prompt = batch_size // image_latents.shape[0]
678-
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
679-
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
680-
raise ValueError(
681-
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
720+
all_image_latents = []
721+
all_image_ids = []
722+
image_latents = images_ids = None
723+
if images is not None:
724+
for i, image in enumerate(images):
725+
image = image.to(device=device, dtype=dtype)
726+
if image.shape[1] != self.latent_channels:
727+
image_latents = self._encode_vae_image(image=image, generator=generator)
728+
else:
729+
image_latents = image
730+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
731+
# expand init_latents for batch_size
732+
additional_image_per_prompt = batch_size // image_latents.shape[0]
733+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
734+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
735+
raise ValueError(
736+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
737+
)
738+
else:
739+
image_latents = torch.cat([image_latents], dim=0)
740+
741+
image_latent_height, image_latent_width = image_latents.shape[2:]
742+
image_latents = self._pack_latents(
743+
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
744+
)
745+
image_ids = self._prepare_latent_image_ids(
746+
batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype
682747
)
683-
else:
684-
image_latents = torch.cat([image_latents], dim=0)
748+
# image ids are the same as latent ids with the first dimension set to 1 instead of 0
749+
image_ids[..., 0] = 1
685750

686-
image_latent_height, image_latent_width = image_latents.shape[2:]
687-
image_latents = self._pack_latents(
688-
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
689-
)
690-
image_ids = self._prepare_latent_image_ids(
691-
batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype
692-
)
693-
# image ids are the same as latent ids with the first dimension set to 1 instead of 0
694-
image_ids[..., 0] = 1
751+
# set the image ids to the correct position in the latent grid
752+
image_ids[..., 2] += i * (image_latent_height // 2)
753+
754+
all_image_ids.append(image_ids)
755+
all_image_latents.append(image_latents)
756+
757+
image_latents = torch.cat(all_image_latents, dim=1)
758+
image_ids = torch.cat(all_image_ids, dim=0)
695759

696760
latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
697761

@@ -757,6 +821,7 @@ def __call__(
757821
max_sequence_length: int = 512,
758822
max_area: int = 1024**2,
759823
_auto_resize: bool = True,
824+
multiple_images: Optional[PipelineSeveralImagesInput] = None,
760825
):
761826
r"""
762827
Function invoked when calling the pipeline for generation.
@@ -858,6 +923,9 @@ def __call__(
858923
max_area (`int`, defaults to `1024 ** 2`):
859924
The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
860925
area while maintaining the aspect ratio.
926+
multiple_images (`PipelineSeveralImagesInput`, *optional*):
927+
A list of images to be used as reference images for the generation. If provided, the pipeline will
928+
merge the reference images in the latent space.
861929
862930
Examples:
863931
@@ -953,19 +1021,16 @@ def __call__(
9531021
)
9541022

9551023
# 3. Preprocess image
1024+
if image is not None and multiple_images is not None:
1025+
raise ValueError("Cannot pass both `image` and `multiple_images`. Please use only one of them.")
9561026
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
957-
img = image[0] if isinstance(image, list) else image
958-
image_height, image_width = self.image_processor.get_default_height_width(img)
959-
aspect_ratio = image_width / image_height
960-
if _auto_resize:
961-
# Kontext is trained on specific resolutions, using one of them is recommended
962-
_, image_width, image_height = min(
963-
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
964-
)
965-
image_width = image_width // multiple_of * multiple_of
966-
image_height = image_height // multiple_of * multiple_of
967-
image = self.image_processor.resize(image, image_height, image_width)
968-
image = self.image_processor.preprocess(image, image_height, image_width)
1027+
image = [self.preprocess_image(image, _auto_resize=True, multiple_of=multiple_of)]
1028+
if multiple_images is not None:
1029+
image = self.preprocess_images(
1030+
multiple_images,
1031+
_auto_resize=_auto_resize,
1032+
multiple_of=multiple_of,
1033+
)
9691034

9701035
# 4. Prepare latent variables
9711036
num_channels_latents = self.transformer.config.in_channels // 4

0 commit comments

Comments
 (0)