@@ -318,7 +318,10 @@ def set_use_memory_efficient_attention_xformers(
318318 XFormersAttnAddedKVProcessor ,
319319 ),
320320 )
321-
321+ is_ip_adapter = hasattr (self , "processor" ) and isinstance (
322+ self .processor ,
323+ (IPAdapterAttnProcessor , IPAdapterAttnProcessor2_0 , IPAdapterXFormersAttnProcessor ),
324+ )
322325 if use_memory_efficient_attention_xformers :
323326 if is_added_kv_processor and is_custom_diffusion :
324327 raise NotImplementedError (
@@ -368,6 +371,19 @@ def set_use_memory_efficient_attention_xformers(
368371 "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
369372 )
370373 processor = XFormersAttnAddedKVProcessor (attention_op = attention_op )
374+ elif is_ip_adapter :
375+ processor = IPAdapterXFormersAttnProcessor (
376+ hidden_size = self .processor .hidden_size ,
377+ cross_attention_dim = self .processor .cross_attention_dim ,
378+ num_tokens = self .processor .num_tokens ,
379+ scale = self .processor .scale ,
380+ attention_op = attention_op ,
381+ )
382+ processor .load_state_dict (self .processor .state_dict ())
383+ if hasattr (self .processor , "to_k_ip" ):
384+ processor .to (
385+ device = self .processor .to_k_ip [0 ].weight .device , dtype = self .processor .to_k_ip [0 ].weight .dtype
386+ )
371387 else :
372388 processor = XFormersAttnProcessor (attention_op = attention_op )
373389 else :
@@ -386,6 +402,18 @@ def set_use_memory_efficient_attention_xformers(
386402 processor .load_state_dict (self .processor .state_dict ())
387403 if hasattr (self .processor , "to_k_custom_diffusion" ):
388404 processor .to (self .processor .to_k_custom_diffusion .weight .device )
405+ elif is_ip_adapter :
406+ processor = IPAdapterAttnProcessor2_0 (
407+ hidden_size = self .processor .hidden_size ,
408+ cross_attention_dim = self .processor .cross_attention_dim ,
409+ num_tokens = self .processor .num_tokens ,
410+ scale = self .processor .scale ,
411+ )
412+ processor .load_state_dict (self .processor .state_dict ())
413+ if hasattr (self .processor , "to_k_ip" ):
414+ processor .to (
415+ device = self .processor .to_k_ip [0 ].weight .device , dtype = self .processor .to_k_ip [0 ].weight .dtype
416+ )
389417 else :
390418 # set attention processor
391419 # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
@@ -4542,6 +4570,238 @@ def __call__(
45424570 return hidden_states
45434571
45444572
4573+ class IPAdapterXFormersAttnProcessor (torch .nn .Module ):
4574+ r"""
4575+ Attention processor for IP-Adapter using xFormers.
4576+
4577+ Args:
4578+ hidden_size (`int`):
4579+ The hidden size of the attention layer.
4580+ cross_attention_dim (`int`):
4581+ The number of channels in the `encoder_hidden_states`.
4582+ num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
4583+ The context length of the image features.
4584+ scale (`float` or `List[float]`, defaults to 1.0):
4585+ the weight scale of image prompt.
4586+ attention_op (`Callable`, *optional*, defaults to `None`):
4587+ The base
4588+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
4589+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
4590+ operator.
4591+ """
4592+
4593+ def __init__ (
4594+ self ,
4595+ hidden_size ,
4596+ cross_attention_dim = None ,
4597+ num_tokens = (4 ,),
4598+ scale = 1.0 ,
4599+ attention_op : Optional [Callable ] = None ,
4600+ ):
4601+ super ().__init__ ()
4602+
4603+ self .hidden_size = hidden_size
4604+ self .cross_attention_dim = cross_attention_dim
4605+ self .attention_op = attention_op
4606+
4607+ if not isinstance (num_tokens , (tuple , list )):
4608+ num_tokens = [num_tokens ]
4609+ self .num_tokens = num_tokens
4610+
4611+ if not isinstance (scale , list ):
4612+ scale = [scale ] * len (num_tokens )
4613+ if len (scale ) != len (num_tokens ):
4614+ raise ValueError ("`scale` should be a list of integers with the same length as `num_tokens`." )
4615+ self .scale = scale
4616+
4617+ self .to_k_ip = nn .ModuleList (
4618+ [nn .Linear (cross_attention_dim or hidden_size , hidden_size , bias = False ) for _ in range (len (num_tokens ))]
4619+ )
4620+ self .to_v_ip = nn .ModuleList (
4621+ [nn .Linear (cross_attention_dim or hidden_size , hidden_size , bias = False ) for _ in range (len (num_tokens ))]
4622+ )
4623+
4624+ def __call__ (
4625+ self ,
4626+ attn : Attention ,
4627+ hidden_states : torch .FloatTensor ,
4628+ encoder_hidden_states : Optional [torch .FloatTensor ] = None ,
4629+ attention_mask : Optional [torch .FloatTensor ] = None ,
4630+ temb : Optional [torch .FloatTensor ] = None ,
4631+ scale : float = 1.0 ,
4632+ ip_adapter_masks : Optional [torch .FloatTensor ] = None ,
4633+ ):
4634+ residual = hidden_states
4635+
4636+ # separate ip_hidden_states from encoder_hidden_states
4637+ if encoder_hidden_states is not None :
4638+ if isinstance (encoder_hidden_states , tuple ):
4639+ encoder_hidden_states , ip_hidden_states = encoder_hidden_states
4640+ else :
4641+ deprecation_message = (
4642+ "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
4643+ " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
4644+ )
4645+ deprecate ("encoder_hidden_states not a tuple" , "1.0.0" , deprecation_message , standard_warn = False )
4646+ end_pos = encoder_hidden_states .shape [1 ] - self .num_tokens [0 ]
4647+ encoder_hidden_states , ip_hidden_states = (
4648+ encoder_hidden_states [:, :end_pos , :],
4649+ [encoder_hidden_states [:, end_pos :, :]],
4650+ )
4651+
4652+ if attn .spatial_norm is not None :
4653+ hidden_states = attn .spatial_norm (hidden_states , temb )
4654+
4655+ input_ndim = hidden_states .ndim
4656+
4657+ if input_ndim == 4 :
4658+ batch_size , channel , height , width = hidden_states .shape
4659+ hidden_states = hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
4660+
4661+ batch_size , sequence_length , _ = (
4662+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
4663+ )
4664+
4665+ if attention_mask is not None :
4666+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
4667+ # expand our mask's singleton query_tokens dimension:
4668+ # [batch*heads, 1, key_tokens] ->
4669+ # [batch*heads, query_tokens, key_tokens]
4670+ # so that it can be added as a bias onto the attention scores that xformers computes:
4671+ # [batch*heads, query_tokens, key_tokens]
4672+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
4673+ _ , query_tokens , _ = hidden_states .shape
4674+ attention_mask = attention_mask .expand (- 1 , query_tokens , - 1 )
4675+
4676+ if attn .group_norm is not None :
4677+ hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
4678+
4679+ query = attn .to_q (hidden_states )
4680+
4681+ if encoder_hidden_states is None :
4682+ encoder_hidden_states = hidden_states
4683+ elif attn .norm_cross :
4684+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
4685+
4686+ key = attn .to_k (encoder_hidden_states )
4687+ value = attn .to_v (encoder_hidden_states )
4688+
4689+ query = attn .head_to_batch_dim (query ).contiguous ()
4690+ key = attn .head_to_batch_dim (key ).contiguous ()
4691+ value = attn .head_to_batch_dim (value ).contiguous ()
4692+
4693+ hidden_states = xformers .ops .memory_efficient_attention (
4694+ query , key , value , attn_bias = attention_mask , op = self .attention_op
4695+ )
4696+ hidden_states = hidden_states .to (query .dtype )
4697+ hidden_states = attn .batch_to_head_dim (hidden_states )
4698+
4699+ if ip_hidden_states :
4700+ if ip_adapter_masks is not None :
4701+ if not isinstance (ip_adapter_masks , List ):
4702+ # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
4703+ ip_adapter_masks = list (ip_adapter_masks .unsqueeze (1 ))
4704+ if not (len (ip_adapter_masks ) == len (self .scale ) == len (ip_hidden_states )):
4705+ raise ValueError (
4706+ f"Length of ip_adapter_masks array ({ len (ip_adapter_masks )} ) must match "
4707+ f"length of self.scale array ({ len (self .scale )} ) and number of ip_hidden_states "
4708+ f"({ len (ip_hidden_states )} )"
4709+ )
4710+ else :
4711+ for index , (mask , scale , ip_state ) in enumerate (
4712+ zip (ip_adapter_masks , self .scale , ip_hidden_states )
4713+ ):
4714+ if mask is None :
4715+ continue
4716+ if not isinstance (mask , torch .Tensor ) or mask .ndim != 4 :
4717+ raise ValueError (
4718+ "Each element of the ip_adapter_masks array should be a tensor with shape "
4719+ "[1, num_images_for_ip_adapter, height, width]."
4720+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
4721+ )
4722+ if mask .shape [1 ] != ip_state .shape [1 ]:
4723+ raise ValueError (
4724+ f"Number of masks ({ mask .shape [1 ]} ) does not match "
4725+ f"number of ip images ({ ip_state .shape [1 ]} ) at index { index } "
4726+ )
4727+ if isinstance (scale , list ) and not len (scale ) == mask .shape [1 ]:
4728+ raise ValueError (
4729+ f"Number of masks ({ mask .shape [1 ]} ) does not match "
4730+ f"number of scales ({ len (scale )} ) at index { index } "
4731+ )
4732+ else :
4733+ ip_adapter_masks = [None ] * len (self .scale )
4734+
4735+ # for ip-adapter
4736+ for current_ip_hidden_states , scale , to_k_ip , to_v_ip , mask in zip (
4737+ ip_hidden_states , self .scale , self .to_k_ip , self .to_v_ip , ip_adapter_masks
4738+ ):
4739+ skip = False
4740+ if isinstance (scale , list ):
4741+ if all (s == 0 for s in scale ):
4742+ skip = True
4743+ elif scale == 0 :
4744+ skip = True
4745+ if not skip :
4746+ if mask is not None :
4747+ mask = mask .to (torch .float16 )
4748+ if not isinstance (scale , list ):
4749+ scale = [scale ] * mask .shape [1 ]
4750+
4751+ current_num_images = mask .shape [1 ]
4752+ for i in range (current_num_images ):
4753+ ip_key = to_k_ip (current_ip_hidden_states [:, i , :, :])
4754+ ip_value = to_v_ip (current_ip_hidden_states [:, i , :, :])
4755+
4756+ ip_key = attn .head_to_batch_dim (ip_key ).contiguous ()
4757+ ip_value = attn .head_to_batch_dim (ip_value ).contiguous ()
4758+
4759+ _current_ip_hidden_states = xformers .ops .memory_efficient_attention (
4760+ query , ip_key , ip_value , op = self .attention_op
4761+ )
4762+ _current_ip_hidden_states = _current_ip_hidden_states .to (query .dtype )
4763+ _current_ip_hidden_states = attn .batch_to_head_dim (_current_ip_hidden_states )
4764+
4765+ mask_downsample = IPAdapterMaskProcessor .downsample (
4766+ mask [:, i , :, :],
4767+ batch_size ,
4768+ _current_ip_hidden_states .shape [1 ],
4769+ _current_ip_hidden_states .shape [2 ],
4770+ )
4771+
4772+ mask_downsample = mask_downsample .to (dtype = query .dtype , device = query .device )
4773+ hidden_states = hidden_states + scale [i ] * (_current_ip_hidden_states * mask_downsample )
4774+ else :
4775+ ip_key = to_k_ip (current_ip_hidden_states )
4776+ ip_value = to_v_ip (current_ip_hidden_states )
4777+
4778+ ip_key = attn .head_to_batch_dim (ip_key ).contiguous ()
4779+ ip_value = attn .head_to_batch_dim (ip_value ).contiguous ()
4780+
4781+ current_ip_hidden_states = xformers .ops .memory_efficient_attention (
4782+ query , ip_key , ip_value , op = self .attention_op
4783+ )
4784+ current_ip_hidden_states = current_ip_hidden_states .to (query .dtype )
4785+ current_ip_hidden_states = attn .batch_to_head_dim (current_ip_hidden_states )
4786+
4787+ hidden_states = hidden_states + scale * current_ip_hidden_states
4788+
4789+ # linear proj
4790+ hidden_states = attn .to_out [0 ](hidden_states )
4791+ # dropout
4792+ hidden_states = attn .to_out [1 ](hidden_states )
4793+
4794+ if input_ndim == 4 :
4795+ hidden_states = hidden_states .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
4796+
4797+ if attn .residual_connection :
4798+ hidden_states = hidden_states + residual
4799+
4800+ hidden_states = hidden_states / attn .rescale_output_factor
4801+
4802+ return hidden_states
4803+
4804+
45454805class PAGIdentitySelfAttnProcessor2_0 :
45464806 r"""
45474807 Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
0 commit comments