YourMT3 / amt /src /extras /unimax_sampler /unimax_sampler.py
mimbres's picture
.
a03c9b4
import torch
from torch.utils.data import DistributedSampler
from torch.utils.data import Dataset, Sampler
from torch.utils.data import RandomSampler
from operator import itemgetter
from typing import List, Union, Iterator, Optional
class DatasetFromSampler(Dataset):
"""Dataset to create indexes from `Sampler`. From catalyst library.
Args:
sampler: PyTorch sampler
"""
def __init__(self, sampler: Sampler):
"""Initialisation for DatasetFromSampler."""
self.sampler = sampler
self.sampler_list = None
def __getitem__(self, index: int):
"""Gets element of the dataset.
Args:
index: index of the element in the dataset
Returns:
Single element by index
"""
if self.sampler_list is None:
self.sampler_list = list(self.sampler)
return self.sampler_list[index]
def __len__(self) -> int:
"""
Returns:
int: length of the dataset
"""
return len(self.sampler)
class DistributedSamplerWrapper(DistributedSampler):
"""
Wrapper over `Sampler` for distributed training.
Allows you to use any sampler in distributed mode.
From https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py
It is especially useful in conjunction with
`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSamplerWrapper instance as a DataLoader
sampler, and load a subset of subsampled data of the original dataset
that is exclusive to it.
.. note::
Sampler is assumed to be of constant size.
"""
def __init__(
self,
sampler,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
):
"""
Args:
sampler: Sampler used for subsampling
num_replicas (int, optional): Number of processes participating in
distributed training
rank (int, optional): Rank of the current process
within ``num_replicas``
shuffle (bool, optional): If true (default),
sampler will shuffle the indices
"""
super(DistributedSamplerWrapper, self).__init__(
DatasetFromSampler(sampler),
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
)
self.sampler = sampler
def __iter__(self) -> Iterator[int]:
"""Iterate over sampler.
Returns:
python iterator
"""
self.dataset = DatasetFromSampler(self.sampler)
indexes_of_indexes = super().__iter__()
subsampler_indexes = self.dataset
return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
class UnimaxSampler(Sampler):
# Initialize the sampler with the character counts for each language,
# the total character budget, and the number of epochs per language.
def __init__(self, language_character_counts: List[int], total_character_budget: int,
num_epochs: int) -> None:
self.language_character_counts = torch.tensor(language_character_counts)
self.total_character_budget = total_character_budget
self.num_epochs = num_epochs
# Compute the sampling distribution p.
self.p = self._unimax()
# Define how to iterate over the data. We'll use PyTorch's multinomial
# function to generate indices according to the distribution p.
def __iter__(self) -> iter:
return iter(torch.multinomial(self.p, len(self.p), replacement=True).tolist())
# Define the length of the sampler as the number of languages.
def __len__(self) -> int:
return len(self.p)
# Implement the UNIMAX algorithm to compute the sampling distribution p.
def _unimax(self) -> torch.Tensor:
# Sort languages by character count.
L, indices = torch.sort(self.language_character_counts)
# Initialize the remaining budget to the total character budget.
B = float(self.total_character_budget)
i = 0
# Initialize the budget per language.
U = torch.zeros_like(L)
# For each language...
for idx in indices:
# Compute the remaining budget per-language.
bl = B / (len(L) - i)
cl = L[idx]
# If per-language budget exceeds N epochs of the language, use N epochs.
if bl > cl * self.num_epochs:
Ul = cl * self.num_epochs
# Otherwise use uniform per-language budget.
else:
Ul = bl
# Store the computed budget.
U[idx] = Ul
# Update the remaining budget.
B -= Ul
# Move to the next language.
i += 1
# Normalize the budget to create a distribution.
p = U / U.sum()
# Return the computed distribution.
return p
class DistributedUnimaxSampler(UnimaxSampler):
def __init__(self,
language_character_counts: List[int],
total_character_budget: int,
num_epochs: int,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True) -> None:
super().__init__(language_character_counts, total_character_budget, num_epochs)
self.distributed_sampler = DistributedSamplerWrapper(self, num_replicas, rank, shuffle)
def __iter__(self):
return iter(self.distributed_sampler)
def __len__(self):
return len(self.distributed_sampler)
def set_epoch(self, epoch):
self.distributed_sampler.set_epoch(epoch)