Skip to content

Commit 5776e29

Browse files
Merge pull request CompVis#35 from enzymezoo-code/inpainting_1.0
Inpainting 1.0
2 parents 975ed37 + 15f1c48 commit 5776e29

File tree

1 file changed

+55
-61
lines changed

1 file changed

+55
-61
lines changed

Deforum_Stable_Diffusion.py

Lines changed: 55 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,10 @@ def forward(self, x, sigma, uncond, cond, cond_scale):
144144
def add_noise(sample: torch.Tensor, noise_amt: float):
145145
return sample + torch.randn(sample.shape, device=sample.device) * noise_amt
146146

147-
def get_output_folder(output_path,batch_folder=None):
148-
yearMonth = time.strftime('%Y-%m/')
149-
out_path = os.path.join(output_path,yearMonth)
147+
def get_output_folder(output_path, batch_folder):
148+
out_path = os.path.join(output_path,time.strftime('%Y-%m/'))
150149
if batch_folder != "":
151-
out_path = os.path.join(out_path,batch_folder)
152-
# we will also make sure the path suffix is a slash if linux and a backslash if windows
153-
if out_path[-1] != os.path.sep:
154-
out_path += os.path.sep
150+
out_path = os.path.join(out_path, batch_folder)
155151
os.makedirs(out_path, exist_ok=True)
156152
return out_path
157153

@@ -184,10 +180,9 @@ def load_mask(path, shape):
184180
return mask
185181

186182
def maintain_colors(prev_img, color_match_sample, mode):
187-
color_coherence = 'Match Frame 0 RGB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}
188-
if color_coherence == 'Match Frame 0 RGB':
183+
if mode == 'Match Frame 0 RGB':
189184
return match_histograms(prev_img, color_match_sample, multichannel=True)
190-
elif color_coherence == 'Match Frame 0 HSV':
185+
elif mode == 'Match Frame 0 HSV':
191186
prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV)
192187
color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV)
193188
matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True)
@@ -316,57 +311,56 @@ def generate(args, return_latent=False, return_sample=False, return_c=False):
316311
with torch.no_grad():
317312
with precision_scope("cuda"):
318313
with model.ema_scope():
319-
for n in range(args.n_samples):
320-
for prompts in data:
321-
uc = None
322-
if args.scale != 1.0:
323-
uc = model.get_learned_conditioning(batch_size * [""])
324-
if isinstance(prompts, tuple):
325-
prompts = list(prompts)
326-
c = model.get_learned_conditioning(prompts)
327-
328-
if args.init_c != None:
329-
c = args.init_c
330-
331-
if args.sampler in ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral"]:
332-
samples = sampler_fn(
333-
c=c,
334-
uc=uc,
335-
args=args,
336-
model_wrap=model_wrap,
337-
init_latent=init_latent,
338-
t_enc=t_enc,
339-
device=device,
340-
cb=callback)
314+
for prompts in data:
315+
uc = None
316+
if args.scale != 1.0:
317+
uc = model.get_learned_conditioning(batch_size * [""])
318+
if isinstance(prompts, tuple):
319+
prompts = list(prompts)
320+
c = model.get_learned_conditioning(prompts)
321+
322+
if args.init_c != None:
323+
c = args.init_c
324+
325+
if args.sampler in ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral"]:
326+
samples = sampler_fn(
327+
c=c,
328+
uc=uc,
329+
args=args,
330+
model_wrap=model_wrap,
331+
init_latent=init_latent,
332+
t_enc=t_enc,
333+
device=device,
334+
cb=callback)
335+
else:
336+
# args.sampler == 'plms' or args.sampler == 'ddim':
337+
if init_latent is not None and args.strength > 0:
338+
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
341339
else:
342-
# args.sampler == 'plms' or args.sampler == 'ddim':
343-
if init_latent is not None and args.strength > 0:
344-
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
345-
else:
346-
z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)
347-
samples = sampler.decode(z_enc,
348-
c,
349-
t_enc,
350-
unconditional_guidance_scale=args.scale,
351-
unconditional_conditioning=uc,
352-
img_callback=callback)
353-
354-
if return_latent:
355-
results.append(samples.clone())
356-
357-
x_samples = model.decode_first_stage(samples)
358-
if return_sample:
359-
results.append(x_samples.clone())
360-
361-
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
362-
363-
if return_c:
364-
results.append(c.clone())
365-
366-
for x_sample in x_samples:
367-
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
368-
image = Image.fromarray(x_sample.astype(np.uint8))
369-
results.append(image)
340+
z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)
341+
samples = sampler.decode(z_enc,
342+
c,
343+
t_enc,
344+
unconditional_guidance_scale=args.scale,
345+
unconditional_conditioning=uc,
346+
img_callback=callback)
347+
348+
if return_latent:
349+
results.append(samples.clone())
350+
351+
x_samples = model.decode_first_stage(samples)
352+
if return_sample:
353+
results.append(x_samples.clone())
354+
355+
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
356+
357+
if return_c:
358+
results.append(c.clone())
359+
360+
for x_sample in x_samples:
361+
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
362+
image = Image.fromarray(x_sample.astype(np.uint8))
363+
results.append(image)
370364
return results
371365

372366
def sample_from_cv2(sample: np.ndarray) -> torch.Tensor:
@@ -685,7 +679,7 @@ def DeforumArgs():
685679
seed_behavior = "iter" #@param ["iter","fixed","random"]
686680

687681
#@markdown **Grid Settings**
688-
make_grid = True #@param {type:"boolean"}
682+
make_grid = False #@param {type:"boolean"}
689683
grid_rows = 2 #@param
690684

691685
precision = 'autocast'

0 commit comments

Comments
 (0)