|
|
|
import torch |
|
from torch import Tensor |
|
|
|
from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union |
|
|
|
__all__ = [ |
|
"BatchSampler", |
|
"RandomSampler", |
|
"Sampler", |
|
"SequentialSampler", |
|
"SubsetRandomSampler", |
|
"WeightedRandomSampler", |
|
] |
|
|
|
T_co = TypeVar('T_co', covariant=True) |
|
|
|
|
|
class Sampler(Generic[T_co]): |
|
r"""Base class for all Samplers. |
|
|
|
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a |
|
way to iterate over indices or lists of indices (batches) of dataset elements, |
|
and may provide a :meth:`__len__` method that returns the length of the returned iterators. |
|
|
|
Args: |
|
data_source (Dataset): This argument is not used and will be removed in 2.2.0. |
|
You may still have custom implementation that utilizes it. |
|
|
|
Example: |
|
>>> # xdoctest: +SKIP |
|
>>> class AccedingSequenceLengthSampler(Sampler[int]): |
|
>>> def __init__(self, data: List[str]) -> None: |
|
>>> self.data = data |
|
>>> |
|
>>> def __len__(self) -> int: |
|
>>> return len(self.data) |
|
>>> |
|
>>> def __iter__(self) -> Iterator[int]: |
|
>>> sizes = torch.tensor([len(x) for x in self.data]) |
|
>>> yield from torch.argsort(sizes).tolist() |
|
>>> |
|
>>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): |
|
>>> def __init__(self, data: List[str], batch_size: int) -> None: |
|
>>> self.data = data |
|
>>> self.batch_size = batch_size |
|
>>> |
|
>>> def __len__(self) -> int: |
|
>>> return (len(self.data) + self.batch_size - 1) // self.batch_size |
|
>>> |
|
>>> def __iter__(self) -> Iterator[List[int]]: |
|
>>> sizes = torch.tensor([len(x) for x in self.data]) |
|
>>> for batch in torch.chunk(torch.argsort(sizes), len(self)): |
|
>>> yield batch.tolist() |
|
|
|
.. note:: The :meth:`__len__` method isn't strictly required by |
|
:class:`~torch.utils.data.DataLoader`, but is expected in any |
|
calculation involving the length of a :class:`~torch.utils.data.DataLoader`. |
|
""" |
|
|
|
def __init__(self, data_source: Optional[Sized] = None) -> None: |
|
if data_source is not None: |
|
import warnings |
|
|
|
warnings.warn("`data_source` argument is not used and will be removed in 2.2.0." |
|
"You may still have custom implementation that utilizes it.") |
|
|
|
def __iter__(self) -> Iterator[T_co]: |
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SequentialSampler(Sampler[int]): |
|
r"""Samples elements sequentially, always in the same order. |
|
|
|
Args: |
|
data_source (Dataset): dataset to sample from |
|
""" |
|
|
|
data_source: Sized |
|
|
|
def __init__(self, data_source: Sized) -> None: |
|
self.data_source = data_source |
|
|
|
def __iter__(self) -> Iterator[int]: |
|
return iter(range(len(self.data_source))) |
|
|
|
def __len__(self) -> int: |
|
return len(self.data_source) |
|
|
|
|
|
class RandomSampler(Sampler[int]): |
|
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. |
|
|
|
If with replacement, then user can specify :attr:`num_samples` to draw. |
|
|
|
Args: |
|
data_source (Dataset): dataset to sample from |
|
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` |
|
num_samples (int): number of samples to draw, default=`len(dataset)`. |
|
generator (Generator): Generator used in sampling. |
|
""" |
|
|
|
data_source: Sized |
|
replacement: bool |
|
|
|
def __init__(self, data_source: Sized, replacement: bool = False, |
|
num_samples: Optional[int] = None, generator=None) -> None: |
|
self.data_source = data_source |
|
self.replacement = replacement |
|
self._num_samples = num_samples |
|
self.generator = generator |
|
|
|
if not isinstance(self.replacement, bool): |
|
raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}") |
|
|
|
if not isinstance(self.num_samples, int) or self.num_samples <= 0: |
|
raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}") |
|
|
|
@property |
|
def num_samples(self) -> int: |
|
|
|
if self._num_samples is None: |
|
return len(self.data_source) |
|
return self._num_samples |
|
|
|
def __iter__(self) -> Iterator[int]: |
|
n = len(self.data_source) |
|
if self.generator is None: |
|
seed = int(torch.empty((), dtype=torch.int64).random_().item()) |
|
generator = torch.Generator() |
|
generator.manual_seed(seed) |
|
else: |
|
generator = self.generator |
|
|
|
if self.replacement: |
|
for _ in range(self.num_samples // 32): |
|
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() |
|
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() |
|
else: |
|
for _ in range(self.num_samples // n): |
|
yield from torch.randperm(n, generator=generator).tolist() |
|
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n] |
|
|
|
def __len__(self) -> int: |
|
return self.num_samples |
|
|
|
|
|
class SubsetRandomSampler(Sampler[int]): |
|
r"""Samples elements randomly from a given list of indices, without replacement. |
|
|
|
Args: |
|
indices (sequence): a sequence of indices |
|
generator (Generator): Generator used in sampling. |
|
""" |
|
|
|
indices: Sequence[int] |
|
|
|
def __init__(self, indices: Sequence[int], generator=None) -> None: |
|
self.indices = indices |
|
self.generator = generator |
|
|
|
def __iter__(self) -> Iterator[int]: |
|
for i in torch.randperm(len(self.indices), generator=self.generator): |
|
yield self.indices[i] |
|
|
|
def __len__(self) -> int: |
|
return len(self.indices) |
|
|
|
|
|
class WeightedRandomSampler(Sampler[int]): |
|
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). |
|
|
|
Args: |
|
weights (sequence) : a sequence of weights, not necessary summing up to one |
|
num_samples (int): number of samples to draw |
|
replacement (bool): if ``True``, samples are drawn with replacement. |
|
If not, they are drawn without replacement, which means that when a |
|
sample index is drawn for a row, it cannot be drawn again for that row. |
|
generator (Generator): Generator used in sampling. |
|
|
|
Example: |
|
>>> # xdoctest: +IGNORE_WANT("non-deterministic") |
|
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) |
|
[4, 4, 1, 4, 5] |
|
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) |
|
[0, 1, 4, 3, 2] |
|
""" |
|
|
|
weights: Tensor |
|
num_samples: int |
|
replacement: bool |
|
|
|
def __init__(self, weights: Sequence[float], num_samples: int, |
|
replacement: bool = True, generator=None) -> None: |
|
if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \ |
|
num_samples <= 0: |
|
raise ValueError(f"num_samples should be a positive integer value, but got num_samples={num_samples}") |
|
if not isinstance(replacement, bool): |
|
raise ValueError(f"replacement should be a boolean value, but got replacement={replacement}") |
|
|
|
weights_tensor = torch.as_tensor(weights, dtype=torch.double) |
|
if len(weights_tensor.shape) != 1: |
|
raise ValueError("weights should be a 1d sequence but given " |
|
f"weights have shape {tuple(weights_tensor.shape)}") |
|
|
|
self.weights = weights_tensor |
|
self.num_samples = num_samples |
|
self.replacement = replacement |
|
self.generator = generator |
|
|
|
def __iter__(self) -> Iterator[int]: |
|
rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator) |
|
yield from iter(rand_tensor.tolist()) |
|
|
|
def __len__(self) -> int: |
|
return self.num_samples |
|
|
|
|
|
class BatchSampler(Sampler[List[int]]): |
|
r"""Wraps another sampler to yield a mini-batch of indices. |
|
|
|
Args: |
|
sampler (Sampler or Iterable): Base sampler. Can be any iterable object |
|
batch_size (int): Size of mini-batch. |
|
drop_last (bool): If ``True``, the sampler will drop the last batch if |
|
its size would be less than ``batch_size`` |
|
|
|
Example: |
|
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) |
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] |
|
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) |
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8]] |
|
""" |
|
|
|
def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None: |
|
|
|
|
|
|
|
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ |
|
batch_size <= 0: |
|
raise ValueError(f"batch_size should be a positive integer value, but got batch_size={batch_size}") |
|
if not isinstance(drop_last, bool): |
|
raise ValueError(f"drop_last should be a boolean value, but got drop_last={drop_last}") |
|
self.sampler = sampler |
|
self.batch_size = batch_size |
|
self.drop_last = drop_last |
|
|
|
def __iter__(self) -> Iterator[List[int]]: |
|
|
|
if self.drop_last: |
|
sampler_iter = iter(self.sampler) |
|
while True: |
|
try: |
|
batch = [next(sampler_iter) for _ in range(self.batch_size)] |
|
yield batch |
|
except StopIteration: |
|
break |
|
else: |
|
batch = [0] * self.batch_size |
|
idx_in_batch = 0 |
|
for idx in self.sampler: |
|
batch[idx_in_batch] = idx |
|
idx_in_batch += 1 |
|
if idx_in_batch == self.batch_size: |
|
yield batch |
|
idx_in_batch = 0 |
|
batch = [0] * self.batch_size |
|
if idx_in_batch > 0: |
|
yield batch[:idx_in_batch] |
|
|
|
def __len__(self) -> int: |
|
|
|
|
|
|
|
|
|
if self.drop_last: |
|
return len(self.sampler) // self.batch_size |
|
else: |
|
return (len(self.sampler) + self.batch_size - 1) // self.batch_size |
|
|