|
import math |
|
|
|
import torch |
|
from torch.utils.data import DistributedSampler as _DistributedSampler |
|
|
|
|
|
class DistributedSampler(_DistributedSampler): |
|
|
|
def __init__(self, |
|
dataset, |
|
num_replicas=None, |
|
rank=None, |
|
shuffle=True, |
|
seed=0): |
|
super().__init__( |
|
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) |
|
|
|
self.seed = seed if seed is not None else 0 |
|
|
|
def __iter__(self): |
|
|
|
if self.shuffle: |
|
g = torch.Generator() |
|
g.manual_seed(self.epoch + self.seed) |
|
indices = torch.randperm(len(self.dataset), generator=g).tolist() |
|
else: |
|
indices = torch.arange(len(self.dataset)).tolist() |
|
|
|
|
|
|
|
indices = (indices * |
|
math.ceil(self.total_size / len(indices)))[:self.total_size] |
|
assert len(indices) == self.total_size |
|
|
|
|
|
indices = indices[self.rank:self.total_size:self.num_replicas] |
|
assert len(indices) == self.num_samples |
|
|
|
return iter(indices) |
|
|