1- import inspect
21from typing import Any , Callable , Dict , List , Optional , Tuple , Union
32
4- from PIL import Image , ImageFilter
5- import torch
6-
73import numpy as np
4+ import torch
5+ from PIL import Image , ImageFilter
86
7+ from diffusers .image_processor import PipelineImageInput
8+ from diffusers .pipelines .stable_diffusion_xl .pipeline_output import StableDiffusionXLPipelineOutput
9+ from diffusers .pipelines .stable_diffusion_xl .pipeline_stable_diffusion_xl_img2img import (
10+ StableDiffusionXLImg2ImgPipeline ,
11+ rescale_noise_cfg ,
12+ retrieve_latents ,
13+ retrieve_timesteps ,
14+ )
915from diffusers .utils import (
1016 deprecate ,
1117 is_torch_xla_available ,
1218 logging ,
1319)
14- from diffusers .image_processor import PipelineImageInput
15- from diffusers .pipelines .stable_diffusion_xl .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline , retrieve_latents , retrieve_timesteps , rescale_noise_cfg
16- from diffusers .pipelines .stable_diffusion_xl .pipeline_output import StableDiffusionXLPipelineOutput
17-
1820from diffusers .utils .torch_utils import randn_tensor
1921
22+
2023if is_torch_xla_available ():
2124 import torch_xla .core .xla_model as xm
2225
2932
3033
3134class MaskedStableDiffusionXLImg2ImgPipeline (StableDiffusionXLImg2ImgPipeline ):
32- debug_save = 0
35+ debug_save = 0
3336
3437 @torch .no_grad ()
3538 def __call__ (
@@ -79,7 +82,7 @@ def __call__(
7982 List [Image .Image ],
8083 List [np .ndarray ],
8184 ] = None ,
82- blur = 24 ,
85+ blur = 24 ,
8386 blur_compose = 4 ,
8487 sample_mode = 'sample' ,
8588 ** kwargs
@@ -283,7 +286,7 @@ def denoising_value_valid(dnv):
283286 )
284287
285288 # mean of the latent distribution
286- # it is multiplied by self.vae.config.scaling_factor
289+ # it is multiplied by self.vae.config.scaling_factor
287290 non_paint_latents = self .prepare_latents (
288291 original_image ,
289292 latent_timestep ,
@@ -292,7 +295,7 @@ def denoising_value_valid(dnv):
292295 prompt_embeds .dtype ,
293296 device ,
294297 generator ,
295- add_noise = False ,
298+ add_noise = False ,
296299 sample_mode = "argmax" )
297300
298301 if self .debug_save :
@@ -398,7 +401,7 @@ def denoising_value_valid(dnv):
398401
399402 shape = non_paint_latents .shape
400403 noise = randn_tensor (shape , generator = generator , device = device , dtype = latents .dtype )
401- # noisy latent code of input image at current step
404+ # noisy latent code of input image at current step
402405 orig_latents_t = non_paint_latents
403406 orig_latents_t = self .scheduler .add_noise (non_paint_latents , noise , t .unsqueeze (0 ))
404407
@@ -491,7 +494,7 @@ def denoising_value_valid(dnv):
491494 if self .debug_save :
492495 image_gen = self .latents_to_img (latents )
493496 image_gen [0 ].save ("from_latent.png" )
494-
497+
495498 if latent_mask is not None :
496499 # interpolate with latent mask
497500 latents = torch .lerp (non_paint_latents , latents , latent_mask )
0 commit comments