Spaces:
Sleeping
Sleeping
File size: 3,534 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import math
from typing import Iterator, Optional, Sized
import torch
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from mmocr.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class BatchAugSampler(Sampler):
"""Sampler that repeats the same data elements for num_repeats times. The
batch size should be divisible by num_repeats.
It ensures that different each
augmented version of a sample will be visible to a different process (GPU).
Heavily based on torch.utils.data.DistributedSampler.
This sampler was modified from
https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
Used in
Copyright (c) 2015-present, Facebook, Inc.
Args:
dataset (Sized): The dataset.
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
num_repeats (int): The repeat times of every sample. Defaults to 3.
seed (int, optional): Random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Defaults to None.
"""
def __init__(self,
dataset: Sized,
shuffle: bool = True,
num_repeats: int = 3,
seed: Optional[int] = None):
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.shuffle = shuffle
if seed is None:
seed = sync_random_seed()
self.seed = seed
self.epoch = 0
self.num_repeats = num_repeats
# The number of repeated samples in the rank
self.num_samples = math.ceil(
len(self.dataset) * num_repeats / world_size)
# The total number of repeated samples in all ranks.
self.total_size = self.num_samples * world_size
# The number of selected samples in the rank
self.num_selected_samples = math.ceil(len(self.dataset) / world_size)
def __iter__(self) -> Iterator[int]:
"""Iterate the indices."""
# deterministically shuffle based on epoch and seed
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
indices = [x for x in indices for _ in range(self.num_repeats)]
# add extra samples to make it evenly divisible
indices = (indices *
int(self.total_size / len(indices) + 1))[:self.total_size]
assert len(indices) == self.total_size
# subsample per rank
indices = indices[self.rank:self.total_size:self.world_size]
assert len(indices) == self.num_samples
# return up to num selected samples
return iter(indices)
def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_selected_samples
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
|