Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 5 additions & 32 deletions catalyst/core/_misc.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,13 @@
from typing import Dict, List, Union
from collections import OrderedDict
from copy import copy
import warnings

from torch.utils.data import DataLoader, DistributedSampler

from catalyst.core.callback import Callback, CallbackNode, CallbackWrapper
from catalyst.data.ddp_loader import prepare_ddp_loader
from catalyst.data.sampler import DistributedSamplerWrapper
from catalyst.utils.distributed import get_rank


def _force_make_distributed_loader(loader: DataLoader) -> DataLoader:
"""
Transfers loader to distributed mode. Experimental feature.

Args:
loader: pytorch dataloder

Returns:
DataLoader: pytorch dataloder with distributed sampler.
"""
sampler = (
DistributedSampler(dataset=loader.dataset)
if getattr(loader, "sampler", None) is not None
else DistributedSamplerWrapper(sampler=loader.sampler)
)
loader = DataLoader(
dataset=copy(loader.dataset),
batch_size=loader.batch_size,
# shuffle=loader.shuffle,
sampler=sampler,
# batch_sampler=loader.batch_sampler,
num_workers=loader.num_workers,
# collate_fn=loader.collate_fn,
pin_memory=loader.pin_memory,
drop_last=loader.drop_last,
)
return loader
from catalyst.utils.distributed import get_rank, get_world_size


def validate_loaders(loaders: Dict[str, DataLoader]) -> Dict[str, DataLoader]:
Expand All @@ -61,7 +32,9 @@ def validate_loaders(loaders: Dict[str, DataLoader]) -> Dict[str, DataLoader]:
"you need ``DistributedSampler`` for your ``DataLoader``."
"Transferring to distributed mode. (Experimental feature)"
)
loaders[key] = _force_make_distributed_loader(value)
loaders[key] = prepare_ddp_loader(
value, num_processes=get_world_size(), process_index=rank
)
return loaders


