Skip to content

Commit c4e77f1

Browse files
xduzhangjiayusayakpaulyiyixuxu
committed
fix bugs for sd3 controlnet training (#9489)
Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent 74c7462 commit c4e77f1

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

examples/controlnet/train_controlnet_sd3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import transformers
3232
from accelerate import Accelerator
3333
from accelerate.logging import get_logger
34-
from accelerate.utils import ProjectConfiguration, set_seed
34+
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
3535
from datasets import load_dataset
3636
from huggingface_hub import create_repo, upload_folder
3737
from packaging import version
@@ -899,12 +899,13 @@ def main(args):
899899
logging_dir = Path(args.output_dir, args.logging_dir)
900900

901901
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
902-
902+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
903903
accelerator = Accelerator(
904904
gradient_accumulation_steps=args.gradient_accumulation_steps,
905905
mixed_precision=args.mixed_precision,
906906
log_with=args.report_to,
907907
project_config=accelerator_project_config,
908+
kwargs_handlers=[kwargs],
908909
)
909910

910911
# Disable AMP for MPS.

0 commit comments

Comments
 (0)