File size: 8,339 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
"""Module that contain iterator used for dynamic data."""
from itertools import cycle
from torchtext.data import batch as torchtext_batch
from onmt.inputters import str2sortkey, max_tok_len, OrderedIterator
from onmt.inputters.corpus import get_corpora, build_corpora_iters,\
DatasetAdapter
from onmt.transforms import make_transforms
from onmt.utils.logging import logger
class MixingStrategy(object):
"""Mixing strategy that should be used in Data Iterator."""
def __init__(self, iterables, weights):
"""Initilize neccessary attr."""
self._valid_iterable(iterables, weights)
self.iterables = iterables
self.weights = weights
def _valid_iterable(self, iterables, weights):
iter_keys = iterables.keys()
weight_keys = weights.keys()
if iter_keys != weight_keys:
raise ValueError(
f"keys in {iterables} & {iterables} should be equal.")
def __iter__(self):
raise NotImplementedError
class SequentialMixer(MixingStrategy):
"""Generate data sequentially from `iterables` which is exhaustible."""
def _iter_datasets(self):
for ds_name, ds_weight in self.weights.items():
for _ in range(ds_weight):
yield ds_name
def __iter__(self):
for ds_name in self._iter_datasets():
iterable = self.iterables[ds_name]
yield from iterable
class WeightedMixer(MixingStrategy):
"""A mixing strategy that mix data weightedly and iterate infinitely."""
def __init__(self, iterables, weights):
super().__init__(iterables, weights)
self._iterators = {}
self._counts = {}
for ds_name in self.iterables.keys():
self._reset_iter(ds_name)
def _logging(self):
"""Report corpora loading statistics."""
msgs = []
for ds_name, ds_count in self._counts.items():
msgs.append(f"\t\t\t* {ds_name}: {ds_count}")
logger.info("Weighted corpora loaded so far:\n"+"\n".join(msgs))
def _reset_iter(self, ds_name):
self._iterators[ds_name] = iter(self.iterables[ds_name])
self._counts[ds_name] = self._counts.get(ds_name, 0) + 1
self._logging()
def _iter_datasets(self):
for ds_name, ds_weight in self.weights.items():
for _ in range(ds_weight):
yield ds_name
def __iter__(self):
for ds_name in cycle(self._iter_datasets()):
iterator = self._iterators[ds_name]
try:
item = next(iterator)
except StopIteration:
self._reset_iter(ds_name)
iterator = self._iterators[ds_name]
item = next(iterator)
finally:
yield item
class DynamicDatasetIter(object):
"""Yield batch from (multiple) plain text corpus.
Args:
corpora (dict[str, ParallelCorpus]): collections of corpora to iterate;
corpora_info (dict[str, dict]): corpora infos correspond to corpora;
transforms (dict[str, Transform]): transforms may be used by corpora;
fields (dict[str, Field]): fields dict for convert corpora into Tensor;
is_train (bool): True when generate data for training;
batch_type (str): batching type to count on, choices=[tokens, sents];
batch_size (int): numbers of examples in a batch;
batch_size_multiple (int): make batch size multiply of this;
data_type (str): input data type, currently only text;
bucket_size (int): accum this number of examples in a dynamic dataset;
pool_factor (int): accum this number of batch before sorting;
skip_empty_level (str): security level when encouter empty line;
stride (int): iterate data files with this stride;
offset (int): iterate data files with this offset.
Attributes:
batch_size_fn (function): functions to calculate batch_size;
sort_key (function): functions define how to sort examples;
dataset_adapter (DatasetAdapter): organize raw corpus to tensor adapt;
mixer (MixingStrategy): the strategy to iterate corpora.
"""
def __init__(self, corpora, corpora_info, transforms, fields, is_train,
batch_type, batch_size, batch_size_multiple, data_type="text",
bucket_size=2048, pool_factor=8192,
skip_empty_level='warning', stride=1, offset=0):
self.corpora = corpora
self.transforms = transforms
self.fields = fields
self.corpora_info = corpora_info
self.is_train = is_train
self.init_iterators = False
self.batch_size = batch_size
self.batch_size_fn = max_tok_len if batch_type == "tokens" else None
self.batch_size_multiple = batch_size_multiple
self.device = 'cpu'
self.sort_key = str2sortkey[data_type]
self.bucket_size = bucket_size
self.pool_factor = pool_factor
if stride <= 0:
raise ValueError(f"Invalid argument for stride={stride}.")
self.stride = stride
self.offset = offset
if skip_empty_level not in ['silent', 'warning', 'error']:
raise ValueError(
f"Invalid argument skip_empty_level={skip_empty_level}")
self.skip_empty_level = skip_empty_level
@classmethod
def from_opts(cls, corpora, transforms, fields, opts, is_train,
stride=1, offset=0):
"""Initilize `DynamicDatasetIter` with options parsed from `opts`."""
batch_size = opts.batch_size if is_train else opts.valid_batch_size
if opts.batch_size_multiple is not None:
batch_size_multiple = opts.batch_size_multiple
else:
batch_size_multiple = 8 if opts.model_dtype == "fp16" else 1
return cls(
corpora, opts.data, transforms, fields, is_train, opts.batch_type,
batch_size, batch_size_multiple, data_type=opts.data_type,
bucket_size=opts.bucket_size, pool_factor=opts.pool_factor,
skip_empty_level=opts.skip_empty_level,
stride=stride, offset=offset
)
def _init_datasets(self):
datasets_iterables = build_corpora_iters(
self.corpora, self.transforms, self.corpora_info,
skip_empty_level=self.skip_empty_level,
stride=self.stride, offset=self.offset)
self.dataset_adapter = DatasetAdapter(self.fields, self.is_train)
datasets_weights = {
ds_name: int(self.corpora_info[ds_name]['weight'])
for ds_name in datasets_iterables.keys()
}
if self.is_train:
self.mixer = WeightedMixer(datasets_iterables, datasets_weights)
else:
self.mixer = SequentialMixer(datasets_iterables, datasets_weights)
self.init_iterators = True
def _bucketing(self):
buckets = torchtext_batch(
self.mixer,
batch_size=self.bucket_size,
batch_size_fn=None)
yield from buckets
def __iter__(self):
if self.init_iterators is False:
self._init_datasets()
for bucket in self._bucketing():
dataset = self.dataset_adapter(bucket)
train_iter = OrderedIterator(
dataset,
self.batch_size,
pool_factor=self.pool_factor,
batch_size_fn=self.batch_size_fn,
batch_size_multiple=self.batch_size_multiple,
device=self.device,
train=self.is_train,
sort=False,
sort_within_batch=True,
sort_key=self.sort_key,
repeat=False,
)
for batch in train_iter:
yield batch
def build_dynamic_dataset_iter(fields, transforms_cls, opts, is_train=True,
stride=1, offset=0):
"""Build `DynamicDatasetIter` from fields & opts."""
transforms = make_transforms(opts, transforms_cls, fields)
corpora = get_corpora(opts, is_train)
if corpora is None:
assert not is_train, "only valid corpus is ignorable."
return None
return DynamicDatasetIter.from_opts(
corpora, transforms, fields, opts, is_train,
stride=stride, offset=offset)
|