Expand Down
1 change: 1 addition & 0 deletions catalyst/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# flake8: noqa
from catalyst.data.dataset import DatasetFromSampler, SelfSupervisedDatasetWrapper
from catalyst.data.ddp_loader import BatchSamplerShard, prepare_ddp_loader
from catalyst.data.loader import (
ILoaderWrapper,
BatchLimitLoaderWrapper,
Expand Down
153 changes: 153 additions & 0 deletions catalyst/data/ddp_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from torch.utils.data import BatchSampler, DataLoader, IterableDataset

# kwargs of the DataLoader in min version 1.3.0.
_PYTORCH_DATALOADER_KWARGS = {
"batch_size": 1,
"shuffle": False,
"sampler": None,
"batch_sampler": None,
"num_workers": 0,
"collate_fn": None,
"pin_memory": False,
"drop_last": False,
"timeout": 0,
"worker_init_fn": None,
"multiprocessing_context": None,
}


# Heavily based on HuggingFace accelerate project internals.
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/data_loader.py
class BatchSamplerShard(BatchSampler):
"""
Wraps a PyTorch :obj:`BatchSampler` to generate batches for one of the processes only. Instances of this class will

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
E501 line too long (119 > 99 characters)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
W505 doc line too long (119 > 99 characters)

always yield a number of batches that is a round multiple of :obj:`num_processes` and that all have the same size.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
E501 line too long (118 > 99 characters)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
W505 doc line too long (118 > 99 characters)

Depending on the value of the :obj:`drop_last` attribute of the batch sampler passed, it will either stop the

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
E501 line too long (113 > 99 characters)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
W505 doc line too long (113 > 99 characters)

iteration at the first batch that would be too small / not present on all processes or loop with indices from the

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
E501 line too long (117 > 99 characters)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
W505 doc line too long (117 > 99 characters)

beginning.

Args:
batch_sampler (:obj:`torch.utils.data.sampler.BatchSampler`):
The batch sampler to split in several shards.
num_processes (:obj:`int`, `optional`, defaults to 1):
The number of processes running concurrently.
process_index (:obj:`int`, `optional`, defaults to 0):
The index of the current process.

.. warning::

This does not support :obj:`BatchSampler` with varying batch size yet.
"""

def __init__(
self,
batch_sampler: BatchSampler,
num_processes: int = 1,
process_index: int = 0,
):
"""Init."""
self.batch_sampler = batch_sampler
self.num_processes = num_processes
self.process_index = process_index
self.batch_size = batch_sampler.batch_size
self.drop_last = batch_sampler.drop_last

def __len__(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
D105 Missing docstring in magic method

if len(self.batch_sampler) % self.num_processes == 0:
return len(self.batch_sampler) // self.num_processes
length = len(self.batch_sampler) // self.num_processes
return length if self.drop_last else length + 1

def __iter__(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
D105 Missing docstring in magic method

initial_data = []
batch_to_yield = []
for idx, batch in enumerate(self.batch_sampler):
# We gather the initial indices in case we need to circle back at the end.
if not self.drop_last and idx < self.num_processes:
initial_data += batch
# We identify the batch to yield but wait until we ar sure every process gets a full batch before actually

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
E501 line too long (118 > 99 characters)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
W505 doc line too long (118 > 99 characters)

# yielding it.
if idx % self.num_processes == self.process_index:
batch_to_yield = batch
if (
idx % self.num_processes == self.num_processes - 1
and len(batch) == self.batch_size
):
yield batch_to_yield
batch_to_yield = []

# If drop_last is True, iteration is over, otherwise...
if not self.drop_last and len(initial_data) > 0:
# ... we yield the complete batch we had saved before if it has the proper length
if len(batch_to_yield) == self.batch_size:
yield batch_to_yield

# For degenerate cases where the dataset has less than num_process * batch_size samples
while len(initial_data) < self.num_processes * self.batch_size:
initial_data += initial_data

# If the last batch seen was of the proper size, it has been yielded by its process so we move to the next

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
E501 line too long (118 > 99 characters)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
W505 doc line too long (118 > 99 characters)

if len(batch) == self.batch_size:
batch = []
idx += 1

# Make sure we yield a multiple of self.num_processes batches
cycle_index = 0
while idx % self.num_processes != 0 or len(batch) > 0:
end_index = cycle_index + self.batch_size - len(batch)
batch += initial_data[cycle_index:end_index]
if idx % self.num_processes == self.process_index:
yield batch
cycle_index = end_index
batch = []
idx += 1


def prepare_ddp_loader(loader: DataLoader, num_processes: int, process_index: int) -> DataLoader:
"""
Transfers loader to distributed mode. Experimental feature.

Args:
loader: pytorch dataloder
num_processes (:obj:`int`, `optional`, defaults to 1):
The number of processes running concurrently.
process_index (:obj:`int`, `optional`, defaults to 0):
The index of the current process.

Returns:
DataLoader: pytorch dataloder with distributed batch sampler.
"""
ddp_dataset = loader.dataset
# Iterable dataset doesn't like batch_sampler, but DataLoader creates a default one for it
if isinstance(ddp_dataset, IterableDataset):
ddp_batch_sampler = None
else:
ddp_batch_sampler = BatchSamplerShard(
loader.batch_sampler,
num_processes=num_processes,
process_index=process_index,
)
# We ignore all of those since they are all dealt with by our new_batch_sampler
ignore_kwargs = [
"batch_size",
"shuffle",
"sampler",
"batch_sampler",
"drop_last",
"generator",
]
kwargs = {
k: getattr(loader, k, _PYTORCH_DATALOADER_KWARGS[k])
for k in _PYTORCH_DATALOADER_KWARGS
if k not in ignore_kwargs
}
# Need to provide batch_size as batch_sampler is None for Iterable dataset
if ddp_batch_sampler is None:
kwargs["drop_last"] = loader.drop_last
kwargs["batch_size"] = loader.batch_size

loader = DataLoader(dataset=ddp_dataset, batch_sampler=ddp_batch_sampler, **kwargs)
return loader


__all__ = [BatchSamplerShard, prepare_ddp_loader]