Skip to content

Commit 3f2f093

Browse files
Fix Conditioning masks on 3d latents. (comfyanonymous#9506)
1 parent cc6d764 commit 3f2f093

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

comfy/samplers.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import comfy.patcher_extension
1818
import comfy.hooks
1919
import comfy.context_windows
20+
import comfy.utils
2021
import scipy.stats
2122
import numpy
2223

@@ -61,15 +62,15 @@ def get_area_and_mult(conds, x_in, timestep_in):
6162
if "mask_strength" in conds:
6263
mask_strength = conds["mask_strength"]
6364
mask = conds['mask']
64-
assert (mask.shape[1:] == x_in.shape[2:])
65+
# assert (mask.shape[1:] == x_in.shape[2:])
6566

6667
mask = mask[:input_x.shape[0]]
6768
if area is not None:
6869
for i in range(len(dims)):
6970
mask = mask.narrow(i + 1, area[len(dims) + i], area[i])
7071

7172
mask = mask * mask_strength
72-
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
73+
mask = mask.unsqueeze(1).repeat((input_x.shape[0] // mask.shape[0], input_x.shape[1]) + (1, ) * (mask.ndim - 1))
7374
else:
7475
mask = torch.ones_like(input_x)
7576
mult = mask * strength
@@ -553,7 +554,10 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
553554
if len(mask.shape) == len(dims):
554555
mask = mask.unsqueeze(0)
555556
if mask.shape[1:] != dims:
556-
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1)
557+
if mask.ndim < 4:
558+
mask = comfy.utils.common_upscale(mask.unsqueeze(1), dims[-1], dims[-2], 'bilinear', 'none').squeeze(1)
559+
else:
560+
mask = comfy.utils.common_upscale(mask, dims[-1], dims[-2], 'bilinear', 'none')
557561

558562
if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
559563
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)

0 commit comments

Comments
 (0)