@@ -35,6 +35,7 @@ def __init__(
35
35
device = None , dtype = None , operations = None
36
36
):
37
37
super ().__init__ ()
38
+ self .additional_in_dim = additional_in_dim
38
39
self .img_in = operations .Linear (in_dim + additional_in_dim , dim , device = device , dtype = dtype )
39
40
self .controlnet_blocks = torch .nn .ModuleList (
40
41
[
@@ -44,7 +45,7 @@ def __init__(
44
45
)
45
46
46
47
def process_input_latent_image (self , latent_image ):
47
- latent_image = comfy .latent_formats .Wan21 ().process_in (latent_image )
48
+ latent_image [:, : 16 ] = comfy .latent_formats .Wan21 ().process_in (latent_image [:, : 16 ] )
48
49
patch_size = 2
49
50
hidden_states = comfy .ldm .common_dit .pad_to_patch_size (latent_image , (1 , patch_size , patch_size ))
50
51
orig_shape = hidden_states .shape
@@ -73,19 +74,33 @@ def load_model_patch(self, name):
73
74
sd = comfy .utils .load_torch_file (model_patch_path , safe_load = True )
74
75
dtype = comfy .utils .weight_dtype (sd )
75
76
# TODO: this node will work with more types of model patches
76
- model = QwenImageBlockWiseControlNet (device = comfy .model_management .unet_offload_device (), dtype = dtype , operations = comfy .ops .manual_cast )
77
+ additional_in_dim = sd ["img_in.weight" ].shape [1 ] - 64
78
+ model = QwenImageBlockWiseControlNet (additional_in_dim = additional_in_dim , device = comfy .model_management .unet_offload_device (), dtype = dtype , operations = comfy .ops .manual_cast )
77
79
model .load_state_dict (sd )
78
80
model = comfy .model_patcher .ModelPatcher (model , load_device = comfy .model_management .get_torch_device (), offload_device = comfy .model_management .unet_offload_device ())
79
81
return (model ,)
80
82
81
83
82
84
class DiffSynthCnetPatch :
83
- def __init__ (self , model_patch , vae , image , strength ):
84
- self .encoded_image = model_patch .model .process_input_latent_image (vae .encode (image ))
85
+ def __init__ (self , model_patch , vae , image , strength , mask = None ):
85
86
self .model_patch = model_patch
86
87
self .vae = vae
87
88
self .image = image
88
89
self .strength = strength
90
+ self .mask = mask
91
+ self .encoded_image = model_patch .model .process_input_latent_image (self .encode_latent_cond (image ))
92
+
93
+ def encode_latent_cond (self , image ):
94
+ latent_image = self .vae .encode (image )
95
+ if self .model_patch .model .additional_in_dim > 0 :
96
+ if self .mask is None :
97
+ mask_ = torch .ones_like (latent_image )[:, :self .model_patch .model .additional_in_dim // 4 ]
98
+ else :
99
+ mask_ = comfy .utils .common_upscale (self .mask .mean (dim = 1 , keepdim = True ), latent_image .shape [- 1 ], latent_image .shape [- 2 ], "bilinear" , "none" )
100
+
101
+ return torch .cat ([latent_image , mask_ ], dim = 1 )
102
+ else :
103
+ return latent_image
89
104
90
105
def __call__ (self , kwargs ):
91
106
x = kwargs .get ("x" )
@@ -95,7 +110,7 @@ def __call__(self, kwargs):
95
110
spacial_compression = self .vae .spacial_compression_encode ()
96
111
image_scaled = comfy .utils .common_upscale (self .image .movedim (- 1 , 1 ), x .shape [- 1 ] * spacial_compression , x .shape [- 2 ] * spacial_compression , "area" , "center" )
97
112
loaded_models = comfy .model_management .loaded_models (only_currently_used = True )
98
- self .encoded_image = self .model_patch .model .process_input_latent_image (self .vae . encode (image_scaled .movedim (1 , - 1 )))
113
+ self .encoded_image = self .model_patch .model .process_input_latent_image (self .encode_latent_cond (image_scaled .movedim (1 , - 1 )))
99
114
comfy .model_management .load_models_gpu (loaded_models )
100
115
101
116
img = img + (self .model_patch .model .control_block (img , self .encoded_image .to (img .dtype ), block_index ) * self .strength )
@@ -118,17 +133,25 @@ def INPUT_TYPES(s):
118
133
"vae" : ("VAE" ,),
119
134
"image" : ("IMAGE" ,),
120
135
"strength" : ("FLOAT" , {"default" : 1.0 , "min" : - 10.0 , "max" : 10.0 , "step" : 0.01 }),
121
- }}
136
+ },
137
+ "optional" : {"mask" : ("MASK" ,)}}
122
138
RETURN_TYPES = ("MODEL" ,)
123
139
FUNCTION = "diffsynth_controlnet"
124
140
EXPERIMENTAL = True
125
141
126
142
CATEGORY = "advanced/loaders/qwen"
127
143
128
- def diffsynth_controlnet (self , model , model_patch , vae , image , strength ):
144
+ def diffsynth_controlnet (self , model , model_patch , vae , image , strength , mask = None ):
129
145
model_patched = model .clone ()
130
146
image = image [:, :, :, :3 ]
131
- model_patched .set_model_double_block_patch (DiffSynthCnetPatch (model_patch , vae , image , strength ))
147
+ if mask is not None :
148
+ if mask .ndim == 3 :
149
+ mask = mask .unsqueeze (1 )
150
+ if mask .ndim == 4 :
151
+ mask = mask .unsqueeze (2 )
152
+ mask = 1.0 - mask
153
+
154
+ model_patched .set_model_double_block_patch (DiffSynthCnetPatch (model_patch , vae , image , strength , mask ))
132
155
return (model_patched ,)
133
156
134
157
0 commit comments