File size: 8,777 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
# coding: utf-8
from itertools import chain, starmap
from collections import Counter
import torch
from torchtext.data import Dataset as TorchtextDataset
from torchtext.data import Example
from torchtext.vocab import Vocab
def _join_dicts(*args):
"""
Args:
dictionaries with disjoint keys.
Returns:
a single dictionary that has the union of these keys.
"""
return dict(chain(*[d.items() for d in args]))
def _dynamic_dict(example, src_field, tgt_field):
"""Create copy-vocab and numericalize with it.
In-place adds ``"src_map"`` to ``example``. That is the copy-vocab
numericalization of the tokenized ``example["src"]``. If ``example``
has a ``"tgt"`` key, adds ``"alignment"`` to example. That is the
copy-vocab numericalization of the tokenized ``example["tgt"]``. The
alignment has an initial and final UNK token to match the BOS and EOS
tokens.
Args:
example (dict): An example dictionary with a ``"src"`` key and
maybe a ``"tgt"`` key. (This argument changes in place!)
src_field (torchtext.data.Field): Field object.
tgt_field (torchtext.data.Field): Field object.
Returns:
``example``, changed as described.
"""
src = src_field.tokenize(example["src"]["src"])
# make a small vocab containing just the tokens in the source sequence
unk = src_field.unk_token
pad = src_field.pad_token
# add init_token and eos_token according to src construction
if src_field.init_token:
src = [src_field.init_token] + src
if src_field.eos_token:
src.append(src_field.eos_token)
src_ex_vocab = Vocab(Counter(src), specials=[unk, pad])
unk_idx = src_ex_vocab.stoi[unk]
# Map source tokens to indices in the dynamic dict.
src_map = torch.LongTensor([src_ex_vocab.stoi[w] for w in src])
example["src_map"] = src_map
example["src_ex_vocab"] = src_ex_vocab
if "tgt" in example:
tgt = tgt_field.tokenize(example["tgt"]["tgt"])
mask = torch.LongTensor(
[unk_idx] + [src_ex_vocab.stoi[w] for w in tgt] + [unk_idx])
example["alignment"] = mask
return example
class Dataset(TorchtextDataset):
"""Contain data and process it.
A dataset is an object that accepts sequences of raw data (sentence pairs
in the case of machine translation) and fields which describe how this
raw data should be processed to produce tensors. When a dataset is
instantiated, it applies the fields' preprocessing pipeline (but not
the bit that numericalizes it or turns it into batch tensors) to the raw
data, producing a list of :class:`torchtext.data.Example` objects.
torchtext's iterators then know how to use these examples to make batches.
Args:
fields (dict[str, Field]): a dict with the structure
returned by :func:`onmt.inputters.get_fields()`. Usually
that means the dataset side, ``"src"`` or ``"tgt"``. Keys match
the keys of items yielded by the ``readers``, while values
are lists of (name, Field) pairs. An attribute with this
name will be created for each :class:`torchtext.data.Example`
object and its value will be the result of applying the Field
to the data that matches the key. The advantage of having
sequences of fields for each piece of raw input is that it allows
the dataset to store multiple "views" of each input, which allows
for easy implementation of token-level features, mixed word-
and character-level models, and so on. (See also
:class:`onmt.inputters.TextMultiField`.)
readers (Iterable[onmt.inputters.DataReaderBase]): Reader objects
for disk-to-dict. The yielded dicts are then processed
according to ``fields``.
data (Iterable[Tuple[str, Any]]): (name, ``data_arg``) pairs
where ``data_arg`` is passed to the ``read()`` method of the
reader in ``readers`` at that position. (See the reader object for
details on the ``Any`` type.)
sort_key (Callable[[torchtext.data.Example], Any]): A function
for determining the value on which data is sorted (i.e. length).
filter_pred (Callable[[torchtext.data.Example], bool]): A function
that accepts Example objects and returns a boolean value
indicating whether to include that example in the dataset.
Attributes:
src_vocabs (List[torchtext.data.Vocab]): Used with dynamic dict/copy
attention. There is a very short vocab for each src example.
It contains just the source words, e.g. so that the generator can
predict to copy them.
"""
def __init__(self, fields, readers, data, sort_key, filter_pred=None):
self.sort_key = sort_key
can_copy = 'src_map' in fields and 'alignment' in fields
read_iters = [r.read(dat, name, feats)
for r, (name, dat, feats) in zip(readers, data)]
# self.src_vocabs is used in collapse_copy_scores and Translator.py
self.src_vocabs = []
examples = []
for ex_dict in starmap(_join_dicts, zip(*read_iters)):
if can_copy:
src_field = fields['src']
tgt_field = fields['tgt']
# this assumes src_field and tgt_field are both text
ex_dict = _dynamic_dict(
ex_dict, src_field.base_field, tgt_field.base_field)
self.src_vocabs.append(ex_dict["src_ex_vocab"])
ex_fields = {k: [(k, v)] for k, v in fields.items() if
k in ex_dict}
ex = Example.fromdict(ex_dict, ex_fields)
examples.append(ex)
# fields needs to have only keys that examples have as attrs
fields = []
for _, nf_list in ex_fields.items():
assert len(nf_list) == 1
fields.append(nf_list[0])
super(Dataset, self).__init__(examples, fields, filter_pred)
def __getattr__(self, attr):
# avoid infinite recursion when fields isn't defined
if 'fields' not in vars(self):
raise AttributeError
if attr in self.fields:
return (getattr(x, attr) for x in self.examples)
else:
raise AttributeError
def save(self, path, remove_fields=True):
if remove_fields:
self.fields = []
torch.save(self, path)
@staticmethod
def config(fields):
readers, data = [], []
for name, field in fields:
if field["data"] is not None:
readers.append(field["reader"])
data.append((name, field["data"], field.get("features", {})))
return readers, data
class DynamicDataset(Dataset):
def __init__(self, fields, data, sort_key, filter_pred=None):
self.sort_key = sort_key
can_copy = 'src_map' in fields and 'alignment' in fields
# self.src_vocabs is used in collapse_copy_scores and Translator.py
self.src_vocabs = []
examples = []
for ex_dict in data:
if can_copy:
src_field = fields['src']
tgt_field = fields['tgt']
# this assumes src_field and tgt_field are both text
ex_dict = _dynamic_dict(
ex_dict, src_field.base_field, tgt_field.base_field)
self.src_vocabs.append(ex_dict["src_ex_vocab"])
ex_fields = {k: [(k, v)] for k, v in fields.items() if
k in ex_dict}
ex = Example.fromdict(ex_dict, ex_fields)
examples.append(ex)
# fields needs to have only keys that examples have as attrs
fields = []
for _, nf_list in ex_fields.items():
assert len(nf_list) == 1
fields.append(nf_list[0])
super(Dataset, self).__init__(examples, fields, filter_pred)
def __getattr__(self, attr):
# avoid infinite recursion when fields isn't defined
if 'fields' not in vars(self):
raise AttributeError
if attr in self.fields:
return (getattr(x, attr) for x in self.examples)
else:
raise AttributeError
def save(self, path, remove_fields=True):
if remove_fields:
self.fields = []
torch.save(self, path)
@staticmethod
def config(fields):
readers, data = [], []
for name, field in fields:
if field["data"] is not None:
readers.append(field["reader"])
data.append((name, field["data"], field.get("features", {})))
return readers, data
|