"""batch samplers that work with either random or sequential data samplers""" |
import math |
import os |
import sys |
import torch |
from torch.utils import data |
import numpy as np |
class RandomSampler(data.sampler.Sampler): |
r""" |
Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler, |
but this class lets the user set an epoch like DistributedSampler |
Samples elements randomly. If without replacement, then sample from a shuffled dataset. |
If with replacement, then user can specify ``num_samples`` to draw. |
Arguments: |
data_source (Dataset): dataset to sample from |
num_samples (int): number of samples to draw, default=len(dataset) |
replacement (bool): samples are drawn with replacement if ``True``, default=False |
""" |
def __init__(self, data_source, replacement=False, num_samples=None): |
super(RandomSampler, self).__init__(data_source) |
self.data_source = data_source |
self.replacement = replacement |
self._num_samples = num_samples |
self.epoch = -1 |
if self._num_samples is not None and replacement is False: |
raise ValueError("With replacement=False, num_samples should not be specified, " |
"since a random permute will be performed.") |
if not isinstance(self.num_samples, int) or self.num_samples <= 0: |
raise ValueError("num_samples should be a positive integer " |
"value, but got num_samples={}".format(self.num_samples)) |
if not isinstance(self.replacement, bool): |
raise ValueError("replacement should be a boolean value, but got " |
"replacement={}".format(self.replacement)) |
@property |
def num_samples(self): |
if self._num_samples is None: |
return len(self.data_source) |
return self._num_samples |
def __iter__(self): |
n = len(self.data_source) |
g = torch.Generator() |
if self.epoch >= 0: |
g.manual_seed(self.epoch) |
if self.replacement: |
for _ in range(self.num_samples // 32): |
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=g).tolist() |
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, |
generator=g).tolist() |
else: |
yield from torch.randperm(n, generator=self.generator).tolist() |
def __len__(self): |
return self.num_samples |
def set_epoch(self, epoch): |
self.epoch = epoch |
class DistributedSequentialSampler(data.sampler.Sampler): |
def __init__(self, num_samples, train_iters, batch_size, rank=-1, world_size=2): |
super().__init__(num_samples) |
if rank == -1: |
rank = 0 |
world_size = 1 |
self.num_samples = num_samples |
self.rank = rank |
self.world_size = world_size |
self.start_iter = 0 |
self.train_iters = train_iters |
self.batch_size = batch_size |
self.batch_bias = [i * (num_samples // batch_size) for i in range(batch_size)] |
def __iter__(self): |
for idx in range(self.start_iter, self.train_iters * 10): |
batch = [(idx + bias) % self.num_samples for bias in self.batch_bias] |
tbatch = self._batch(batch) |
yield tbatch |
def __len__(self): |
return self.train_iters |
def _batch(self, batch): |
"""extracts samples only pertaining to this worker's batch""" |
start = self.rank*self.batch_size//self.world_size |
end = (self.rank+1)*self.batch_size//self.world_size |
return batch[start:end] |
class DistributedBatchSampler(data.sampler.BatchSampler): |
""" |
similar to normal implementation of distributed sampler, except implementation is at the |
batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary |
data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler. |
""" |
def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False, gradient_accumulation_steps=None): |
super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last) |
if rank == -1: |
assert False, 'should not be here' |
self.rank = rank |
self.world_size = world_size |
self.sampler.wrap_around = 0 |
self.wrap_around = 0 |
self.wrap_last = wrap_last |
self.start_iter = 0 |
self.effective_batch_size = batch_size if gradient_accumulation_steps is None else batch_size * gradient_accumulation_steps |
def __iter__(self): |
batch = [] |
i = 0 |
for idx in self.data_iterator(self.sampler, wrap_around=False): |
batch.append(idx) |
if len(batch) == self.batch_size: |
tbatch = self._batch(batch) |
if i >= self.start_iter * self.effective_batch_size: |
yield tbatch |
self.start_iter = 0 |
i += len(batch) |
batch = [] |
batch_len = len(batch) |
if batch_len > 0 and not self.drop_last: |
if self.wrap_last: |
self.sampler.wrap_around -= (self.batch_size) |
self.wrap_around += (len(batch)) |
self.wrap_around %= self.batch_size |
yield self._batch(batch) |
if self.wrap_last: |
self.sampler.wrap_around += self.batch_size |
def data_iterator(self, _iter, wrap_around=False): |
"""iterates through data and handles wrap around""" |
for i, idx in enumerate(_iter): |
if i < self.wrap_around%self.batch_size: |
continue |
if wrap_around: |
self.wrap_around += 1 |
self.wrap_around %= self.batch_size |
yield idx |
def _batch(self, batch): |
"""extracts samples only pertaining to this worker's batch""" |
start = self.rank*self.batch_size//self.world_size |
end = (self.rank+1)*self.batch_size//self.world_size |
return batch[start:end] |
class DistributedMultiDatasetBatchSampler(data.sampler.BatchSampler): |
""" |
This is a modality-blended batch sampler which allows to sample a batch data from different dataset alternatively. |
""" |
def __init__(self, sampler, batch_size, dataset, drop_last, rank=-1, world_size=2, wrap_last=False, gradient_accumulation_steps=None): |
super(DistributedMultiDatasetBatchSampler, self).__init__(sampler, batch_size, drop_last) |
if rank == -1: |
assert False, 'should not be here' |
self.rank = rank |
self.world_size = world_size |
self.wrap_last = wrap_last |
self.drop_last = drop_last |
self.gradient_accumulation_steps = gradient_accumulation_steps |
self.dataset = dataset |
self.batch_size = batch_size |
self.number_of_datasets = len(dataset.datasets.datasets) |
self.largest_dataset_size = max([_cur_dataset.__len__() for _cur_dataset in dataset.datasets.datasets]) |
def __iter__(self): |
samplers_list = [] |
sampler_iterators = [] |
for dataset_idx in range(self.number_of_datasets): |
cur_dataset = self.dataset.datasets.datasets[dataset_idx] |
sampler = torch.utils.data.RandomSampler(cur_dataset) |
batch_sampler = DistributedBatchSampler(sampler, self.batch_size, self.drop_last, self.rank, |
self.world_size, self.wrap_last, self.gradient_accumulation_steps) |
samplers_list.append(batch_sampler) |
cur_sampler_iterator = batch_sampler.__iter__() |
sampler_iterators.append(cur_sampler_iterator) |
push_index_val = [0] + self.dataset.datasets.cumulative_sizes[:-1] |
step = self.batch_size * self.number_of_datasets |
samples_to_grab = self.batch_size |
epoch_samples = self.largest_dataset_size * self.number_of_datasets |
for _ in range(0, epoch_samples, step): |
for i in range(self.number_of_datasets): |
cur_batch_sampler = sampler_iterators[i] |
try: |
cur_sample_org = cur_batch_sampler.__next__() |
cur_samples = [x + push_index_val[i] for x in cur_sample_org] |
yield cur_samples |
except StopIteration: |
sampler_iterators[i] = samplers_list[i].__iter__() |
cur_batch_sampler = sampler_iterators[i] |
cur_sample_org = cur_batch_sampler.__next__() |
cur_samples = [x + push_index_val[i] for x in cur_sample_org] |
yield cur_samples |