Skip to content

Commit 8eecf92

Browse files
committed
style
1 parent c928bbd commit 8eecf92

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

examples/community/masked_stable_diffusion_xl_img2img.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
1-
import inspect
21
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
32

4-
from PIL import Image, ImageFilter
5-
import torch
6-
73
import numpy as np
4+
import torch
5+
from PIL import Image, ImageFilter
86

7+
from diffusers.image_processor import PipelineImageInput
8+
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
9+
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import (
10+
StableDiffusionXLImg2ImgPipeline,
11+
rescale_noise_cfg,
12+
retrieve_latents,
13+
retrieve_timesteps,
14+
)
915
from diffusers.utils import (
1016
deprecate,
1117
is_torch_xla_available,
1218
logging,
1319
)
14-
from diffusers.image_processor import PipelineImageInput
15-
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline, retrieve_latents, retrieve_timesteps, rescale_noise_cfg
16-
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
17-
1820
from diffusers.utils.torch_utils import randn_tensor
1921

22+
2023
if is_torch_xla_available():
2124
import torch_xla.core.xla_model as xm
2225

@@ -29,7 +32,7 @@
2932

3033

3134
class MaskedStableDiffusionXLImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):
32-
debug_save = 0
35+
debug_save = 0
3336

3437
@torch.no_grad()
3538
def __call__(
@@ -79,7 +82,7 @@ def __call__(
7982
List[Image.Image],
8083
List[np.ndarray],
8184
] = None,
82-
blur=24,
85+
blur=24,
8386
blur_compose=4,
8487
sample_mode='sample',
8588
**kwargs
@@ -283,7 +286,7 @@ def denoising_value_valid(dnv):
283286
)
284287

285288
# mean of the latent distribution
286-
# it is multiplied by self.vae.config.scaling_factor
289+
# it is multiplied by self.vae.config.scaling_factor
287290
non_paint_latents = self.prepare_latents(
288291
original_image,
289292
latent_timestep,
@@ -292,7 +295,7 @@ def denoising_value_valid(dnv):
292295
prompt_embeds.dtype,
293296
device,
294297
generator,
295-
add_noise=False,
298+
add_noise=False,
296299
sample_mode="argmax")
297300

298301
if self.debug_save:
@@ -398,7 +401,7 @@ def denoising_value_valid(dnv):
398401

399402
shape = non_paint_latents.shape
400403
noise = randn_tensor(shape, generator=generator, device=device, dtype=latents.dtype)
401-
# noisy latent code of input image at current step
404+
# noisy latent code of input image at current step
402405
orig_latents_t = non_paint_latents
403406
orig_latents_t = self.scheduler.add_noise(non_paint_latents, noise, t.unsqueeze(0))
404407

@@ -491,7 +494,7 @@ def denoising_value_valid(dnv):
491494
if self.debug_save:
492495
image_gen = self.latents_to_img(latents)
493496
image_gen[0].save("from_latent.png")
494-
497+
495498
if latent_mask is not None:
496499
# interpolate with latent mask
497500
latents = torch.lerp(non_paint_latents, latents, latent_mask)

0 commit comments

Comments
 (0)