|
|
|
|
|
from itertools import chain, starmap |
|
from collections import Counter |
|
|
|
import torch |
|
from torchtext.data import Dataset as TorchtextDataset |
|
from torchtext.data import Example |
|
from torchtext.vocab import Vocab |
|
|
|
|
|
def _join_dicts(*args): |
|
""" |
|
Args: |
|
dictionaries with disjoint keys. |
|
|
|
Returns: |
|
a single dictionary that has the union of these keys. |
|
""" |
|
|
|
return dict(chain(*[d.items() for d in args])) |
|
|
|
|
|
def _dynamic_dict(example, src_field, tgt_field): |
|
"""Create copy-vocab and numericalize with it. |
|
|
|
In-place adds ``"src_map"`` to ``example``. That is the copy-vocab |
|
numericalization of the tokenized ``example["src"]``. If ``example`` |
|
has a ``"tgt"`` key, adds ``"alignment"`` to example. That is the |
|
copy-vocab numericalization of the tokenized ``example["tgt"]``. The |
|
alignment has an initial and final UNK token to match the BOS and EOS |
|
tokens. |
|
|
|
Args: |
|
example (dict): An example dictionary with a ``"src"`` key and |
|
maybe a ``"tgt"`` key. (This argument changes in place!) |
|
src_field (torchtext.data.Field): Field object. |
|
tgt_field (torchtext.data.Field): Field object. |
|
|
|
Returns: |
|
``example``, changed as described. |
|
""" |
|
|
|
src = src_field.tokenize(example["src"]["src"]) |
|
|
|
unk = src_field.unk_token |
|
pad = src_field.pad_token |
|
|
|
|
|
if src_field.init_token: |
|
src = [src_field.init_token] + src |
|
if src_field.eos_token: |
|
src.append(src_field.eos_token) |
|
|
|
src_ex_vocab = Vocab(Counter(src), specials=[unk, pad]) |
|
unk_idx = src_ex_vocab.stoi[unk] |
|
|
|
src_map = torch.LongTensor([src_ex_vocab.stoi[w] for w in src]) |
|
example["src_map"] = src_map |
|
example["src_ex_vocab"] = src_ex_vocab |
|
|
|
if "tgt" in example: |
|
tgt = tgt_field.tokenize(example["tgt"]["tgt"]) |
|
mask = torch.LongTensor( |
|
[unk_idx] + [src_ex_vocab.stoi[w] for w in tgt] + [unk_idx]) |
|
example["alignment"] = mask |
|
return example |
|
|
|
|
|
class Dataset(TorchtextDataset): |
|
"""Contain data and process it. |
|
|
|
A dataset is an object that accepts sequences of raw data (sentence pairs |
|
in the case of machine translation) and fields which describe how this |
|
raw data should be processed to produce tensors. When a dataset is |
|
instantiated, it applies the fields' preprocessing pipeline (but not |
|
the bit that numericalizes it or turns it into batch tensors) to the raw |
|
data, producing a list of :class:`torchtext.data.Example` objects. |
|
torchtext's iterators then know how to use these examples to make batches. |
|
|
|
Args: |
|
fields (dict[str, Field]): a dict with the structure |
|
returned by :func:`onmt.inputters.get_fields()`. Usually |
|
that means the dataset side, ``"src"`` or ``"tgt"``. Keys match |
|
the keys of items yielded by the ``readers``, while values |
|
are lists of (name, Field) pairs. An attribute with this |
|
name will be created for each :class:`torchtext.data.Example` |
|
object and its value will be the result of applying the Field |
|
to the data that matches the key. The advantage of having |
|
sequences of fields for each piece of raw input is that it allows |
|
the dataset to store multiple "views" of each input, which allows |
|
for easy implementation of token-level features, mixed word- |
|
and character-level models, and so on. (See also |
|
:class:`onmt.inputters.TextMultiField`.) |
|
readers (Iterable[onmt.inputters.DataReaderBase]): Reader objects |
|
for disk-to-dict. The yielded dicts are then processed |
|
according to ``fields``. |
|
data (Iterable[Tuple[str, Any]]): (name, ``data_arg``) pairs |
|
where ``data_arg`` is passed to the ``read()`` method of the |
|
reader in ``readers`` at that position. (See the reader object for |
|
details on the ``Any`` type.) |
|
sort_key (Callable[[torchtext.data.Example], Any]): A function |
|
for determining the value on which data is sorted (i.e. length). |
|
filter_pred (Callable[[torchtext.data.Example], bool]): A function |
|
that accepts Example objects and returns a boolean value |
|
indicating whether to include that example in the dataset. |
|
|
|
Attributes: |
|
src_vocabs (List[torchtext.data.Vocab]): Used with dynamic dict/copy |
|
attention. There is a very short vocab for each src example. |
|
It contains just the source words, e.g. so that the generator can |
|
predict to copy them. |
|
""" |
|
|
|
def __init__(self, fields, readers, data, sort_key, filter_pred=None): |
|
self.sort_key = sort_key |
|
can_copy = 'src_map' in fields and 'alignment' in fields |
|
|
|
read_iters = [r.read(dat, name, feats) |
|
for r, (name, dat, feats) in zip(readers, data)] |
|
|
|
|
|
self.src_vocabs = [] |
|
examples = [] |
|
for ex_dict in starmap(_join_dicts, zip(*read_iters)): |
|
if can_copy: |
|
src_field = fields['src'] |
|
tgt_field = fields['tgt'] |
|
|
|
ex_dict = _dynamic_dict( |
|
ex_dict, src_field.base_field, tgt_field.base_field) |
|
self.src_vocabs.append(ex_dict["src_ex_vocab"]) |
|
ex_fields = {k: [(k, v)] for k, v in fields.items() if |
|
k in ex_dict} |
|
ex = Example.fromdict(ex_dict, ex_fields) |
|
examples.append(ex) |
|
|
|
|
|
fields = [] |
|
for _, nf_list in ex_fields.items(): |
|
assert len(nf_list) == 1 |
|
fields.append(nf_list[0]) |
|
|
|
super(Dataset, self).__init__(examples, fields, filter_pred) |
|
|
|
def __getattr__(self, attr): |
|
|
|
if 'fields' not in vars(self): |
|
raise AttributeError |
|
if attr in self.fields: |
|
return (getattr(x, attr) for x in self.examples) |
|
else: |
|
raise AttributeError |
|
|
|
def save(self, path, remove_fields=True): |
|
if remove_fields: |
|
self.fields = [] |
|
torch.save(self, path) |
|
|
|
@staticmethod |
|
def config(fields): |
|
readers, data = [], [] |
|
for name, field in fields: |
|
if field["data"] is not None: |
|
readers.append(field["reader"]) |
|
data.append((name, field["data"], field.get("features", {}))) |
|
return readers, data |
|
|
|
|
|
class DynamicDataset(Dataset): |
|
|
|
def __init__(self, fields, data, sort_key, filter_pred=None): |
|
self.sort_key = sort_key |
|
can_copy = 'src_map' in fields and 'alignment' in fields |
|
|
|
|
|
self.src_vocabs = [] |
|
examples = [] |
|
for ex_dict in data: |
|
if can_copy: |
|
src_field = fields['src'] |
|
tgt_field = fields['tgt'] |
|
|
|
ex_dict = _dynamic_dict( |
|
ex_dict, src_field.base_field, tgt_field.base_field) |
|
self.src_vocabs.append(ex_dict["src_ex_vocab"]) |
|
ex_fields = {k: [(k, v)] for k, v in fields.items() if |
|
k in ex_dict} |
|
ex = Example.fromdict(ex_dict, ex_fields) |
|
examples.append(ex) |
|
|
|
|
|
fields = [] |
|
for _, nf_list in ex_fields.items(): |
|
assert len(nf_list) == 1 |
|
fields.append(nf_list[0]) |
|
|
|
super(Dataset, self).__init__(examples, fields, filter_pred) |
|
|
|
def __getattr__(self, attr): |
|
|
|
if 'fields' not in vars(self): |
|
raise AttributeError |
|
if attr in self.fields: |
|
return (getattr(x, attr) for x in self.examples) |
|
else: |
|
raise AttributeError |
|
|
|
def save(self, path, remove_fields=True): |
|
if remove_fields: |
|
self.fields = [] |
|
torch.save(self, path) |
|
|
|
@staticmethod |
|
def config(fields): |
|
readers, data = [], [] |
|
for name, field in fields: |
|
if field["data"] is not None: |
|
readers.append(field["reader"]) |
|
data.append((name, field["data"], field.get("features", {}))) |
|
return readers, data |
|
|