|
17 | 17 | import comfy.patcher_extension
|
18 | 18 | import comfy.hooks
|
19 | 19 | import comfy.context_windows
|
| 20 | +import comfy.utils |
20 | 21 | import scipy.stats
|
21 | 22 | import numpy
|
22 | 23 |
|
@@ -61,15 +62,15 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
61 | 62 | if "mask_strength" in conds:
|
62 | 63 | mask_strength = conds["mask_strength"]
|
63 | 64 | mask = conds['mask']
|
64 |
| - assert (mask.shape[1:] == x_in.shape[2:]) |
| 65 | + # assert (mask.shape[1:] == x_in.shape[2:]) |
65 | 66 |
|
66 | 67 | mask = mask[:input_x.shape[0]]
|
67 | 68 | if area is not None:
|
68 | 69 | for i in range(len(dims)):
|
69 | 70 | mask = mask.narrow(i + 1, area[len(dims) + i], area[i])
|
70 | 71 |
|
71 | 72 | mask = mask * mask_strength
|
72 |
| - mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) |
| 73 | + mask = mask.unsqueeze(1).repeat((input_x.shape[0] // mask.shape[0], input_x.shape[1]) + (1, ) * (mask.ndim - 1)) |
73 | 74 | else:
|
74 | 75 | mask = torch.ones_like(input_x)
|
75 | 76 | mult = mask * strength
|
@@ -553,7 +554,10 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
|
553 | 554 | if len(mask.shape) == len(dims):
|
554 | 555 | mask = mask.unsqueeze(0)
|
555 | 556 | if mask.shape[1:] != dims:
|
556 |
| - mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1) |
| 557 | + if mask.ndim < 4: |
| 558 | + mask = comfy.utils.common_upscale(mask.unsqueeze(1), dims[-1], dims[-2], 'bilinear', 'none').squeeze(1) |
| 559 | + else: |
| 560 | + mask = comfy.utils.common_upscale(mask, dims[-1], dims[-2], 'bilinear', 'none') |
557 | 561 |
|
558 | 562 | if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
|
559 | 563 | bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
|
|
0 commit comments