Spaces:
Build error
Build error
import logging | |
import random | |
from typing import * | |
from allennlp.data.samplers.batch_sampler import BatchSampler | |
from allennlp.data.samplers.max_tokens_batch_sampler import MaxTokensBatchSampler | |
from torch.utils import data | |
logger = logging.getLogger('mix_sampler') | |
class MixSampler(MaxTokensBatchSampler): | |
def __init__( | |
self, | |
max_tokens: int, | |
sorting_keys: List[str] = None, | |
padding_noise: float = 0.1, | |
sampling_ratios: Optional[Dict[str, float]] = None, | |
): | |
super().__init__(max_tokens, sorting_keys, padding_noise) | |
self.sampling_ratios = sampling_ratios or dict() | |
def __iter__(self): | |
indices, lengths = self._argsort_by_padding(self.data_source) | |
original_num = len(indices) | |
instance_types = [ | |
ins.fields['meta'].metadata.get('type', 'default') if 'meta' in ins.fields else 'default' | |
for ins in self.data_source | |
] | |
instance_thresholds = [ | |
self.sampling_ratios[ins_type] if ins_type in self.sampling_ratios else 1.0 for ins_type in instance_types | |
] | |
for idx, threshold in enumerate(instance_thresholds): | |
if random.random() > threshold: | |
# Reject | |
list_idx = indices.index(idx) | |
del indices[list_idx], lengths[list_idx] | |
if original_num != len(indices): | |
logger.info(f'#instances reduced from {original_num} to {len(indices)}.') | |
max_lengths = [max(length) for length in lengths] | |
group_iterator = self._lazy_groups_of_max_size(indices, max_lengths) | |
batches = [list(group) for group in group_iterator] | |
random.shuffle(batches) | |
for batch in batches: | |
yield batch | |