1515from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1616
1717import torch
18+ from transformers import CLIPImageProcessor , CLIPVisionModelWithProjection
1819
1920from ...callbacks import MultiPipelineCallbacks , PipelineCallback
20- from ...image_processor import VaeImageProcessor
21- from ...loaders import StableDiffusionXLLoraLoaderMixin
22- from ...models import AutoencoderKL , UNet2DConditionModel
21+ from ...image_processor import PipelineImageInput , VaeImageProcessor
22+ from ...loaders import IPAdapterMixin , StableDiffusionXLLoraLoaderMixin
23+ from ...models import AutoencoderKL , ImageProjection , UNet2DConditionModel
2324from ...models .attention_processor import AttnProcessor2_0 , FusedAttnProcessor2_0 , XFormersAttnProcessor
2425from ...schedulers import KarrasDiffusionSchedulers
2526from ...utils import is_torch_xla_available , logging , replace_example_docstring
@@ -120,7 +121,7 @@ def retrieve_timesteps(
120121 return timesteps , num_inference_steps
121122
122123
123- class KolorsPipeline (DiffusionPipeline , StableDiffusionMixin , StableDiffusionXLLoraLoaderMixin ):
124+ class KolorsPipeline (DiffusionPipeline , StableDiffusionMixin , StableDiffusionXLLoraLoaderMixin , IPAdapterMixin ):
124125 r"""
125126 Pipeline for text-to-image generation using Kolors.
126127
@@ -130,6 +131,7 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL
130131 The pipeline also inherits the following loading methods:
131132 - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
132133 - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
134+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
133135
134136 Args:
135137 vae ([`AutoencoderKL`]):
@@ -148,7 +150,11 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL
148150 `Kwai-Kolors/Kolors-diffusers`.
149151 """
150152
151- model_cpu_offload_seq = "text_encoder->unet->vae"
153+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
154+ _optional_components = [
155+ "image_encoder" ,
156+ "feature_extractor" ,
157+ ]
152158 _callback_tensor_inputs = [
153159 "latents" ,
154160 "prompt_embeds" ,
@@ -166,11 +172,21 @@ def __init__(
166172 tokenizer : ChatGLMTokenizer ,
167173 unet : UNet2DConditionModel ,
168174 scheduler : KarrasDiffusionSchedulers ,
175+ image_encoder : CLIPVisionModelWithProjection = None ,
176+ feature_extractor : CLIPImageProcessor = None ,
169177 force_zeros_for_empty_prompt : bool = False ,
170178 ):
171179 super ().__init__ ()
172180
173- self .register_modules (vae = vae , text_encoder = text_encoder , tokenizer = tokenizer , unet = unet , scheduler = scheduler )
181+ self .register_modules (
182+ vae = vae ,
183+ text_encoder = text_encoder ,
184+ tokenizer = tokenizer ,
185+ unet = unet ,
186+ scheduler = scheduler ,
187+ image_encoder = image_encoder ,
188+ feature_extractor = feature_extractor ,
189+ )
174190 self .register_to_config (force_zeros_for_empty_prompt = force_zeros_for_empty_prompt )
175191 self .vae_scale_factor = (
176192 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if hasattr (self , "vae" ) and self .vae is not None else 8
@@ -343,6 +359,77 @@ def encode_prompt(
343359
344360 return prompt_embeds , negative_prompt_embeds , pooled_prompt_embeds , negative_pooled_prompt_embeds
345361
362+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
363+ def encode_image (self , image , device , num_images_per_prompt , output_hidden_states = None ):
364+ dtype = next (self .image_encoder .parameters ()).dtype
365+
366+ if not isinstance (image , torch .Tensor ):
367+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
368+
369+ image = image .to (device = device , dtype = dtype )
370+ if output_hidden_states :
371+ image_enc_hidden_states = self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
372+ image_enc_hidden_states = image_enc_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
373+ uncond_image_enc_hidden_states = self .image_encoder (
374+ torch .zeros_like (image ), output_hidden_states = True
375+ ).hidden_states [- 2 ]
376+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states .repeat_interleave (
377+ num_images_per_prompt , dim = 0
378+ )
379+ return image_enc_hidden_states , uncond_image_enc_hidden_states
380+ else :
381+ image_embeds = self .image_encoder (image ).image_embeds
382+ image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
383+ uncond_image_embeds = torch .zeros_like (image_embeds )
384+
385+ return image_embeds , uncond_image_embeds
386+
387+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
388+ def prepare_ip_adapter_image_embeds (
389+ self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt , do_classifier_free_guidance
390+ ):
391+ image_embeds = []
392+ if do_classifier_free_guidance :
393+ negative_image_embeds = []
394+ if ip_adapter_image_embeds is None :
395+ if not isinstance (ip_adapter_image , list ):
396+ ip_adapter_image = [ip_adapter_image ]
397+
398+ if len (ip_adapter_image ) != len (self .unet .encoder_hid_proj .image_projection_layers ):
399+ raise ValueError (
400+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got { len (ip_adapter_image )} images and { len (self .unet .encoder_hid_proj .image_projection_layers )} IP Adapters."
401+ )
402+
403+ for single_ip_adapter_image , image_proj_layer in zip (
404+ ip_adapter_image , self .unet .encoder_hid_proj .image_projection_layers
405+ ):
406+ output_hidden_state = not isinstance (image_proj_layer , ImageProjection )
407+ single_image_embeds , single_negative_image_embeds = self .encode_image (
408+ single_ip_adapter_image , device , 1 , output_hidden_state
409+ )
410+
411+ image_embeds .append (single_image_embeds [None , :])
412+ if do_classifier_free_guidance :
413+ negative_image_embeds .append (single_negative_image_embeds [None , :])
414+ else :
415+ for single_image_embeds in ip_adapter_image_embeds :
416+ if do_classifier_free_guidance :
417+ single_negative_image_embeds , single_image_embeds = single_image_embeds .chunk (2 )
418+ negative_image_embeds .append (single_negative_image_embeds )
419+ image_embeds .append (single_image_embeds )
420+
421+ ip_adapter_image_embeds = []
422+ for i , single_image_embeds in enumerate (image_embeds ):
423+ single_image_embeds = torch .cat ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
424+ if do_classifier_free_guidance :
425+ single_negative_image_embeds = torch .cat ([negative_image_embeds [i ]] * num_images_per_prompt , dim = 0 )
426+ single_image_embeds = torch .cat ([single_negative_image_embeds , single_image_embeds ], dim = 0 )
427+
428+ single_image_embeds = single_image_embeds .to (device = device )
429+ ip_adapter_image_embeds .append (single_image_embeds )
430+
431+ return ip_adapter_image_embeds
432+
346433 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
347434 def prepare_extra_step_kwargs (self , generator , eta ):
348435 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -364,16 +451,25 @@ def prepare_extra_step_kwargs(self, generator, eta):
364451 def check_inputs (
365452 self ,
366453 prompt ,
454+ num_inference_steps ,
367455 height ,
368456 width ,
369457 negative_prompt = None ,
370458 prompt_embeds = None ,
371459 pooled_prompt_embeds = None ,
372460 negative_prompt_embeds = None ,
373461 negative_pooled_prompt_embeds = None ,
462+ ip_adapter_image = None ,
463+ ip_adapter_image_embeds = None ,
374464 callback_on_step_end_tensor_inputs = None ,
375465 max_sequence_length = None ,
376466 ):
467+ if not isinstance (num_inference_steps , int ) or num_inference_steps <= 0 :
468+ raise ValueError (
469+ f"`num_inference_steps` has to be a positive integer but is { num_inference_steps } of type"
470+ f" { type (num_inference_steps )} ."
471+ )
472+
377473 if height % 8 != 0 or width % 8 != 0 :
378474 raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
379475
@@ -420,6 +516,21 @@ def check_inputs(
420516 "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
421517 )
422518
519+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None :
520+ raise ValueError (
521+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
522+ )
523+
524+ if ip_adapter_image_embeds is not None :
525+ if not isinstance (ip_adapter_image_embeds , list ):
526+ raise ValueError (
527+ f"`ip_adapter_image_embeds` has to be of type `list` but is { type (ip_adapter_image_embeds )} "
528+ )
529+ elif ip_adapter_image_embeds [0 ].ndim not in [3 , 4 ]:
530+ raise ValueError (
531+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is { ip_adapter_image_embeds [0 ].ndim } D"
532+ )
533+
423534 if max_sequence_length is not None and max_sequence_length > 256 :
424535 raise ValueError (f"`max_sequence_length` cannot be greater than 256 but is { max_sequence_length } " )
425536
@@ -563,6 +674,8 @@ def __call__(
563674 pooled_prompt_embeds : Optional [torch .Tensor ] = None ,
564675 negative_prompt_embeds : Optional [torch .Tensor ] = None ,
565676 negative_pooled_prompt_embeds : Optional [torch .Tensor ] = None ,
677+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
678+ ip_adapter_image_embeds : Optional [List [torch .Tensor ]] = None ,
566679 output_type : Optional [str ] = "pil" ,
567680 return_dict : bool = True ,
568681 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -649,6 +762,12 @@ def __call__(
649762 Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
650763 weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
651764 input argument.
765+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
766+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
767+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
768+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
769+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
770+ provided, embeddings are computed from the `ip_adapter_image` input argument.
652771 output_type (`str`, *optional*, defaults to `"pil"`):
653772 The output format of the generate image. Choose between
654773 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -719,13 +838,16 @@ def __call__(
719838 # 1. Check inputs. Raise error if not correct
720839 self .check_inputs (
721840 prompt ,
841+ num_inference_steps ,
722842 height ,
723843 width ,
724844 negative_prompt ,
725845 prompt_embeds ,
726846 pooled_prompt_embeds ,
727847 negative_prompt_embeds ,
728848 negative_pooled_prompt_embeds ,
849+ ip_adapter_image ,
850+ ip_adapter_image_embeds ,
729851 callback_on_step_end_tensor_inputs ,
730852 max_sequence_length = max_sequence_length ,
731853 )
@@ -815,6 +937,15 @@ def __call__(
815937 add_text_embeds = add_text_embeds .to (device )
816938 add_time_ids = add_time_ids .to (device ).repeat (batch_size * num_images_per_prompt , 1 )
817939
940+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
941+ image_embeds = self .prepare_ip_adapter_image_embeds (
942+ ip_adapter_image ,
943+ ip_adapter_image_embeds ,
944+ device ,
945+ batch_size * num_images_per_prompt ,
946+ self .do_classifier_free_guidance ,
947+ )
948+
818949 # 8. Denoising loop
819950 num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
820951
@@ -856,6 +987,9 @@ def __call__(
856987 # predict the noise residual
857988 added_cond_kwargs = {"text_embeds" : add_text_embeds , "time_ids" : add_time_ids }
858989
990+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
991+ added_cond_kwargs ["image_embeds" ] = image_embeds
992+
859993 noise_pred = self .unet (
860994 latent_model_input ,
861995 t ,
0 commit comments