Spaces:
Paused
Paused
# Copyright (c) Facebook, Inc. and its affiliates. | |
import numpy as np | |
from torch.utils.data.sampler import BatchSampler, Sampler | |
class GroupedBatchSampler(BatchSampler): | |
""" | |
Wraps another sampler to yield a mini-batch of indices. | |
It enforces that the batch only contain elements from the same group. | |
It also tries to provide mini-batches which follows an ordering which is | |
as close as possible to the ordering from the original sampler. | |
""" | |
def __init__(self, sampler, group_ids, batch_size): | |
""" | |
Args: | |
sampler (Sampler): Base sampler. | |
group_ids (list[int]): If the sampler produces indices in range [0, N), | |
`group_ids` must be a list of `N` ints which contains the group id of each sample. | |
The group ids must be a set of integers in the range [0, num_groups). | |
batch_size (int): Size of mini-batch. | |
""" | |
if not isinstance(sampler, Sampler): | |
raise ValueError( | |
"sampler should be an instance of " | |
"torch.utils.data.Sampler, but got sampler={}".format(sampler) | |
) | |
self.sampler = sampler | |
self.group_ids = np.asarray(group_ids) | |
assert self.group_ids.ndim == 1 | |
self.batch_size = batch_size | |
groups = np.unique(self.group_ids).tolist() | |
# buffer the indices of each group until batch size is reached | |
self.buffer_per_group = {k: [] for k in groups} | |
def __iter__(self): | |
for idx in self.sampler: | |
group_id = self.group_ids[idx] | |
group_buffer = self.buffer_per_group[group_id] | |
group_buffer.append(idx) | |
if len(group_buffer) == self.batch_size: | |
yield group_buffer[:] # yield a copy of the list | |
del group_buffer[:] | |
def __len__(self): | |
raise NotImplementedError("len() of GroupedBatchSampler is not well-defined.") | |