Gosse Minnema
First version
6680682
import logging
import random
from typing import *
from allennlp.data.samplers.batch_sampler import BatchSampler
from allennlp.data.samplers.max_tokens_batch_sampler import MaxTokensBatchSampler
from torch.utils import data
logger = logging.getLogger('mix_sampler')
@BatchSampler.register('mix_sampler')
class MixSampler(MaxTokensBatchSampler):
def __init__(
self,
max_tokens: int,
sorting_keys: List[str] = None,
padding_noise: float = 0.1,
sampling_ratios: Optional[Dict[str, float]] = None,
):
super().__init__(max_tokens, sorting_keys, padding_noise)
self.sampling_ratios = sampling_ratios or dict()
def __iter__(self):
indices, lengths = self._argsort_by_padding(self.data_source)
original_num = len(indices)
instance_types = [
ins.fields['meta'].metadata.get('type', 'default') if 'meta' in ins.fields else 'default'
for ins in self.data_source
]
instance_thresholds = [
self.sampling_ratios[ins_type] if ins_type in self.sampling_ratios else 1.0 for ins_type in instance_types
]
for idx, threshold in enumerate(instance_thresholds):
if random.random() > threshold:
# Reject
list_idx = indices.index(idx)
del indices[list_idx], lengths[list_idx]
if original_num != len(indices):
logger.info(f'#instances reduced from {original_num} to {len(indices)}.')
max_lengths = [max(length) for length in lengths]
group_iterator = self._lazy_groups_of_max_size(indices, max_lengths)
batches = [list(group) for group in group_iterator]
random.shuffle(batches)
for batch in batches:
yield batch