2626 T5TokenizerFast ,
2727)
2828
29- from ...image_processor import PipelineImageInput , VaeImageProcessor
29+ from ...image_processor import PipelineImageInput , PipelineSeveralImagesInput , VaeImageProcessor
3030from ...loaders import FluxIPAdapterMixin , FluxLoraLoaderMixin , FromSingleFileMixin , TextualInversionLoaderMixin
3131from ...models import AutoencoderKL , FluxTransformer2DModel
3232from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -641,9 +641,61 @@ def disable_vae_tiling(self):
641641 """
642642 self .vae .disable_tiling ()
643643
644+ def preprocess_image (self , image : PipelineImageInput , _auto_resize : bool , multiple_of : int ) -> torch .Tensor :
645+ img = image [0 ] if isinstance (image , list ) else image
646+ image_height , image_width = self .image_processor .get_default_height_width (img )
647+ aspect_ratio = image_width / image_height
648+ if _auto_resize :
649+ # Kontext is trained on specific resolutions, using one of them is recommended
650+ _ , image_width , image_height = min (
651+ (abs (aspect_ratio - w / h ), w , h ) for w , h in PREFERRED_KONTEXT_RESOLUTIONS
652+ )
653+ image_width = image_width // multiple_of * multiple_of
654+ image_height = image_height // multiple_of * multiple_of
655+ image = self .image_processor .resize (image , image_height , image_width )
656+ image = self .image_processor .preprocess (image , image_height , image_width )
657+ return image
658+
659+ def preprocess_images (
660+ self ,
661+ images : PipelineSeveralImagesInput ,
662+ _auto_resize : bool ,
663+ multiple_of : int ,
664+ ) -> torch .Tensor :
665+ # TODO for reviewer: I'm not sure what's the best way to implement this part given the philosophy of the repo.
666+ # The solutions I thought about are:
667+ # - Make the `resize` and `preprocess` methods of `VaeImageProcessor` more generic (using TypeVar for instance)
668+ # - Start by converting the image to a List[Tuple[ {image_format} ]], to unify the processing logic
669+ # - Or duplicate the code, as done here.
670+ # What do you think ?
671+
672+ # convert multiple_images to a list of tuple, to simplify following logic
673+ if not isinstance (images , list ):
674+ images = [images ]
675+ # now multiple_images is a list of tuples.
676+
677+ img = images [0 ][0 ]
678+ image_height , image_width = self .image_processor .get_default_height_width (img )
679+ aspect_ratio = image_width / image_height
680+ if _auto_resize :
681+ # Kontext is trained on specific resolutions, using one of them is recommended
682+ _ , image_width , image_height = min (
683+ (abs (aspect_ratio - w / h ), w , h ) for w , h in PREFERRED_KONTEXT_RESOLUTIONS
684+ )
685+ image_width = image_width // multiple_of * multiple_of
686+ image_height = image_height // multiple_of * multiple_of
687+ n_image_per_batch = len (images [0 ])
688+ output_images = []
689+ for i in range (n_image_per_batch ):
690+ image = [batch_images [i ] for batch_images in images ]
691+ image = self .image_processor .resize (image , image_height , image_width )
692+ image = self .image_processor .preprocess (image , image_height , image_width )
693+ output_images .append (image )
694+ return output_images
695+
644696 def prepare_latents (
645697 self ,
646- image : Optional [torch .Tensor ],
698+ images : Optional [list [ torch .Tensor ] ],
647699 batch_size : int ,
648700 num_channels_latents : int ,
649701 height : int ,
@@ -665,33 +717,45 @@ def prepare_latents(
665717 width = 2 * (int (width ) // (self .vae_scale_factor * 2 ))
666718 shape = (batch_size , num_channels_latents , height , width )
667719
668- image_latents = image_ids = None
669- if image is not None :
670- image = image .to (device = device , dtype = dtype )
671- if image .shape [1 ] != self .latent_channels :
672- image_latents = self ._encode_vae_image (image = image , generator = generator )
673- else :
674- image_latents = image
675- if batch_size > image_latents .shape [0 ] and batch_size % image_latents .shape [0 ] == 0 :
676- # expand init_latents for batch_size
677- additional_image_per_prompt = batch_size // image_latents .shape [0 ]
678- image_latents = torch .cat ([image_latents ] * additional_image_per_prompt , dim = 0 )
679- elif batch_size > image_latents .shape [0 ] and batch_size % image_latents .shape [0 ] != 0 :
680- raise ValueError (
681- f"Cannot duplicate `image` of batch size { image_latents .shape [0 ]} to { batch_size } text prompts."
720+ all_image_latents = []
721+ all_image_ids = []
722+ image_latents = images_ids = None
723+ if images is not None :
724+ for i , image in enumerate (images ):
725+ image = image .to (device = device , dtype = dtype )
726+ if image .shape [1 ] != self .latent_channels :
727+ image_latents = self ._encode_vae_image (image = image , generator = generator )
728+ else :
729+ image_latents = image
730+ if batch_size > image_latents .shape [0 ] and batch_size % image_latents .shape [0 ] == 0 :
731+ # expand init_latents for batch_size
732+ additional_image_per_prompt = batch_size // image_latents .shape [0 ]
733+ image_latents = torch .cat ([image_latents ] * additional_image_per_prompt , dim = 0 )
734+ elif batch_size > image_latents .shape [0 ] and batch_size % image_latents .shape [0 ] != 0 :
735+ raise ValueError (
736+ f"Cannot duplicate `image` of batch size { image_latents .shape [0 ]} to { batch_size } text prompts."
737+ )
738+ else :
739+ image_latents = torch .cat ([image_latents ], dim = 0 )
740+
741+ image_latent_height , image_latent_width = image_latents .shape [2 :]
742+ image_latents = self ._pack_latents (
743+ image_latents , batch_size , num_channels_latents , image_latent_height , image_latent_width
744+ )
745+ image_ids = self ._prepare_latent_image_ids (
746+ batch_size , image_latent_height // 2 , image_latent_width // 2 , device , dtype
682747 )
683- else :
684- image_latents = torch . cat ([ image_latents ], dim = 0 )
748+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
749+ image_ids [..., 0 ] = 1
685750
686- image_latent_height , image_latent_width = image_latents .shape [2 :]
687- image_latents = self ._pack_latents (
688- image_latents , batch_size , num_channels_latents , image_latent_height , image_latent_width
689- )
690- image_ids = self ._prepare_latent_image_ids (
691- batch_size , image_latent_height // 2 , image_latent_width // 2 , device , dtype
692- )
693- # image ids are the same as latent ids with the first dimension set to 1 instead of 0
694- image_ids [..., 0 ] = 1
751+ # set the image ids to the correct position in the latent grid
752+ image_ids [..., 2 ] += i * (image_latent_height // 2 )
753+
754+ all_image_ids .append (image_ids )
755+ all_image_latents .append (image_latents )
756+
757+ image_latents = torch .cat (all_image_latents , dim = 1 )
758+ image_ids = torch .cat (all_image_ids , dim = 0 )
695759
696760 latent_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
697761
@@ -757,6 +821,7 @@ def __call__(
757821 max_sequence_length : int = 512 ,
758822 max_area : int = 1024 ** 2 ,
759823 _auto_resize : bool = True ,
824+ multiple_images : Optional [PipelineSeveralImagesInput ] = None ,
760825 ):
761826 r"""
762827 Function invoked when calling the pipeline for generation.
@@ -858,6 +923,9 @@ def __call__(
858923 max_area (`int`, defaults to `1024 ** 2`):
859924 The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
860925 area while maintaining the aspect ratio.
926+ multiple_images (`PipelineSeveralImagesInput`, *optional*):
927+ A list of images to be used as reference images for the generation. If provided, the pipeline will
928+ merge the reference images in the latent space.
861929
862930 Examples:
863931
@@ -953,19 +1021,16 @@ def __call__(
9531021 )
9541022
9551023 # 3. Preprocess image
1024+ if image is not None and multiple_images is not None :
1025+ raise ValueError ("Cannot pass both `image` and `multiple_images`. Please use only one of them." )
9561026 if image is not None and not (isinstance (image , torch .Tensor ) and image .size (1 ) == self .latent_channels ):
957- img = image [0 ] if isinstance (image , list ) else image
958- image_height , image_width = self .image_processor .get_default_height_width (img )
959- aspect_ratio = image_width / image_height
960- if _auto_resize :
961- # Kontext is trained on specific resolutions, using one of them is recommended
962- _ , image_width , image_height = min (
963- (abs (aspect_ratio - w / h ), w , h ) for w , h in PREFERRED_KONTEXT_RESOLUTIONS
964- )
965- image_width = image_width // multiple_of * multiple_of
966- image_height = image_height // multiple_of * multiple_of
967- image = self .image_processor .resize (image , image_height , image_width )
968- image = self .image_processor .preprocess (image , image_height , image_width )
1027+ image = [self .preprocess_image (image , _auto_resize = True , multiple_of = multiple_of )]
1028+ if multiple_images is not None :
1029+ image = self .preprocess_images (
1030+ multiple_images ,
1031+ _auto_resize = _auto_resize ,
1032+ multiple_of = multiple_of ,
1033+ )
9691034
9701035 # 4. Prepare latent variables
9711036 num_channels_latents = self .transformer .config .in_channels // 4
0 commit comments