File size: 5,011 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
Wrap any batch sampler for distributed training

Authors:
  * Leo 2022
"""

import logging
from copy import deepcopy
from typing import Iterator, Optional, TypeVar

import torch.distributed as dist
from torch.utils.data import BatchSampler

T_co = TypeVar("T_co", covariant=True)
logger = logging.getLogger(__name__)

__all__ = [
    "DistributedBatchSamplerWrapper",
]


class DistributedBatchSamplerWrapper:
    def __init__(
        self,
        batch_sampler: BatchSampler,
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        allow_duplicates: bool = False,
        allow_uneven: bool = False,
    ) -> None:
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                "Invalid rank {}, rank should be in the interval"
                " [0, {}]".format(rank, num_replicas - 1)
            )
        self.batch_sampler = batch_sampler
        self.num_replicas = num_replicas
        self.rank = rank
        self.allow_duplicates = allow_duplicates
        self.allow_uneven = allow_uneven

    def __iter__(self) -> Iterator[T_co]:
        logger.info(
            f"Building distributed batch sampler for rank={self.rank}, world_size={self.num_replicas}"
        )

        all_rank_batch_indices = list(iter(self.batch_sampler))
        if len(all_rank_batch_indices) % self.num_replicas == 0:
            target_batch_indices = all_rank_batch_indices
        else:
            num_to_halve = (
                self.num_replicas - len(all_rank_batch_indices) % self.num_replicas
            )
            flatten_batch_indices = deepcopy(all_rank_batch_indices)
            while num_to_halve > 0:
                newly_flatten = []
                all_cant_be_halved = True
                for indices in flatten_batch_indices:
                    if num_to_halve > 0 and len(indices) > 1:
                        indices1, indices2 = (
                            indices[: len(indices) // 2],
                            indices[len(indices) // 2 :],
                        )
                        newly_flatten += [indices1, indices2]
                        num_to_halve -= 1
                        all_cant_be_halved = False
                    else:
                        newly_flatten.append(indices)
                flatten_batch_indices = deepcopy(newly_flatten)

                if all_cant_be_halved:
                    if self.allow_duplicates:
                        logger.warning(
                            "To ensure all the dataloaders in different processes get the same number "
                            "of batches. Some batches are duplicated. This must not happen during the "
                            "evaluation stage."
                        )
                        flatten_batch_indices = (
                            flatten_batch_indices
                            + all_rank_batch_indices[:num_to_halve]
                        )
                    elif self.allow_uneven:
                        logger.warning(
                            "Total batches will not be evenly distributed across the dataloaders in "
                            "different processes. This must not happen during the training stage and "
                            "can lead to hanging, while might be okay during the evaluation stage."
                        )
                    else:
                        raise ValueError(
                            "The provided batch sampler cannot be safely wrapped for distributed training. "
                            "Please try increase the number of indices in each batch. Or, allowing duplicated "
                            "batches or uneven number of batches across dataloaders."
                        )
            target_batch_indices = flatten_batch_indices

        if not self.allow_uneven:
            assert len(target_batch_indices) % self.num_replicas == 0

        batch_indices = target_batch_indices[self.rank :: self.num_replicas]
        return iter(batch_indices)

    def __len__(self) -> int:
        # Since the total number of batches dynamically depends on the current epoch,
        # instead of pre-compute it which will duplicate the batch number computation logic,
        # it makes no harm to simply re-compute it with __iter__ for every call, since
        # __len__ is usually not frequently called and won't be the performance bottleneck
        return len(list(iter(self)))

    def set_epoch(self, epoch: int) -> None:
        if hasattr(self.batch_sampler, "set_epoch"):
            self.batch_sampler.set_epoch(epoch)