File tree Expand file tree Collapse file tree 3 files changed +7
-6
lines changed Expand file tree Collapse file tree 3 files changed +7
-6
lines changed Original file line number Diff line number Diff line change 1
- __version__ = "21.09rc1 "
1
+ __version__ = "21.09 "
Original file line number Diff line number Diff line change @@ -370,7 +370,7 @@ class DistributedDataParallelAPEXEngine(DistributedDataParallelEngine):
370
370
address: address to use for backend.
371
371
port: port to use for backend.
372
372
sync_bn: boolean flag for batchnorm synchonization during disributed training.
373
- if True, applies PyTorch `convert_sync_batchnorm `_ to the model for native torch
373
+ if True, applies Apex `convert_syncbn_model `_ to the model for native torch
374
374
distributed only. Default, False.
375
375
ddp_kwargs: parameters for `apex.parallel.DistributedDataParallel`.
376
376
More info here:
@@ -439,9 +439,8 @@ def get_engine(self):
439
439
stages:
440
440
...
441
441
442
- .. _convert_sync_batchnorm:
443
- https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html#
444
- torch.nn.SyncBatchNorm.convert_sync_batchnorm
442
+ .. _`convert_syncbn_model`:
443
+ https://nvidia.github.io/apex/parallel.html#apex.parallel.convert_syncbn_model
445
444
"""
446
445
447
446
def __init__ (
@@ -501,7 +500,7 @@ def init_components(
501
500
model = model_fn ()
502
501
model = self .sync_device (model )
503
502
if self ._sync_bn :
504
- model = nn . SyncBatchNorm . convert_sync_batchnorm (model )
503
+ model = apex . parallel . convert_syncbn_model (model )
505
504
506
505
criterion = criterion_fn ()
507
506
criterion = self .sync_device (criterion )
Original file line number Diff line number Diff line change @@ -340,6 +340,7 @@ def __init__(
340
340
self ,
341
341
address : str = None ,
342
342
port : Union [str , int ] = None ,
343
+ sync_bn : bool = False ,
343
344
ddp_kwargs : Dict [str , Any ] = None ,
344
345
process_group_kwargs : Dict [str , Any ] = None ,
345
346
scaler_kwargs : Dict [str , Any ] = None ,
@@ -348,6 +349,7 @@ def __init__(
348
349
super ().__init__ (
349
350
address = address ,
350
351
port = port ,
352
+ sync_bn = sync_bn ,
351
353
ddp_kwargs = ddp_kwargs ,
352
354
process_group_kwargs = process_group_kwargs ,
353
355
)
You can’t perform that action at this time.
0 commit comments