@@ -144,14 +144,10 @@ def forward(self, x, sigma, uncond, cond, cond_scale):
144
144
def add_noise (sample : torch .Tensor , noise_amt : float ):
145
145
return sample + torch .randn (sample .shape , device = sample .device ) * noise_amt
146
146
147
- def get_output_folder (output_path ,batch_folder = None ):
148
- yearMonth = time .strftime ('%Y-%m/' )
149
- out_path = os .path .join (output_path ,yearMonth )
147
+ def get_output_folder (output_path , batch_folder ):
148
+ out_path = os .path .join (output_path ,time .strftime ('%Y-%m/' ))
150
149
if batch_folder != "" :
151
- out_path = os .path .join (out_path ,batch_folder )
152
- # we will also make sure the path suffix is a slash if linux and a backslash if windows
153
- if out_path [- 1 ] != os .path .sep :
154
- out_path += os .path .sep
150
+ out_path = os .path .join (out_path , batch_folder )
155
151
os .makedirs (out_path , exist_ok = True )
156
152
return out_path
157
153
@@ -184,10 +180,9 @@ def load_mask(path, shape):
184
180
return mask
185
181
186
182
def maintain_colors (prev_img , color_match_sample , mode ):
187
- color_coherence = 'Match Frame 0 RGB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}
188
- if color_coherence == 'Match Frame 0 RGB' :
183
+ if mode == 'Match Frame 0 RGB' :
189
184
return match_histograms (prev_img , color_match_sample , multichannel = True )
190
- elif color_coherence == 'Match Frame 0 HSV' :
185
+ elif mode == 'Match Frame 0 HSV' :
191
186
prev_img_hsv = cv2 .cvtColor (prev_img , cv2 .COLOR_RGB2HSV )
192
187
color_match_hsv = cv2 .cvtColor (color_match_sample , cv2 .COLOR_RGB2HSV )
193
188
matched_hsv = match_histograms (prev_img_hsv , color_match_hsv , multichannel = True )
@@ -316,57 +311,56 @@ def generate(args, return_latent=False, return_sample=False, return_c=False):
316
311
with torch .no_grad ():
317
312
with precision_scope ("cuda" ):
318
313
with model .ema_scope ():
319
- for n in range (args .n_samples ):
320
- for prompts in data :
321
- uc = None
322
- if args .scale != 1.0 :
323
- uc = model .get_learned_conditioning (batch_size * ["" ])
324
- if isinstance (prompts , tuple ):
325
- prompts = list (prompts )
326
- c = model .get_learned_conditioning (prompts )
327
-
328
- if args .init_c != None :
329
- c = args .init_c
330
-
331
- if args .sampler in ["klms" ,"dpm2" ,"dpm2_ancestral" ,"heun" ,"euler" ,"euler_ancestral" ]:
332
- samples = sampler_fn (
333
- c = c ,
334
- uc = uc ,
335
- args = args ,
336
- model_wrap = model_wrap ,
337
- init_latent = init_latent ,
338
- t_enc = t_enc ,
339
- device = device ,
340
- cb = callback )
314
+ for prompts in data :
315
+ uc = None
316
+ if args .scale != 1.0 :
317
+ uc = model .get_learned_conditioning (batch_size * ["" ])
318
+ if isinstance (prompts , tuple ):
319
+ prompts = list (prompts )
320
+ c = model .get_learned_conditioning (prompts )
321
+
322
+ if args .init_c != None :
323
+ c = args .init_c
324
+
325
+ if args .sampler in ["klms" ,"dpm2" ,"dpm2_ancestral" ,"heun" ,"euler" ,"euler_ancestral" ]:
326
+ samples = sampler_fn (
327
+ c = c ,
328
+ uc = uc ,
329
+ args = args ,
330
+ model_wrap = model_wrap ,
331
+ init_latent = init_latent ,
332
+ t_enc = t_enc ,
333
+ device = device ,
334
+ cb = callback )
335
+ else :
336
+ # args.sampler == 'plms' or args.sampler == 'ddim':
337
+ if init_latent is not None and args .strength > 0 :
338
+ z_enc = sampler .stochastic_encode (init_latent , torch .tensor ([t_enc ]* batch_size ).to (device ))
341
339
else :
342
- # args.sampler == 'plms' or args.sampler == 'ddim':
343
- if init_latent is not None and args .strength > 0 :
344
- z_enc = sampler .stochastic_encode (init_latent , torch .tensor ([t_enc ]* batch_size ).to (device ))
345
- else :
346
- z_enc = torch .randn ([args .n_samples , args .C , args .H // args .f , args .W // args .f ], device = device )
347
- samples = sampler .decode (z_enc ,
348
- c ,
349
- t_enc ,
350
- unconditional_guidance_scale = args .scale ,
351
- unconditional_conditioning = uc ,
352
- img_callback = callback )
353
-
354
- if return_latent :
355
- results .append (samples .clone ())
356
-
357
- x_samples = model .decode_first_stage (samples )
358
- if return_sample :
359
- results .append (x_samples .clone ())
360
-
361
- x_samples = torch .clamp ((x_samples + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
362
-
363
- if return_c :
364
- results .append (c .clone ())
365
-
366
- for x_sample in x_samples :
367
- x_sample = 255. * rearrange (x_sample .cpu ().numpy (), 'c h w -> h w c' )
368
- image = Image .fromarray (x_sample .astype (np .uint8 ))
369
- results .append (image )
340
+ z_enc = torch .randn ([args .n_samples , args .C , args .H // args .f , args .W // args .f ], device = device )
341
+ samples = sampler .decode (z_enc ,
342
+ c ,
343
+ t_enc ,
344
+ unconditional_guidance_scale = args .scale ,
345
+ unconditional_conditioning = uc ,
346
+ img_callback = callback )
347
+
348
+ if return_latent :
349
+ results .append (samples .clone ())
350
+
351
+ x_samples = model .decode_first_stage (samples )
352
+ if return_sample :
353
+ results .append (x_samples .clone ())
354
+
355
+ x_samples = torch .clamp ((x_samples + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
356
+
357
+ if return_c :
358
+ results .append (c .clone ())
359
+
360
+ for x_sample in x_samples :
361
+ x_sample = 255. * rearrange (x_sample .cpu ().numpy (), 'c h w -> h w c' )
362
+ image = Image .fromarray (x_sample .astype (np .uint8 ))
363
+ results .append (image )
370
364
return results
371
365
372
366
def sample_from_cv2 (sample : np .ndarray ) -> torch .Tensor :
@@ -685,7 +679,7 @@ def DeforumArgs():
685
679
seed_behavior = "iter" #@param ["iter","fixed","random"]
686
680
687
681
#@markdown **Grid Settings**
688
- make_grid = True #@param {type:"boolean"}
682
+ make_grid = False #@param {type:"boolean"}
689
683
grid_rows = 2 #@param
690
684
691
685
precision = 'autocast'
0 commit comments