1313# limitations under the License.
1414
1515import inspect
16- from typing import Callable , Dict , List , Optional , Union
16+ from typing import Any , Callable , Dict , List , Optional , Union
1717
1818import PIL .Image
1919import torch
2525)
2626
2727from ...image_processor import PipelineImageInput , VaeImageProcessor
28- from ...loaders import SD3LoraLoaderMixin
28+ from ...loaders import FromSingleFileMixin , SD3LoraLoaderMixin
2929from ...models .autoencoders import AutoencoderKL
3030from ...models .transformers import SD3Transformer2DModel
3131from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -149,7 +149,7 @@ def retrieve_timesteps(
149149 return timesteps , num_inference_steps
150150
151151
152- class StableDiffusion3Img2ImgPipeline (DiffusionPipeline ):
152+ class StableDiffusion3Img2ImgPipeline (DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin ):
153153 r"""
154154 Args:
155155 transformer ([`SD3Transformer2DModel`]):
@@ -680,6 +680,10 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
680680 def guidance_scale (self ):
681681 return self ._guidance_scale
682682
683+ @property
684+ def joint_attention_kwargs (self ):
685+ return self ._joint_attention_kwargs
686+
683687 @property
684688 def clip_skip (self ):
685689 return self ._clip_skip
@@ -723,6 +727,7 @@ def __call__(
723727 negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
724728 output_type : Optional [str ] = "pil" ,
725729 return_dict : bool = True ,
730+ joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
726731 clip_skip : Optional [int ] = None ,
727732 callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
728733 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
@@ -797,6 +802,10 @@ def __call__(
797802 return_dict (`bool`, *optional*, defaults to `True`):
798803 Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
799804 of a plain tuple.
805+ joint_attention_kwargs (`dict`, *optional*):
806+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
807+ `self.processor` in
808+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
800809 callback_on_step_end (`Callable`, *optional*):
801810 A function that calls at the end of each denoising steps during the inference. The function is called
802811 with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -835,6 +844,7 @@ def __call__(
835844
836845 self ._guidance_scale = guidance_scale
837846 self ._clip_skip = clip_skip
847+ self ._joint_attention_kwargs = joint_attention_kwargs
838848 self ._interrupt = False
839849
840850 # 2. Define call parameters
@@ -847,6 +857,10 @@ def __call__(
847857
848858 device = self ._execution_device
849859
860+ lora_scale = (
861+ self .joint_attention_kwargs .get ("scale" , None ) if self .joint_attention_kwargs is not None else None
862+ )
863+
850864 (
851865 prompt_embeds ,
852866 negative_prompt_embeds ,
@@ -868,6 +882,7 @@ def __call__(
868882 clip_skip = self .clip_skip ,
869883 num_images_per_prompt = num_images_per_prompt ,
870884 max_sequence_length = max_sequence_length ,
885+ lora_scale = lora_scale ,
871886 )
872887
873888 if self .do_classifier_free_guidance :
@@ -912,6 +927,7 @@ def __call__(
912927 timestep = timestep ,
913928 encoder_hidden_states = prompt_embeds ,
914929 pooled_projections = pooled_prompt_embeds ,
930+ joint_attention_kwargs = self .joint_attention_kwargs ,
915931 return_dict = False ,
916932 )[0 ]
917933
0 commit comments