Skip to content

Commit 3f18d6b

Browse files
authored
Engines bn fix (#1310)
* fix * fix * version upadte
1 parent bd3a0d7 commit 3f18d6b

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

catalyst/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "21.09rc1"
1+
__version__ = "21.09"

catalyst/engines/apex.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ class DistributedDataParallelAPEXEngine(DistributedDataParallelEngine):
370370
address: address to use for backend.
371371
port: port to use for backend.
372372
sync_bn: boolean flag for batchnorm synchonization during disributed training.
373-
if True, applies PyTorch `convert_sync_batchnorm`_ to the model for native torch
373+
if True, applies Apex `convert_syncbn_model`_ to the model for native torch
374374
distributed only. Default, False.
375375
ddp_kwargs: parameters for `apex.parallel.DistributedDataParallel`.
376376
More info here:
@@ -439,9 +439,8 @@ def get_engine(self):
439439
stages:
440440
...
441441
442-
.. _convert_sync_batchnorm:
443-
https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html#
444-
torch.nn.SyncBatchNorm.convert_sync_batchnorm
442+
.. _`convert_syncbn_model`:
443+
https://nvidia.github.io/apex/parallel.html#apex.parallel.convert_syncbn_model
445444
"""
446445

447446
def __init__(
@@ -501,7 +500,7 @@ def init_components(
501500
model = model_fn()
502501
model = self.sync_device(model)
503502
if self._sync_bn:
504-
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
503+
model = apex.parallel.convert_syncbn_model(model)
505504

506505
criterion = criterion_fn()
507506
criterion = self.sync_device(criterion)

catalyst/engines/fairscale.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ def __init__(
340340
self,
341341
address: str = None,
342342
port: Union[str, int] = None,
343+
sync_bn: bool = False,
343344
ddp_kwargs: Dict[str, Any] = None,
344345
process_group_kwargs: Dict[str, Any] = None,
345346
scaler_kwargs: Dict[str, Any] = None,
@@ -348,6 +349,7 @@ def __init__(
348349
super().__init__(
349350
address=address,
350351
port=port,
352+
sync_bn=sync_bn,
351353
ddp_kwargs=ddp_kwargs,
352354
process_group_kwargs=process_group_kwargs,
353355
)

0 commit comments

Comments
 (0)