sakharamg's picture
Uploading all files
158b61b
"""Module that contain iterator used for dynamic data."""
from itertools import cycle
from torchtext.data import batch as torchtext_batch
from onmt.inputters import str2sortkey, max_tok_len, OrderedIterator
from onmt.inputters.corpus import get_corpora, build_corpora_iters,\
DatasetAdapter
from onmt.transforms import make_transforms
from onmt.utils.logging import logger
class MixingStrategy(object):
"""Mixing strategy that should be used in Data Iterator."""
def __init__(self, iterables, weights):
"""Initilize neccessary attr."""
self._valid_iterable(iterables, weights)
self.iterables = iterables
self.weights = weights
def _valid_iterable(self, iterables, weights):
iter_keys = iterables.keys()
weight_keys = weights.keys()
if iter_keys != weight_keys:
raise ValueError(
f"keys in {iterables} & {iterables} should be equal.")
def __iter__(self):
raise NotImplementedError
class SequentialMixer(MixingStrategy):
"""Generate data sequentially from `iterables` which is exhaustible."""
def _iter_datasets(self):
for ds_name, ds_weight in self.weights.items():
for _ in range(ds_weight):
yield ds_name
def __iter__(self):
for ds_name in self._iter_datasets():
iterable = self.iterables[ds_name]
yield from iterable
class WeightedMixer(MixingStrategy):
"""A mixing strategy that mix data weightedly and iterate infinitely."""
def __init__(self, iterables, weights):
super().__init__(iterables, weights)
self._iterators = {}
self._counts = {}
for ds_name in self.iterables.keys():
self._reset_iter(ds_name)
def _logging(self):
"""Report corpora loading statistics."""
msgs = []
for ds_name, ds_count in self._counts.items():
msgs.append(f"\t\t\t* {ds_name}: {ds_count}")
logger.info("Weighted corpora loaded so far:\n"+"\n".join(msgs))
def _reset_iter(self, ds_name):
self._iterators[ds_name] = iter(self.iterables[ds_name])
self._counts[ds_name] = self._counts.get(ds_name, 0) + 1
self._logging()
def _iter_datasets(self):
for ds_name, ds_weight in self.weights.items():
for _ in range(ds_weight):
yield ds_name
def __iter__(self):
for ds_name in cycle(self._iter_datasets()):
iterator = self._iterators[ds_name]
try:
item = next(iterator)
except StopIteration:
self._reset_iter(ds_name)
iterator = self._iterators[ds_name]
item = next(iterator)
finally:
yield item
class DynamicDatasetIter(object):
"""Yield batch from (multiple) plain text corpus.
Args:
corpora (dict[str, ParallelCorpus]): collections of corpora to iterate;
corpora_info (dict[str, dict]): corpora infos correspond to corpora;
transforms (dict[str, Transform]): transforms may be used by corpora;
fields (dict[str, Field]): fields dict for convert corpora into Tensor;
is_train (bool): True when generate data for training;
batch_type (str): batching type to count on, choices=[tokens, sents];
batch_size (int): numbers of examples in a batch;
batch_size_multiple (int): make batch size multiply of this;
data_type (str): input data type, currently only text;
bucket_size (int): accum this number of examples in a dynamic dataset;
pool_factor (int): accum this number of batch before sorting;
skip_empty_level (str): security level when encouter empty line;
stride (int): iterate data files with this stride;
offset (int): iterate data files with this offset.
Attributes:
batch_size_fn (function): functions to calculate batch_size;
sort_key (function): functions define how to sort examples;
dataset_adapter (DatasetAdapter): organize raw corpus to tensor adapt;
mixer (MixingStrategy): the strategy to iterate corpora.
"""
def __init__(self, corpora, corpora_info, transforms, fields, is_train,
batch_type, batch_size, batch_size_multiple, data_type="text",
bucket_size=2048, pool_factor=8192,
skip_empty_level='warning', stride=1, offset=0):
self.corpora = corpora
self.transforms = transforms
self.fields = fields
self.corpora_info = corpora_info
self.is_train = is_train
self.init_iterators = False
self.batch_size = batch_size
self.batch_size_fn = max_tok_len if batch_type == "tokens" else None
self.batch_size_multiple = batch_size_multiple
self.device = 'cpu'
self.sort_key = str2sortkey[data_type]
self.bucket_size = bucket_size
self.pool_factor = pool_factor
if stride <= 0:
raise ValueError(f"Invalid argument for stride={stride}.")
self.stride = stride
self.offset = offset
if skip_empty_level not in ['silent', 'warning', 'error']:
raise ValueError(
f"Invalid argument skip_empty_level={skip_empty_level}")
self.skip_empty_level = skip_empty_level
@classmethod
def from_opts(cls, corpora, transforms, fields, opts, is_train,
stride=1, offset=0):
"""Initilize `DynamicDatasetIter` with options parsed from `opts`."""
batch_size = opts.batch_size if is_train else opts.valid_batch_size
if opts.batch_size_multiple is not None:
batch_size_multiple = opts.batch_size_multiple
else:
batch_size_multiple = 8 if opts.model_dtype == "fp16" else 1
return cls(
corpora, opts.data, transforms, fields, is_train, opts.batch_type,
batch_size, batch_size_multiple, data_type=opts.data_type,
bucket_size=opts.bucket_size, pool_factor=opts.pool_factor,
skip_empty_level=opts.skip_empty_level,
stride=stride, offset=offset
)
def _init_datasets(self):
datasets_iterables = build_corpora_iters(
self.corpora, self.transforms, self.corpora_info,
skip_empty_level=self.skip_empty_level,
stride=self.stride, offset=self.offset)
self.dataset_adapter = DatasetAdapter(self.fields, self.is_train)
datasets_weights = {
ds_name: int(self.corpora_info[ds_name]['weight'])
for ds_name in datasets_iterables.keys()
}
if self.is_train:
self.mixer = WeightedMixer(datasets_iterables, datasets_weights)
else:
self.mixer = SequentialMixer(datasets_iterables, datasets_weights)
self.init_iterators = True
def _bucketing(self):
buckets = torchtext_batch(
self.mixer,
batch_size=self.bucket_size,
batch_size_fn=None)
yield from buckets
def __iter__(self):
if self.init_iterators is False:
self._init_datasets()
for bucket in self._bucketing():
dataset = self.dataset_adapter(bucket)
train_iter = OrderedIterator(
dataset,
self.batch_size,
pool_factor=self.pool_factor,
batch_size_fn=self.batch_size_fn,
batch_size_multiple=self.batch_size_multiple,
device=self.device,
train=self.is_train,
sort=False,
sort_within_batch=True,
sort_key=self.sort_key,
repeat=False,
)
for batch in train_iter:
yield batch
def build_dynamic_dataset_iter(fields, transforms_cls, opts, is_train=True,
stride=1, offset=0):
"""Build `DynamicDatasetIter` from fields & opts."""
transforms = make_transforms(opts, transforms_cls, fields)
corpora = get_corpora(opts, is_train)
if corpora is None:
assert not is_train, "only valid corpus is ignorable."
return None
return DynamicDatasetIter.from_opts(
corpora, transforms, fields, opts, is_train,
stride=stride, offset=offset)