|
""" |
|
Wrap any batch sampler for distributed training |
|
|
|
Authors: |
|
* Leo 2022 |
|
""" |
|
|
|
import logging |
|
from copy import deepcopy |
|
from typing import Iterator, Optional, TypeVar |
|
|
|
import torch.distributed as dist |
|
from torch.utils.data import BatchSampler |
|
|
|
T_co = TypeVar("T_co", covariant=True) |
|
logger = logging.getLogger(__name__) |
|
|
|
__all__ = [ |
|
"DistributedBatchSamplerWrapper", |
|
] |
|
|
|
|
|
class DistributedBatchSamplerWrapper: |
|
def __init__( |
|
self, |
|
batch_sampler: BatchSampler, |
|
num_replicas: Optional[int] = None, |
|
rank: Optional[int] = None, |
|
allow_duplicates: bool = False, |
|
allow_uneven: bool = False, |
|
) -> None: |
|
if num_replicas is None: |
|
if not dist.is_available(): |
|
raise RuntimeError("Requires distributed package to be available") |
|
num_replicas = dist.get_world_size() |
|
if rank is None: |
|
if not dist.is_available(): |
|
raise RuntimeError("Requires distributed package to be available") |
|
rank = dist.get_rank() |
|
if rank >= num_replicas or rank < 0: |
|
raise ValueError( |
|
"Invalid rank {}, rank should be in the interval" |
|
" [0, {}]".format(rank, num_replicas - 1) |
|
) |
|
self.batch_sampler = batch_sampler |
|
self.num_replicas = num_replicas |
|
self.rank = rank |
|
self.allow_duplicates = allow_duplicates |
|
self.allow_uneven = allow_uneven |
|
|
|
def __iter__(self) -> Iterator[T_co]: |
|
logger.info( |
|
f"Building distributed batch sampler for rank={self.rank}, world_size={self.num_replicas}" |
|
) |
|
|
|
all_rank_batch_indices = list(iter(self.batch_sampler)) |
|
if len(all_rank_batch_indices) % self.num_replicas == 0: |
|
target_batch_indices = all_rank_batch_indices |
|
else: |
|
num_to_halve = ( |
|
self.num_replicas - len(all_rank_batch_indices) % self.num_replicas |
|
) |
|
flatten_batch_indices = deepcopy(all_rank_batch_indices) |
|
while num_to_halve > 0: |
|
newly_flatten = [] |
|
all_cant_be_halved = True |
|
for indices in flatten_batch_indices: |
|
if num_to_halve > 0 and len(indices) > 1: |
|
indices1, indices2 = ( |
|
indices[: len(indices) // 2], |
|
indices[len(indices) // 2 :], |
|
) |
|
newly_flatten += [indices1, indices2] |
|
num_to_halve -= 1 |
|
all_cant_be_halved = False |
|
else: |
|
newly_flatten.append(indices) |
|
flatten_batch_indices = deepcopy(newly_flatten) |
|
|
|
if all_cant_be_halved: |
|
if self.allow_duplicates: |
|
logger.warning( |
|
"To ensure all the dataloaders in different processes get the same number " |
|
"of batches. Some batches are duplicated. This must not happen during the " |
|
"evaluation stage." |
|
) |
|
flatten_batch_indices = ( |
|
flatten_batch_indices |
|
+ all_rank_batch_indices[:num_to_halve] |
|
) |
|
elif self.allow_uneven: |
|
logger.warning( |
|
"Total batches will not be evenly distributed across the dataloaders in " |
|
"different processes. This must not happen during the training stage and " |
|
"can lead to hanging, while might be okay during the evaluation stage." |
|
) |
|
else: |
|
raise ValueError( |
|
"The provided batch sampler cannot be safely wrapped for distributed training. " |
|
"Please try increase the number of indices in each batch. Or, allowing duplicated " |
|
"batches or uneven number of batches across dataloaders." |
|
) |
|
target_batch_indices = flatten_batch_indices |
|
|
|
if not self.allow_uneven: |
|
assert len(target_batch_indices) % self.num_replicas == 0 |
|
|
|
batch_indices = target_batch_indices[self.rank :: self.num_replicas] |
|
return iter(batch_indices) |
|
|
|
def __len__(self) -> int: |
|
|
|
|
|
|
|
|
|
return len(list(iter(self))) |
|
|
|
def set_epoch(self, epoch: int) -> None: |
|
if hasattr(self.batch_sampler, "set_epoch"): |
|
self.batch_sampler.set_epoch(epoch) |
|
|