Skip to content

Commit 3fdfa1d

Browse files
SahilCarterrsayakpaul
authored andcommitted
Added Lora Support to SD3 Img2Img Pipeline (#9659)
* add lora
1 parent bc2cbfb commit 3fdfa1d

File tree

2 files changed

+126
-5
lines changed

2 files changed

+126
-5
lines changed

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import inspect
16-
from typing import Callable, Dict, List, Optional, Union
16+
from typing import Any, Callable, Dict, List, Optional, Union
1717

1818
import PIL.Image
1919
import torch
@@ -25,7 +25,7 @@
2525
)
2626

2727
from ...image_processor import PipelineImageInput, VaeImageProcessor
28-
from ...loaders import SD3LoraLoaderMixin
28+
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
2929
from ...models.autoencoders import AutoencoderKL
3030
from ...models.transformers import SD3Transformer2DModel
3131
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -149,7 +149,7 @@ def retrieve_timesteps(
149149
return timesteps, num_inference_steps
150150

151151

152-
class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
152+
class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
153153
r"""
154154
Args:
155155
transformer ([`SD3Transformer2DModel`]):
@@ -680,6 +680,10 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
680680
def guidance_scale(self):
681681
return self._guidance_scale
682682

683+
@property
684+
def joint_attention_kwargs(self):
685+
return self._joint_attention_kwargs
686+
683687
@property
684688
def clip_skip(self):
685689
return self._clip_skip
@@ -723,6 +727,7 @@ def __call__(
723727
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
724728
output_type: Optional[str] = "pil",
725729
return_dict: bool = True,
730+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
726731
clip_skip: Optional[int] = None,
727732
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
728733
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
@@ -797,6 +802,10 @@ def __call__(
797802
return_dict (`bool`, *optional*, defaults to `True`):
798803
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
799804
of a plain tuple.
805+
joint_attention_kwargs (`dict`, *optional*):
806+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
807+
`self.processor` in
808+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
800809
callback_on_step_end (`Callable`, *optional*):
801810
A function that calls at the end of each denoising steps during the inference. The function is called
802811
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -835,6 +844,7 @@ def __call__(
835844

836845
self._guidance_scale = guidance_scale
837846
self._clip_skip = clip_skip
847+
self._joint_attention_kwargs = joint_attention_kwargs
838848
self._interrupt = False
839849

840850
# 2. Define call parameters
@@ -847,6 +857,10 @@ def __call__(
847857

848858
device = self._execution_device
849859

860+
lora_scale = (
861+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
862+
)
863+
850864
(
851865
prompt_embeds,
852866
negative_prompt_embeds,
@@ -868,6 +882,7 @@ def __call__(
868882
clip_skip=self.clip_skip,
869883
num_images_per_prompt=num_images_per_prompt,
870884
max_sequence_length=max_sequence_length,
885+
lora_scale=lora_scale,
871886
)
872887

873888
if self.do_classifier_free_guidance:
@@ -912,6 +927,7 @@ def __call__(
912927
timestep=timestep,
913928
encoder_hidden_states=prompt_embeds,
914929
pooled_projections=pooled_prompt_embeds,
930+
joint_attention_kwargs=self.joint_attention_kwargs,
915931
return_dict=False,
916932
)[0]
917933

tests/lora/test_lora_layers_sd3.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,29 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import gc
1516
import sys
1617
import unittest
1718

19+
import numpy as np
20+
import torch
1821
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
1922

20-
from diffusers import FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline
21-
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device
23+
from diffusers import (
24+
FlowMatchEulerDiscreteScheduler,
25+
SD3Transformer2DModel,
26+
StableDiffusion3Img2ImgPipeline,
27+
StableDiffusion3Pipeline,
28+
)
29+
from diffusers.utils import load_image
30+
from diffusers.utils.import_utils import is_accelerate_available
31+
from diffusers.utils.testing_utils import (
32+
is_peft_available,
33+
numpy_cosine_similarity_distance,
34+
require_peft_backend,
35+
require_torch_gpu,
36+
torch_device,
37+
)
2238

2339

2440
if is_peft_available():
@@ -29,6 +45,10 @@
2945
from utils import PeftLoraLoaderMixinTests # noqa: E402
3046

3147

48+
if is_accelerate_available():
49+
from accelerate.utils import release_memory
50+
51+
3252
@require_peft_backend
3353
class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
3454
pipeline_class = StableDiffusion3Pipeline
@@ -108,3 +128,88 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
108128
@unittest.skip("Not supported in SD3.")
109129
def test_modify_padding_mode(self):
110130
pass
131+
132+
133+
@require_torch_gpu
134+
@require_peft_backend
135+
class LoraSD3IntegrationTests(unittest.TestCase):
136+
pipeline_class = StableDiffusion3Img2ImgPipeline
137+
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
138+
139+
def setUp(self):
140+
super().setUp()
141+
gc.collect()
142+
torch.cuda.empty_cache()
143+
144+
def tearDown(self):
145+
super().tearDown()
146+
gc.collect()
147+
torch.cuda.empty_cache()
148+
149+
def get_inputs(self, device, seed=0):
150+
init_image = load_image(
151+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
152+
)
153+
if str(device).startswith("mps"):
154+
generator = torch.manual_seed(seed)
155+
else:
156+
generator = torch.Generator(device="cpu").manual_seed(seed)
157+
158+
return {
159+
"prompt": "corgi",
160+
"num_inference_steps": 2,
161+
"guidance_scale": 5.0,
162+
"output_type": "np",
163+
"generator": generator,
164+
"image": init_image,
165+
}
166+
167+
def test_sd3_img2img_lora(self):
168+
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16)
169+
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors")
170+
pipe.enable_sequential_cpu_offload()
171+
172+
inputs = self.get_inputs(torch_device)
173+
174+
image = pipe(**inputs).images[0]
175+
image_slice = image[0, :10, :10]
176+
expected_slice = np.array(
177+
[
178+
0.47827148,
179+
0.5,
180+
0.71972656,
181+
0.3955078,
182+
0.4194336,
183+
0.69628906,
184+
0.37036133,
185+
0.40820312,
186+
0.6923828,
187+
0.36450195,
188+
0.40429688,
189+
0.6904297,
190+
0.35595703,
191+
0.39257812,
192+
0.68652344,
193+
0.35498047,
194+
0.3984375,
195+
0.68310547,
196+
0.34716797,
197+
0.3996582,
198+
0.6855469,
199+
0.3388672,
200+
0.3959961,
201+
0.6816406,
202+
0.34033203,
203+
0.40429688,
204+
0.6845703,
205+
0.34228516,
206+
0.4086914,
207+
0.6870117,
208+
]
209+
)
210+
211+
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
212+
213+
assert max_diff < 1e-4, f"Outputs are not close enough, got {max_diff}"
214+
pipe.unload_lora_weights()
215+
release_memory(pipe)

0 commit comments

Comments
 (0)