diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 39ceadb5acef..f6186da260ad 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -29,7 +29,14 @@ StableDiffusionXLControlNetPipeline, ) from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline -from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline +from .flux import ( + FluxControlNetImg2ImgPipeline, + FluxControlNetInpaintPipeline, + FluxControlNetPipeline, + FluxImg2ImgPipeline, + FluxInpaintPipeline, + FluxPipeline, +) from .hunyuandit import HunyuanDiTPipeline from .kandinsky import ( KandinskyCombinedPipeline, @@ -128,6 +135,7 @@ ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline), ("lcm", LatentConsistencyModelImg2ImgPipeline), ("flux", FluxImg2ImgPipeline), + ("flux-controlnet", FluxControlNetImg2ImgPipeline), ] ) @@ -143,6 +151,7 @@ ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), ("flux", FluxInpaintPipeline), + ("flux-controlnet", FluxControlNetInpaintPipeline), ] ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 481994903d3f..11b71b1cbece 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -729,7 +729,7 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=dtype, + dtype=self.vae.dtype, ) height, width = control_image.shape[-2:] @@ -763,7 +763,7 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=dtype, + dtype=self.vae.dtype, ) height, width = control_image_.shape[-2:] @@ -840,12 +840,10 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - # handle guidance - if self.transformer.config.guidance_embeds: - guidance = torch.tensor([guidance_scale], device=device) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None + guidance = ( + torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None + ) + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None # controlnet controlnet_block_samples, controlnet_single_block_samples = self.controlnet( @@ -863,6 +861,11 @@ def __call__( return_dict=False, ) + guidance = ( + torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None + ) + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 72803b180c34..deeb9e3f546a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -767,7 +767,7 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=dtype, + dtype=self.vae.dtype, ) height, width = control_image.shape[-2:] @@ -798,7 +798,7 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=dtype, + dtype=self.vae.dtype, ) height, width = control_image_.shape[-2:] diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index d43acdf38ea5..e763200155f6 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -899,7 +899,7 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=dtype, + dtype=self.vae.dtype, ) height, width = control_image.shape[-2:] @@ -933,7 +933,7 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=dtype, + dtype=self.vae.dtype, ) height, width = control_image_.shape[-2:]