diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 57a34609d28e..1259bc715fdd 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -51,6 +51,7 @@ _import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"] _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"] + _import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] @@ -120,6 +121,7 @@ HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel, MultiControlNetModel, + MultiControlNetUnionModel, SD3ControlNetModel, SD3MultiControlNetModel, SparseControlNetModel, diff --git a/src/diffusers/models/controlnets/__init__.py b/src/diffusers/models/controlnets/__init__.py index ea86d669f392..1dd92e51a44c 100644 --- a/src/diffusers/models/controlnets/__init__.py +++ b/src/diffusers/models/controlnets/__init__.py @@ -18,6 +18,7 @@ from .controlnet_union import ControlNetUnionModel from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel from .multicontrolnet import MultiControlNetModel + from .multicontrolnet_union import MultiControlNetUnionModel if is_flax_available(): from .controlnet_flax import FlaxControlNetModel diff --git a/src/diffusers/models/controlnets/multicontrolnet_union.py b/src/diffusers/models/controlnets/multicontrolnet_union.py new file mode 100644 index 000000000000..6dbc0c97ff75 --- /dev/null +++ b/src/diffusers/models/controlnets/multicontrolnet_union.py @@ -0,0 +1,192 @@ +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from ...models.controlnets.controlnet import ControlNetOutput +from ...models.controlnets.controlnet_union import ControlNetUnionModel +from ...models.modeling_utils import ModelMixin +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MultiControlNetUnionModel(ModelMixin): + r""" + Multiple `ControlNetUnionModel` wrapper class for Multi-ControlNet-Union. + + This module is a wrapper for multiple instances of the `ControlNetUnionModel`. The `forward()` API is designed to + be compatible with `ControlNetUnionModel`. + + Args: + controlnets (`List[ControlNetUnionModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `ControlNetUnionModel` as a list. + """ + + def __init__(self, controlnets: Union[List[ControlNetUnionModel], Tuple[ControlNetUnionModel]]): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: List[torch.tensor], + control_type: List[torch.Tensor], + control_type_idx: List[List[int]], + conditioning_scale: List[float], + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple]: + for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate( + zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets) + ): + down_samples, mid_sample = controlnet( + sample=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=image, + control_type=ctype, + control_type_idx=ctype_idx, + conditioning_scale=scale, + class_labels=class_labels, + timestep_cond=timestep_cond, + attention_mask=attention_mask, + added_cond_kwargs=added_cond_kwargs, + cross_attention_kwargs=cross_attention_kwargs, + guess_mode=guess_mode, + return_dict=return_dict, + ) + + # merge samples + if i == 0: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + + return down_block_res_samples, mid_block_res_sample + + # Copied from diffusers.models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained with ControlNet->ControlNetUnion + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + save_function: Callable = None, + safe_serialization: bool = True, + variant: Optional[str] = None, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + variant (`str`, *optional*): + If specified, weights are saved in the format pytorch_model..bin. + """ + for idx, controlnet in enumerate(self.nets): + suffix = "" if idx == 0 else f"_{idx}" + controlnet.save_pretrained( + save_directory + suffix, + is_main_process=is_main_process, + save_function=save_function, + safe_serialization=safe_serialization, + variant=variant, + ) + + @classmethod + # Copied from diffusers.models.controlnets.multicontrolnet.MultiControlNetModel.from_pretrained with ControlNet->ControlNetUnion + def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a pretrained MultiControlNetUnion model from multiple pre-trained controlnet models. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_path (`os.PathLike`): + A path to a *directory* containing model weights saved using + [`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.save_pretrained`], e.g., + `./my_model_directory/controlnet`. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + GPU and the available CPU RAM if unset. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. + variant (`str`, *optional*): + If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is + ignored when using `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from + `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. + """ + idx = 0 + controlnets = [] + + # load controlnet and append to list until no controlnet directory exists anymore + # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained` + # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ... + model_path_to_load = pretrained_model_path + while os.path.isdir(model_path_to_load): + controlnet = ControlNetUnionModel.from_pretrained(model_path_to_load, **kwargs) + controlnets.append(controlnet) + + idx += 1 + model_path_to_load = pretrained_model_path + f"_{idx}" + + logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.") + + if len(controlnets) == 0: + raise ValueError( + f"No ControlNetUnions found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}." + ) + + return cls(controlnets) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py index 27e627e5bac9..edae259358b0 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py @@ -19,7 +19,6 @@ import numpy as np import PIL.Image import torch -import torch.nn.functional as F from transformers import ( CLIPImageProcessor, CLIPTextModel, @@ -38,7 +37,13 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel +from ...models import ( + AutoencoderKL, + ControlNetUnionModel, + ImageProjection, + MultiControlNetUnionModel, + UNet2DConditionModel, +) from ...models.attention_processor import ( AttnProcessor2_0, XFormersAttnProcessor, @@ -244,7 +249,9 @@ def __init__( tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, - controlnet: ControlNetUnionModel, + controlnet: Union[ + ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel + ], scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, @@ -253,8 +260,8 @@ def __init__( ): super().__init__() - if not isinstance(controlnet, ControlNetUnionModel): - raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.") + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetUnionModel(controlnet) self.register_modules( vae=vae, @@ -664,6 +671,7 @@ def check_inputs( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + control_mode=None, callback_on_step_end_tensor_inputs=None, ): if callback_on_step_end_tensor_inputs is not None and not all( @@ -721,46 +729,102 @@ def check_inputs( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetUnionModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + # Check `image` - is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.controlnet, torch._dynamo.eval_frame.OptimizedModule - ) - if ( - isinstance(self.controlnet, ControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetModel) - ): - self.check_image(image, prompt, prompt_embeds) - elif ( - isinstance(self.controlnet, ControlNetUnionModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetUnionModel) - ): - self.check_image(image, prompt, prompt_embeds) + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + if isinstance(controlnet, ControlNetUnionModel): + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + elif isinstance(controlnet, MultiControlNetUnionModel): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + elif not all(isinstance(i, list) for i in image): + raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for images_ in image: + for image_ in images_: + self.check_image(image_, prompt, prompt_embeds) else: assert False - if not isinstance(control_guidance_start, (tuple, list)): - control_guidance_start = [control_guidance_start] - - if not isinstance(control_guidance_end, (tuple, list)): - control_guidance_end = [control_guidance_end] + # Check `controlnet_conditioning_scale` + # TODO Update for https://github.com/huggingface/diffusers/pull/10723 + if isinstance(controlnet, ControlNetUnionModel): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif isinstance(controlnet, MultiControlNetUnionModel): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings is not supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False if len(control_guidance_start) != len(control_guidance_end): raise ValueError( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) + if isinstance(controlnet, MultiControlNetUnionModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( - f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + f"control_guidance_start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: - raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + raise ValueError(f"control_guidance_start: {start} can't be smaller than 0.") if end > 1.0: - raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + raise ValueError(f"control_guidance_end: {end} can't be larger than 1.0.") + + # Check `control_mode` + if isinstance(controlnet, ControlNetUnionModel): + if max(control_mode) >= controlnet.config.num_control_type: + raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.") + elif isinstance(controlnet, MultiControlNetUnionModel): + for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets): + if max(_control_mode) >= _controlnet.config.num_control_type: + raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.") + else: + assert False + + # Equal number of `image` and `control_mode` elements + if isinstance(controlnet, ControlNetUnionModel): + if len(image) != len(control_mode): + raise ValueError("Expected len(control_image) == len(control_mode)") + elif isinstance(controlnet, MultiControlNetUnionModel): + if not all(isinstance(i, list) for i in control_mode): + raise ValueError( + "For multiple controlnets: elements of control_mode must be lists representing conditioning mode." + ) + + elif sum(len(x) for x in image) != sum(len(x) for x in control_mode): + raise ValueError("Expected len(control_image) == len(control_mode)") + else: + assert False if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( @@ -936,7 +1000,7 @@ def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, - control_image: PipelineImageInput = None, + control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -963,7 +1027,7 @@ def __call__( guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, - control_mode: Optional[Union[int, List[int]]] = None, + control_mode: Optional[Union[int, List[int], List[List[int]]]] = None, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, @@ -985,7 +1049,7 @@ def __call__( prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders. - control_image (`PipelineImageInput`): + control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or @@ -1077,6 +1141,11 @@ def __call__( The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. + control_mode (`int` or `List[int]` or `List[List[int]], *optional*): + The control condition types for the ControlNet. See the ControlNet's model card forinformation on the + available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list + where each ControlNet should have its corresponding control mode list. Should reflect the order of + conditions in control_image. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as @@ -1137,6 +1206,12 @@ def __call__( control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) if not isinstance(control_image, list): control_image = [control_image] @@ -1146,35 +1221,36 @@ def __call__( if not isinstance(control_mode, list): control_mode = [control_mode] - if len(control_image) != len(control_mode): - raise ValueError("Expected len(control_image) == len(control_type)") - - num_control_type = controlnet.config.num_control_type + if isinstance(controlnet, MultiControlNetUnionModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) # 1. Check inputs - control_type = [0 for _ in range(num_control_type)] - # 1. Check inputs. Raise error if not correct - for _image, control_idx in zip(control_image, control_mode): - control_type[control_idx] = 1 - self.check_inputs( - prompt, - prompt_2, - _image, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - ip_adapter_image, - ip_adapter_image_embeds, - negative_pooled_prompt_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - callback_on_step_end_tensor_inputs, - ) + self.check_inputs( + prompt, + prompt_2, + control_image, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + control_mode, + callback_on_step_end_tensor_inputs, + ) - control_type = torch.Tensor(control_type) + if isinstance(controlnet, ControlNetUnionModel): + control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1) + elif isinstance(controlnet, MultiControlNetUnionModel): + control_type = [ + torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1) + for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets) + ] self._guidance_scale = guidance_scale self._clip_skip = clip_skip @@ -1192,7 +1268,11 @@ def __call__( device = self._execution_device - global_pool_conditions = controlnet.config.global_pool_conditions + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetUnionModel) + else controlnet.nets[0].config.global_pool_conditions + ) guess_mode = guess_mode or global_pool_conditions # 3.1 Encode input prompt @@ -1231,19 +1311,54 @@ def __call__( ) # 4. Prepare image - for idx, _ in enumerate(control_image): - control_image[idx] = self.prepare_image( - image=control_image[idx], - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=self.do_classifier_free_guidance, - guess_mode=guess_mode, - ) - height, width = control_image[idx].shape[-2:] + if isinstance(controlnet, ControlNetUnionModel): + control_images = [] + + for image_ in control_image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(image_) + + control_image = control_images + height, width = control_image[0].shape[-2:] + + elif isinstance(controlnet, MultiControlNetUnionModel): + control_images = [] + + for control_image_ in control_image: + images = [] + + for image_ in control_image_: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + control_images.append(images) + + control_image = control_images + height, width = control_image[0][0].shape[-2:] + + else: + assert False # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( @@ -1278,10 +1393,11 @@ def __call__( # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): - controlnet_keep.append( - 1.0 - - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end) - ) + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetUnionModel) else keeps) # 7.2 Prepare added time ids & embeddings original_size = original_size or (height, width) @@ -1346,11 +1462,20 @@ def __call__( is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") - control_type = ( - control_type.reshape(1, -1) - .to(device, dtype=prompt_embeds.dtype) - .repeat(batch_size * num_images_per_prompt * 2, 1) - ) + if isinstance(controlnet, ControlNetUnionModel): + control_type = ( + control_type.reshape(1, -1) + .to(self._execution_device, dtype=prompt_embeds.dtype) + .repeat(batch_size * num_images_per_prompt * 2, 1) + ) + if isinstance(controlnet, MultiControlNetUnionModel): + control_type = [ + _control_type.reshape(1, -1) + .to(self._execution_device, dtype=prompt_embeds.dtype) + .repeat(batch_size * num_images_per_prompt * 2, 1) + for _control_type in control_type + ] + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: