File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change 3131import transformers
3232from accelerate import Accelerator
3333from accelerate .logging import get_logger
34- from accelerate .utils import ProjectConfiguration , set_seed
34+ from accelerate .utils import DistributedDataParallelKwargs , ProjectConfiguration , set_seed
3535from datasets import load_dataset
3636from huggingface_hub import create_repo , upload_folder
3737from packaging import version
@@ -899,12 +899,13 @@ def main(args):
899899 logging_dir = Path (args .output_dir , args .logging_dir )
900900
901901 accelerator_project_config = ProjectConfiguration (project_dir = args .output_dir , logging_dir = logging_dir )
902-
902+ kwargs = DistributedDataParallelKwargs ( find_unused_parameters = True )
903903 accelerator = Accelerator (
904904 gradient_accumulation_steps = args .gradient_accumulation_steps ,
905905 mixed_precision = args .mixed_precision ,
906906 log_with = args .report_to ,
907907 project_config = accelerator_project_config ,
908+ kwargs_handlers = [kwargs ],
908909 )
909910
910911 # Disable AMP for MPS.
You can’t perform that action at this time.
0 commit comments