# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Code is copy-pasted exactly as in torch.utils.data.distributed. # FIXME remove this once c10d fixes the bug it has import os import math import torch import torch.distributed as dist from torch.utils.data.sampler import Sampler import random from uniperceiver.utils import comm import itertools class DistributedSampler(Sampler): """Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each process can pass a DistributedSampler instance as a DataLoader sampler, and load a subset of the original dataset that is exclusive to it. .. note:: Dataset is assumed to be of constant size. Arguments: dataset: Dataset used for sampling. num_replicas (optional): Number of processes participating in distributed training. rank (optional): Rank of the current process within num_replicas. """ def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True, dataset_repeat=1): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle self.dataset_repeat = dataset_repeat def __iter__(self): if self.shuffle: # deterministically shuffle based on epoch g = torch.Generator() g.manual_seed(self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() else: indices = torch.arange(len(self.dataset)).tolist() # add extra samples to make it evenly divisible indices += indices[: (self.total_size - len(indices))] assert len(indices) == self.total_size # subsample offset = self.num_samples * self.rank indices = indices[offset : offset + self.num_samples] assert len(indices) == self.num_samples repeated_indices = [] for _ in range(self.dataset_repeat): repeated_indices += torch.tensor(indices)[torch.randperm(len(indices), generator=g)].tolist() return iter(repeated_indices) def __len__(self): return self.num_samples * self.dataset_repeat def set_epoch(self, epoch): self.epoch = epoch class TrainingSampler(Sampler): """ In training, we only care about the "infinite stream" of training data. So this sampler produces an infinite stream of indices and all workers cooperate to correctly shuffle the indices and sample different indices. The samplers in each worker effectively produces `indices[worker_id::num_workers]` where `indices` is an infinite stream of indices consisting of `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True) or `range(size) + range(size) + ...` (if shuffle is False) """ def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True, seed = None): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) -1 self.total_size = len(dataset) self.shuffle = shuffle # self.dataset_repeat = dataset_repeat if seed is None: seed = comm.shared_random_seed() self.seed = int(seed) def __len__(self): return self.num_samples def __iter__(self): start = self.rank yield from itertools.islice(self._infinite_indices(), start, None, self.num_replicas) def _infinite_indices(self): g = torch.Generator() g.manual_seed(self.seed) while True: if self.shuffle: yield from torch.randperm(self.total_size, generator=g).tolist() else: yield from torch.arange(self.total_size).tolist() class NaiveSampler(Sampler): """ In training, we only care about the "infinite stream" of training data. So this sampler produces an infinite stream of indices and all workers cooperate to correctly shuffle the indices and sample different indices. The samplers in each worker effectively produces `indices[worker_id::num_workers]` where `indices` is an infinite stream of indices consisting of `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True) or `range(size) + range(size) + ...` (if shuffle is False) for bookswiki node-block cache """ def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True, seed = None): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() // comm.get_local_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = comm.get_rank() // comm.get_local_size() self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.num_samples =int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) -1 self.total_size = len(dataset) self.shuffle = shuffle # self.dataset_repeat = dataset_repeat if seed is None: seed = comm.shared_random_seed() self.seed = int(seed) def __len__(self): return self.num_samples def __iter__(self): start = self.rank yield from itertools.islice(self._infinite_indices(), start, None, self.num_replicas) def _infinite_indices(self): g = torch.Generator() g.manual_seed(self.seed) while True: if self.shuffle: yield from torch.randperm(self.total_size, generator=g).tolist() else: yield from torch.arange(self.total_size).tolist() class NodeDistributedSampler(Sampler): """Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each process can pass a DistributedSampler instance as a DataLoader sampler, and load a subset of the original dataset that is exclusive to it. .. note:: Dataset is assumed to be of constant size. Arguments: dataset: Dataset used for sampling. num_replicas (optional): Number of processes participating in distributed training. rank (optional): Rank of the current process within num_replicas. """ def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() if local_rank is None: local_rank = int(os.environ.get('LOCAL_RANK', 0)) if local_size is None: local_size = int(os.environ.get('LOCAL_SIZE', 1)) self.dataset = dataset self.shuffle = shuffle self.num_replicas = num_replicas self.num_parts = local_size self.rank = rank self.local_rank = local_rank self.epoch = 0 self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) self.total_size = self.num_samples * self.num_replicas self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts self.indices = [i for i in range(len(self.dataset)) if i % self.num_parts == self.local_rank] seed = comm.shared_random_seed() self.seed = int(seed) def __iter__(self): start = self.rank // self.num_parts yield from itertools.islice(self._infinite_indices(), start, None, self.num_replicas // self.num_parts) def _infinite_indices(self): g = torch.Generator() g.manual_seed(self.seed) while True: if self.shuffle: yield from torch.tensor(self.indices)[torch.randperm(len(self.indices), generator=g)].tolist() else: yield from self.indices def __len__(self): return self.num_samples def set_epoch(self, epoch): self.epoch = epoch class NodeDistributedSampler_bak(Sampler): """Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each process can pass a DistributedSampler instance as a DataLoader sampler, and load a subset of the original dataset that is exclusive to it. .. note:: Dataset is assumed to be of constant size. Arguments: dataset: Dataset used for sampling. num_replicas (optional): Number of processes participating in distributed training. rank (optional): Rank of the current process within num_replicas. """ def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() if local_rank is None: local_rank = int(os.environ.get('LOCAL_RANK', 0)) if local_size is None: local_size = int(os.environ.get('LOCAL_SIZE', 1)) self.dataset = dataset self.shuffle = shuffle self.num_replicas = num_replicas self.num_parts = local_size self.rank = rank self.local_rank = local_rank self.epoch = 0 self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) self.total_size = self.num_samples * self.num_replicas self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts def __iter__(self): if self.shuffle: # deterministically shuffle based on epoch g = torch.Generator() g.manual_seed(self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() else: indices = torch.arange(len(self.dataset)).tolist() indices = [i for i in indices if i % self.num_parts == self.local_rank] # add extra samples to make it evenly divisible indices += indices[:(self.total_size_parts - len(indices))] assert len(indices) == self.total_size_parts # subsample indices = indices[self.rank // self.num_parts:self.total_size_parts:self.num_replicas // self.num_parts] assert len(indices) == self.num_samples return iter(indices) def __len__(self): return self.num_samples def set_epoch(self, epoch): self.epoch = epoch