Spaces:
Sleeping
Sleeping
import torch | |
from torch.utils.data import DistributedSampler | |
from torch.utils.data import Dataset, Sampler | |
from torch.utils.data import RandomSampler | |
from operator import itemgetter | |
from typing import List, Union, Iterator, Optional | |
class DatasetFromSampler(Dataset): | |
"""Dataset to create indexes from `Sampler`. From catalyst library. | |
Args: | |
sampler: PyTorch sampler | |
""" | |
def __init__(self, sampler: Sampler): | |
"""Initialisation for DatasetFromSampler.""" | |
self.sampler = sampler | |
self.sampler_list = None | |
def __getitem__(self, index: int): | |
"""Gets element of the dataset. | |
Args: | |
index: index of the element in the dataset | |
Returns: | |
Single element by index | |
""" | |
if self.sampler_list is None: | |
self.sampler_list = list(self.sampler) | |
return self.sampler_list[index] | |
def __len__(self) -> int: | |
""" | |
Returns: | |
int: length of the dataset | |
""" | |
return len(self.sampler) | |
class DistributedSamplerWrapper(DistributedSampler): | |
""" | |
Wrapper over `Sampler` for distributed training. | |
Allows you to use any sampler in distributed mode. | |
From https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py | |
It is especially useful in conjunction with | |
`torch.nn.parallel.DistributedDataParallel`. In such case, each | |
process can pass a DistributedSamplerWrapper instance as a DataLoader | |
sampler, and load a subset of subsampled data of the original dataset | |
that is exclusive to it. | |
.. note:: | |
Sampler is assumed to be of constant size. | |
""" | |
def __init__( | |
self, | |
sampler, | |
num_replicas: Optional[int] = None, | |
rank: Optional[int] = None, | |
shuffle: bool = True, | |
): | |
""" | |
Args: | |
sampler: Sampler used for subsampling | |
num_replicas (int, optional): Number of processes participating in | |
distributed training | |
rank (int, optional): Rank of the current process | |
within ``num_replicas`` | |
shuffle (bool, optional): If true (default), | |
sampler will shuffle the indices | |
""" | |
super(DistributedSamplerWrapper, self).__init__( | |
DatasetFromSampler(sampler), | |
num_replicas=num_replicas, | |
rank=rank, | |
shuffle=shuffle, | |
) | |
self.sampler = sampler | |
def __iter__(self) -> Iterator[int]: | |
"""Iterate over sampler. | |
Returns: | |
python iterator | |
""" | |
self.dataset = DatasetFromSampler(self.sampler) | |
indexes_of_indexes = super().__iter__() | |
subsampler_indexes = self.dataset | |
return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) | |
class UnimaxSampler(Sampler): | |
# Initialize the sampler with the character counts for each language, | |
# the total character budget, and the number of epochs per language. | |
def __init__(self, language_character_counts: List[int], total_character_budget: int, | |
num_epochs: int) -> None: | |
self.language_character_counts = torch.tensor(language_character_counts) | |
self.total_character_budget = total_character_budget | |
self.num_epochs = num_epochs | |
# Compute the sampling distribution p. | |
self.p = self._unimax() | |
# Define how to iterate over the data. We'll use PyTorch's multinomial | |
# function to generate indices according to the distribution p. | |
def __iter__(self) -> iter: | |
return iter(torch.multinomial(self.p, len(self.p), replacement=True).tolist()) | |
# Define the length of the sampler as the number of languages. | |
def __len__(self) -> int: | |
return len(self.p) | |
# Implement the UNIMAX algorithm to compute the sampling distribution p. | |
def _unimax(self) -> torch.Tensor: | |
# Sort languages by character count. | |
L, indices = torch.sort(self.language_character_counts) | |
# Initialize the remaining budget to the total character budget. | |
B = float(self.total_character_budget) | |
i = 0 | |
# Initialize the budget per language. | |
U = torch.zeros_like(L) | |
# For each language... | |
for idx in indices: | |
# Compute the remaining budget per-language. | |
bl = B / (len(L) - i) | |
cl = L[idx] | |
# If per-language budget exceeds N epochs of the language, use N epochs. | |
if bl > cl * self.num_epochs: | |
Ul = cl * self.num_epochs | |
# Otherwise use uniform per-language budget. | |
else: | |
Ul = bl | |
# Store the computed budget. | |
U[idx] = Ul | |
# Update the remaining budget. | |
B -= Ul | |
# Move to the next language. | |
i += 1 | |
# Normalize the budget to create a distribution. | |
p = U / U.sum() | |
# Return the computed distribution. | |
return p | |
class DistributedUnimaxSampler(UnimaxSampler): | |
def __init__(self, | |
language_character_counts: List[int], | |
total_character_budget: int, | |
num_epochs: int, | |
num_replicas: Optional[int] = None, | |
rank: Optional[int] = None, | |
shuffle: bool = True) -> None: | |
super().__init__(language_character_counts, total_character_budget, num_epochs) | |
self.distributed_sampler = DistributedSamplerWrapper(self, num_replicas, rank, shuffle) | |
def __iter__(self): | |
return iter(self.distributed_sampler) | |
def __len__(self): | |
return len(self.distributed_sampler) | |
def set_epoch(self, epoch): | |
self.distributed_sampler.set_epoch(epoch) |