File size: 5,761 Bytes
a03c9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import torch
from torch.utils.data import DistributedSampler
from torch.utils.data import Dataset, Sampler
from torch.utils.data import RandomSampler
from operator import itemgetter
from typing import List, Union, Iterator, Optional


class DatasetFromSampler(Dataset):
    """Dataset to create indexes from `Sampler`. From catalyst library.

    Args:
        sampler: PyTorch sampler
    """

    def __init__(self, sampler: Sampler):
        """Initialisation for DatasetFromSampler."""
        self.sampler = sampler
        self.sampler_list = None

    def __getitem__(self, index: int):
        """Gets element of the dataset.

        Args:
            index: index of the element in the dataset

        Returns:
            Single element by index
        """
        if self.sampler_list is None:
            self.sampler_list = list(self.sampler)
        return self.sampler_list[index]

    def __len__(self) -> int:
        """
        Returns:
            int: length of the dataset
        """
        return len(self.sampler)


class DistributedSamplerWrapper(DistributedSampler):
    """
    Wrapper over `Sampler` for distributed training.
    Allows you to use any sampler in distributed mode.
    From https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py

    It is especially useful in conjunction with
    `torch.nn.parallel.DistributedDataParallel`. In such case, each
    process can pass a DistributedSamplerWrapper instance as a DataLoader
    sampler, and load a subset of subsampled data of the original dataset
    that is exclusive to it.

    .. note::
        Sampler is assumed to be of constant size.
    """

    def __init__(
        self,
        sampler,
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = True,
    ):
        """

        Args:
            sampler: Sampler used for subsampling
            num_replicas (int, optional): Number of processes participating in
                distributed training
            rank (int, optional): Rank of the current process
                within ``num_replicas``
            shuffle (bool, optional): If true (default),
                sampler will shuffle the indices
        """
        super(DistributedSamplerWrapper, self).__init__(
            DatasetFromSampler(sampler),
            num_replicas=num_replicas,
            rank=rank,
            shuffle=shuffle,
        )
        self.sampler = sampler

    def __iter__(self) -> Iterator[int]:
        """Iterate over sampler.

        Returns:
            python iterator
        """
        self.dataset = DatasetFromSampler(self.sampler)
        indexes_of_indexes = super().__iter__()
        subsampler_indexes = self.dataset
        return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))


class UnimaxSampler(Sampler):
    # Initialize the sampler with the character counts for each language,
    # the total character budget, and the number of epochs per language.
    def __init__(self, language_character_counts: List[int], total_character_budget: int,
                 num_epochs: int) -> None:
        self.language_character_counts = torch.tensor(language_character_counts)
        self.total_character_budget = total_character_budget
        self.num_epochs = num_epochs
        # Compute the sampling distribution p.
        self.p = self._unimax()

    # Define how to iterate over the data. We'll use PyTorch's multinomial
    # function to generate indices according to the distribution p.
    def __iter__(self) -> iter:
        return iter(torch.multinomial(self.p, len(self.p), replacement=True).tolist())

    # Define the length of the sampler as the number of languages.
    def __len__(self) -> int:
        return len(self.p)

    # Implement the UNIMAX algorithm to compute the sampling distribution p.
    def _unimax(self) -> torch.Tensor:
        # Sort languages by character count.
        L, indices = torch.sort(self.language_character_counts)
        # Initialize the remaining budget to the total character budget.
        B = float(self.total_character_budget)
        i = 0
        # Initialize the budget per language.
        U = torch.zeros_like(L)
        # For each language...
        for idx in indices:
            # Compute the remaining budget per-language.
            bl = B / (len(L) - i)
            cl = L[idx]
            # If per-language budget exceeds N epochs of the language, use N epochs.
            if bl > cl * self.num_epochs:
                Ul = cl * self.num_epochs
            # Otherwise use uniform per-language budget.
            else:
                Ul = bl
            # Store the computed budget.
            U[idx] = Ul
            # Update the remaining budget.
            B -= Ul
            # Move to the next language.
            i += 1
        # Normalize the budget to create a distribution.
        p = U / U.sum()
        # Return the computed distribution.
        return p


class DistributedUnimaxSampler(UnimaxSampler):

    def __init__(self,
                 language_character_counts: List[int],
                 total_character_budget: int,
                 num_epochs: int,
                 num_replicas: Optional[int] = None,
                 rank: Optional[int] = None,
                 shuffle: bool = True) -> None:

        super().__init__(language_character_counts, total_character_budget, num_epochs)
        self.distributed_sampler = DistributedSamplerWrapper(self, num_replicas, rank, shuffle)

    def __iter__(self):
        return iter(self.distributed_sampler)

    def __len__(self):
        return len(self.distributed_sampler)

    def set_epoch(self, epoch):
        self.distributed_sampler.set_epoch(epoch)