@@ -84,8 +84,8 @@ def __call__(
8484 ] = None ,
8585 blur = 24 ,
8686 blur_compose = 4 ,
87- sample_mode = ' sample' ,
88- ** kwargs
87+ sample_mode = " sample" ,
88+ ** kwargs ,
8989 ):
9090 r"""
9191 The call function to the pipeline for generation.
@@ -174,7 +174,6 @@ def __call__(
174174 "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`" ,
175175 )
176176
177-
178177 # 0. Check inputs. Raise error if not correct
179178 self .check_inputs (
180179 prompt ,
@@ -249,7 +248,6 @@ def __call__(
249248 clip_skip = self .clip_skip ,
250249 )
251250
252-
253251 # 3. Preprocess image
254252 input_image = image if image is not None else original_image
255253 image = self .image_processor .preprocess (input_image )
@@ -282,25 +280,26 @@ def denoising_value_valid(dnv):
282280 device ,
283281 generator ,
284282 add_noise ,
285- sample_mode = sample_mode
283+ sample_mode = sample_mode ,
286284 )
287285
288286 # mean of the latent distribution
289287 # it is multiplied by self.vae.config.scaling_factor
290288 non_paint_latents = self .prepare_latents (
291- original_image ,
292- latent_timestep ,
293- batch_size ,
294- num_images_per_prompt ,
295- prompt_embeds .dtype ,
296- device ,
297- generator ,
298- add_noise = False ,
299- sample_mode = "argmax" )
289+ original_image ,
290+ latent_timestep ,
291+ batch_size ,
292+ num_images_per_prompt ,
293+ prompt_embeds .dtype ,
294+ device ,
295+ generator ,
296+ add_noise = False ,
297+ sample_mode = "argmax" ,
298+ )
300299
301300 if self .debug_save :
302301 init_img_from_latents = self .latents_to_img (non_paint_latents )
303- init_img_from_latents [0 ].save (' non_paint_latents.png' )
302+ init_img_from_latents [0 ].save (" non_paint_latents.png" )
304303 # 6. create latent mask
305304 latent_mask = self ._make_latent_mask (latents , mask )
306305
@@ -359,7 +358,6 @@ def denoising_value_valid(dnv):
359358 self .do_classifier_free_guidance ,
360359 )
361360
362-
363361 # 10. Denoising loop
364362 num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
365363
@@ -406,15 +404,14 @@ def denoising_value_valid(dnv):
406404 orig_latents_t = self .scheduler .add_noise (non_paint_latents , noise , t .unsqueeze (0 ))
407405
408406 # orig_latents_t (1 - latent_mask) + latents * latent_mask
409- latents = torch .lerp (orig_latents_t , latents , latent_mask )
407+ latents = torch .lerp (orig_latents_t , latents , latent_mask )
410408
411409 if self .debug_save :
412410 img1 = self .latents_to_img (latents )
413411 t_str = str (t .int ().item ())
414412 for i in range (3 - len (t_str )):
415- t_str = '0' + t_str
416- img1 [0 ].save (f'step{ t_str } .png' )
417-
413+ t_str = "0" + t_str
414+ img1 [0 ].save (f"step{ t_str } .png" )
418415
419416 # expand the latents if we are doing classifier free guidance
420417 latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
@@ -444,7 +441,6 @@ def denoising_value_valid(dnv):
444441 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
445442 noise_pred = rescale_noise_cfg (noise_pred , noise_pred_text , guidance_rescale = self .guidance_rescale )
446443
447-
448444 # compute the previous noisy sample x_t -> x_t-1
449445 latents_dtype = latents .dtype
450446 latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs , return_dict = False )[0 ]
@@ -501,7 +497,7 @@ def denoising_value_valid(dnv):
501497
502498 latents = self .denormalize (latents )
503499 image = self .vae .decode (latents , return_dict = False )[0 ]
504- m = mask_compose .permute (2 ,0 , 1 ).unsqueeze (0 ).to (image )
500+ m = mask_compose .permute (2 , 0 , 1 ).unsqueeze (0 ).to (image )
505501 img_compose = m * image + (1 - m ) * original_image .to (image )
506502 image = img_compose
507503 # cast back to fp16 if needed
@@ -519,7 +515,6 @@ def denoising_value_valid(dnv):
519515 # Offload all models
520516 self .maybe_free_model_hooks ()
521517
522-
523518 if not return_dict :
524519 return (image ,)
525520
@@ -551,12 +546,17 @@ def _make_latent_mask(self, latents, mask):
551546 return latent_mask
552547
553548 def prepare_latents (
554- self , image , timestep , batch_size , num_images_per_prompt , dtype , device ,
549+ self ,
550+ image ,
551+ timestep ,
552+ batch_size ,
553+ num_images_per_prompt ,
554+ dtype ,
555+ device ,
555556 generator = None ,
556557 add_noise = True ,
557- sample_mode : str = "sample"
558+ sample_mode : str = "sample" ,
558559 ):
559-
560560 if not isinstance (image , (torch .Tensor , Image .Image , list )):
561561 raise ValueError (
562562 f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is { type (image )} "
@@ -573,7 +573,7 @@ def prepare_latents(
573573
574574 if image .shape [1 ] == 4 :
575575 init_latents = image
576- elif sample_mode == ' random' :
576+ elif sample_mode == " random" :
577577 height , width = image .shape [- 2 :]
578578 num_channels_latents = self .unet .config .in_channels
579579 latents = self .random_latents (
@@ -600,7 +600,9 @@ def prepare_latents(
600600
601601 elif isinstance (generator , list ):
602602 init_latents = [
603- retrieve_latents (self .vae .encode (image [i : i + 1 ]), generator = generator [i ], sample_mode = sample_mode )
603+ retrieve_latents (
604+ self .vae .encode (image [i : i + 1 ]), generator = generator [i ], sample_mode = sample_mode
605+ )
604606 for i in range (batch_size )
605607 ]
606608 init_latents = torch .cat (init_latents , dim = 0 )
@@ -661,9 +663,7 @@ def denormalize(self, latents):
661663 latents_mean = (
662664 torch .tensor (self .vae .config .latents_mean ).view (1 , 4 , 1 , 1 ).to (latents .device , latents .dtype )
663665 )
664- latents_std = (
665- torch .tensor (self .vae .config .latents_std ).view (1 , 4 , 1 , 1 ).to (latents .device , latents .dtype )
666- )
666+ latents_std = torch .tensor (self .vae .config .latents_std ).view (1 , 4 , 1 , 1 ).to (latents .device , latents .dtype )
667667 latents = latents * latents_std / self .vae .config .scaling_factor + latents_mean
668668 else :
669669 latents = latents / self .vae .config .scaling_factor
@@ -673,10 +673,10 @@ def denormalize(self, latents):
673673 def latents_to_img (self , latents ):
674674 l1 = self .denormalize (latents )
675675 img1 = self .vae .decode (l1 , return_dict = False )[0 ]
676- img1 = self .image_processor .postprocess (img1 , output_type = ' pil' , do_denormalize = [True ])
676+ img1 = self .image_processor .postprocess (img1 , output_type = " pil" , do_denormalize = [True ])
677677 return img1
678678
679679 def blur_mask (self , pil_mask , blur ):
680680 mask_blur = pil_mask .filter (ImageFilter .GaussianBlur (radius = blur ))
681681 mask_blur = np .array (mask_blur )
682- return torch .from_numpy (np .tile (mask_blur / mask_blur .max (), (3 , 1 , 1 )).transpose (1 ,2 , 0 ))
682+ return torch .from_numpy (np .tile (mask_blur / mask_blur .max (), (3 , 1 , 1 )).transpose (1 , 2 , 0 ))
0 commit comments