|
|
|
|
|
|
|
|
|
import itertools |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from torch.utils.data.sampler import BatchSampler as torchBatchSampler |
|
from torch.utils.data.sampler import Sampler |
|
|
|
|
|
class YoloBatchSampler(torchBatchSampler): |
|
""" |
|
This batch sampler will generate mini-batches of (mosaic, index) tuples from another sampler. |
|
It works just like the :class:`torch.utils.data.sampler.BatchSampler`, |
|
but it will turn on/off the mosaic aug. |
|
""" |
|
|
|
def __init__(self, *args, mosaic=True, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.mosaic = mosaic |
|
|
|
def __iter__(self): |
|
for batch in super().__iter__(): |
|
yield [(self.mosaic, idx) for idx in batch] |
|
|
|
|
|
class InfiniteSampler(Sampler): |
|
""" |
|
In training, we only care about the "infinite stream" of training data. |
|
So this sampler produces an infinite stream of indices and |
|
all workers cooperate to correctly shuffle the indices and sample different indices. |
|
The samplers in each worker effectively produces `indices[worker_id::num_workers]` |
|
where `indices` is an infinite stream of indices consisting of |
|
`shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True) |
|
or `range(size) + range(size) + ...` (if shuffle is False) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
size: int, |
|
shuffle: bool = True, |
|
seed: Optional[int] = 0, |
|
rank=0, |
|
world_size=1, |
|
): |
|
""" |
|
Args: |
|
size (int): the total number of data of the underlying dataset to sample from |
|
shuffle (bool): whether to shuffle the indices or not |
|
seed (int): the initial seed of the shuffle. Must be the same |
|
across all workers. If None, will use a random seed shared |
|
among workers (require synchronization among all workers). |
|
""" |
|
self._size = size |
|
assert size > 0 |
|
self._shuffle = shuffle |
|
self._seed = int(seed) |
|
|
|
if dist.is_available() and dist.is_initialized(): |
|
self._rank = dist.get_rank() |
|
self._world_size = dist.get_world_size() |
|
else: |
|
self._rank = rank |
|
self._world_size = world_size |
|
|
|
def __iter__(self): |
|
start = self._rank |
|
yield from itertools.islice( |
|
self._infinite_indices(), start, None, self._world_size |
|
) |
|
|
|
def _infinite_indices(self): |
|
g = torch.Generator() |
|
g.manual_seed(self._seed) |
|
while True: |
|
if self._shuffle: |
|
yield from torch.randperm(self._size, generator=g) |
|
else: |
|
yield from torch.arange(self._size) |
|
|
|
def __len__(self): |
|
return self._size // self._world_size |
|
|