|
|
|
import numpy as np |
|
from torch.utils.data.sampler import BatchSampler, Sampler |
|
|
|
|
|
class GroupedBatchSampler(BatchSampler): |
|
""" |
|
Wraps another sampler to yield a mini-batch of indices. |
|
It enforces that the batch only contain elements from the same group. |
|
It also tries to provide mini-batches which follows an ordering which is |
|
as close as possible to the ordering from the original sampler. |
|
""" |
|
|
|
def __init__(self, sampler, group_ids, batch_size): |
|
""" |
|
Args: |
|
sampler (Sampler): Base sampler. |
|
group_ids (list[int]): If the sampler produces indices in range [0, N), |
|
`group_ids` must be a list of `N` ints which contains the group id of each sample. |
|
The group ids must be a set of integers in the range [0, num_groups). |
|
batch_size (int): Size of mini-batch. |
|
""" |
|
if not isinstance(sampler, Sampler): |
|
raise ValueError( |
|
"sampler should be an instance of " |
|
"torch.utils.data.Sampler, but got sampler={}".format(sampler) |
|
) |
|
self.sampler = sampler |
|
self.group_ids = np.asarray(group_ids) |
|
assert self.group_ids.ndim == 1 |
|
self.batch_size = batch_size |
|
groups = np.unique(self.group_ids).tolist() |
|
|
|
|
|
self.buffer_per_group = {k: [] for k in groups} |
|
|
|
def __iter__(self): |
|
for idx in self.sampler: |
|
group_id = self.group_ids[idx] |
|
group_buffer = self.buffer_per_group[group_id] |
|
group_buffer.append(idx) |
|
if len(group_buffer) == self.batch_size: |
|
yield group_buffer[:] |
|
del group_buffer[:] |
|
|
|
def __len__(self): |
|
raise NotImplementedError("len() of GroupedBatchSampler is not well-defined.") |
|
|