# coding: utf-8 from itertools import chain, starmap from collections import Counter import torch from torchtext.data import Dataset as TorchtextDataset from torchtext.data import Example from torchtext.vocab import Vocab def _join_dicts(*args): """ Args: dictionaries with disjoint keys. Returns: a single dictionary that has the union of these keys. """ return dict(chain(*[d.items() for d in args])) def _dynamic_dict(example, src_field, tgt_field): """Create copy-vocab and numericalize with it. In-place adds ``"src_map"`` to ``example``. That is the copy-vocab numericalization of the tokenized ``example["src"]``. If ``example`` has a ``"tgt"`` key, adds ``"alignment"`` to example. That is the copy-vocab numericalization of the tokenized ``example["tgt"]``. The alignment has an initial and final UNK token to match the BOS and EOS tokens. Args: example (dict): An example dictionary with a ``"src"`` key and maybe a ``"tgt"`` key. (This argument changes in place!) src_field (torchtext.data.Field): Field object. tgt_field (torchtext.data.Field): Field object. Returns: ``example``, changed as described. """ src = src_field.tokenize(example["src"]["src"]) # make a small vocab containing just the tokens in the source sequence unk = src_field.unk_token pad = src_field.pad_token # add init_token and eos_token according to src construction if src_field.init_token: src = [src_field.init_token] + src if src_field.eos_token: src.append(src_field.eos_token) src_ex_vocab = Vocab(Counter(src), specials=[unk, pad]) unk_idx = src_ex_vocab.stoi[unk] # Map source tokens to indices in the dynamic dict. src_map = torch.LongTensor([src_ex_vocab.stoi[w] for w in src]) example["src_map"] = src_map example["src_ex_vocab"] = src_ex_vocab if "tgt" in example: tgt = tgt_field.tokenize(example["tgt"]["tgt"]) mask = torch.LongTensor( [unk_idx] + [src_ex_vocab.stoi[w] for w in tgt] + [unk_idx]) example["alignment"] = mask return example class Dataset(TorchtextDataset): """Contain data and process it. A dataset is an object that accepts sequences of raw data (sentence pairs in the case of machine translation) and fields which describe how this raw data should be processed to produce tensors. When a dataset is instantiated, it applies the fields' preprocessing pipeline (but not the bit that numericalizes it or turns it into batch tensors) to the raw data, producing a list of :class:`torchtext.data.Example` objects. torchtext's iterators then know how to use these examples to make batches. Args: fields (dict[str, Field]): a dict with the structure returned by :func:`onmt.inputters.get_fields()`. Usually that means the dataset side, ``"src"`` or ``"tgt"``. Keys match the keys of items yielded by the ``readers``, while values are lists of (name, Field) pairs. An attribute with this name will be created for each :class:`torchtext.data.Example` object and its value will be the result of applying the Field to the data that matches the key. The advantage of having sequences of fields for each piece of raw input is that it allows the dataset to store multiple "views" of each input, which allows for easy implementation of token-level features, mixed word- and character-level models, and so on. (See also :class:`onmt.inputters.TextMultiField`.) readers (Iterable[onmt.inputters.DataReaderBase]): Reader objects for disk-to-dict. The yielded dicts are then processed according to ``fields``. data (Iterable[Tuple[str, Any]]): (name, ``data_arg``) pairs where ``data_arg`` is passed to the ``read()`` method of the reader in ``readers`` at that position. (See the reader object for details on the ``Any`` type.) sort_key (Callable[[torchtext.data.Example], Any]): A function for determining the value on which data is sorted (i.e. length). filter_pred (Callable[[torchtext.data.Example], bool]): A function that accepts Example objects and returns a boolean value indicating whether to include that example in the dataset. Attributes: src_vocabs (List[torchtext.data.Vocab]): Used with dynamic dict/copy attention. There is a very short vocab for each src example. It contains just the source words, e.g. so that the generator can predict to copy them. """ def __init__(self, fields, readers, data, sort_key, filter_pred=None): self.sort_key = sort_key can_copy = 'src_map' in fields and 'alignment' in fields read_iters = [r.read(dat, name, feats) for r, (name, dat, feats) in zip(readers, data)] # self.src_vocabs is used in collapse_copy_scores and Translator.py self.src_vocabs = [] examples = [] for ex_dict in starmap(_join_dicts, zip(*read_iters)): if can_copy: src_field = fields['src'] tgt_field = fields['tgt'] # this assumes src_field and tgt_field are both text ex_dict = _dynamic_dict( ex_dict, src_field.base_field, tgt_field.base_field) self.src_vocabs.append(ex_dict["src_ex_vocab"]) ex_fields = {k: [(k, v)] for k, v in fields.items() if k in ex_dict} ex = Example.fromdict(ex_dict, ex_fields) examples.append(ex) # fields needs to have only keys that examples have as attrs fields = [] for _, nf_list in ex_fields.items(): assert len(nf_list) == 1 fields.append(nf_list[0]) super(Dataset, self).__init__(examples, fields, filter_pred) def __getattr__(self, attr): # avoid infinite recursion when fields isn't defined if 'fields' not in vars(self): raise AttributeError if attr in self.fields: return (getattr(x, attr) for x in self.examples) else: raise AttributeError def save(self, path, remove_fields=True): if remove_fields: self.fields = [] torch.save(self, path) @staticmethod def config(fields): readers, data = [], [] for name, field in fields: if field["data"] is not None: readers.append(field["reader"]) data.append((name, field["data"], field.get("features", {}))) return readers, data class DynamicDataset(Dataset): def __init__(self, fields, data, sort_key, filter_pred=None): self.sort_key = sort_key can_copy = 'src_map' in fields and 'alignment' in fields # self.src_vocabs is used in collapse_copy_scores and Translator.py self.src_vocabs = [] examples = [] for ex_dict in data: if can_copy: src_field = fields['src'] tgt_field = fields['tgt'] # this assumes src_field and tgt_field are both text ex_dict = _dynamic_dict( ex_dict, src_field.base_field, tgt_field.base_field) self.src_vocabs.append(ex_dict["src_ex_vocab"]) ex_fields = {k: [(k, v)] for k, v in fields.items() if k in ex_dict} ex = Example.fromdict(ex_dict, ex_fields) examples.append(ex) # fields needs to have only keys that examples have as attrs fields = [] for _, nf_list in ex_fields.items(): assert len(nf_list) == 1 fields.append(nf_list[0]) super(Dataset, self).__init__(examples, fields, filter_pred) def __getattr__(self, attr): # avoid infinite recursion when fields isn't defined if 'fields' not in vars(self): raise AttributeError if attr in self.fields: return (getattr(x, attr) for x in self.examples) else: raise AttributeError def save(self, path, remove_fields=True): if remove_fields: self.fields = [] torch.save(self, path) @staticmethod def config(fields): readers, data = [], [] for name, field in fields: if field["data"] is not None: readers.append(field["reader"]) data.append((name, field["data"], field.get("features", {}))) return readers, data