Spaces:
Sleeping
Sleeping
File size: 5,761 Bytes
a03c9b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import torch
from torch.utils.data import DistributedSampler
from torch.utils.data import Dataset, Sampler
from torch.utils.data import RandomSampler
from operator import itemgetter
from typing import List, Union, Iterator, Optional
class DatasetFromSampler(Dataset):
"""Dataset to create indexes from `Sampler`. From catalyst library.
Args:
sampler: PyTorch sampler
"""
def __init__(self, sampler: Sampler):
"""Initialisation for DatasetFromSampler."""
self.sampler = sampler
self.sampler_list = None
def __getitem__(self, index: int):
"""Gets element of the dataset.
Args:
index: index of the element in the dataset
Returns:
Single element by index
"""
if self.sampler_list is None:
self.sampler_list = list(self.sampler)
return self.sampler_list[index]
def __len__(self) -> int:
"""
Returns:
int: length of the dataset
"""
return len(self.sampler)
class DistributedSamplerWrapper(DistributedSampler):
"""
Wrapper over `Sampler` for distributed training.
Allows you to use any sampler in distributed mode.
From https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py
It is especially useful in conjunction with
`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSamplerWrapper instance as a DataLoader
sampler, and load a subset of subsampled data of the original dataset
that is exclusive to it.
.. note::
Sampler is assumed to be of constant size.
"""
def __init__(
self,
sampler,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
):
"""
Args:
sampler: Sampler used for subsampling
num_replicas (int, optional): Number of processes participating in
distributed training
rank (int, optional): Rank of the current process
within ``num_replicas``
shuffle (bool, optional): If true (default),
sampler will shuffle the indices
"""
super(DistributedSamplerWrapper, self).__init__(
DatasetFromSampler(sampler),
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
)
self.sampler = sampler
def __iter__(self) -> Iterator[int]:
"""Iterate over sampler.
Returns:
python iterator
"""
self.dataset = DatasetFromSampler(self.sampler)
indexes_of_indexes = super().__iter__()
subsampler_indexes = self.dataset
return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
class UnimaxSampler(Sampler):
# Initialize the sampler with the character counts for each language,
# the total character budget, and the number of epochs per language.
def __init__(self, language_character_counts: List[int], total_character_budget: int,
num_epochs: int) -> None:
self.language_character_counts = torch.tensor(language_character_counts)
self.total_character_budget = total_character_budget
self.num_epochs = num_epochs
# Compute the sampling distribution p.
self.p = self._unimax()
# Define how to iterate over the data. We'll use PyTorch's multinomial
# function to generate indices according to the distribution p.
def __iter__(self) -> iter:
return iter(torch.multinomial(self.p, len(self.p), replacement=True).tolist())
# Define the length of the sampler as the number of languages.
def __len__(self) -> int:
return len(self.p)
# Implement the UNIMAX algorithm to compute the sampling distribution p.
def _unimax(self) -> torch.Tensor:
# Sort languages by character count.
L, indices = torch.sort(self.language_character_counts)
# Initialize the remaining budget to the total character budget.
B = float(self.total_character_budget)
i = 0
# Initialize the budget per language.
U = torch.zeros_like(L)
# For each language...
for idx in indices:
# Compute the remaining budget per-language.
bl = B / (len(L) - i)
cl = L[idx]
# If per-language budget exceeds N epochs of the language, use N epochs.
if bl > cl * self.num_epochs:
Ul = cl * self.num_epochs
# Otherwise use uniform per-language budget.
else:
Ul = bl
# Store the computed budget.
U[idx] = Ul
# Update the remaining budget.
B -= Ul
# Move to the next language.
i += 1
# Normalize the budget to create a distribution.
p = U / U.sum()
# Return the computed distribution.
return p
class DistributedUnimaxSampler(UnimaxSampler):
def __init__(self,
language_character_counts: List[int],
total_character_budget: int,
num_epochs: int,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True) -> None:
super().__init__(language_character_counts, total_character_budget, num_epochs)
self.distributed_sampler = DistributedSamplerWrapper(self, num_replicas, rank, shuffle)
def __iter__(self):
return iter(self.distributed_sampler)
def __len__(self):
return len(self.distributed_sampler)
def set_epoch(self, epoch):
self.distributed_sampler.set_epoch(epoch) |