# -*- coding: utf-8 -*- from functools import partial from itertools import repeat import torch from torchtext.data import Field, RawField from onmt.constants import DefaultTokens from onmt.inputters.datareader_base import DataReaderBase from onmt.utils.misc import split_corpus class TextDataReader(DataReaderBase): def read(self, sequences, side, features={}): """Read text data from disk. Args: sequences (str or Iterable[str]): path to text file or iterable of the actual text data. side (str): Prefix used in return dict. Usually ``"src"`` or ``"tgt"``. features: (Dict[str or Iterable[str]]): dictionary mapping feature names with the path to feature file or iterable of the actual feature data. Yields: dictionaries whose keys are the names of fields and whose values are more or less the result of tokenizing with those fields. """ if isinstance(sequences, str): sequences = DataReaderBase._read_file(sequences) features_names = [] features_values = [] for feat_name, v in features.items(): features_names.append(feat_name) if isinstance(v, str): features_values.append(DataReaderBase._read_file(features)) else: features_values.append(v) for i, (seq, *feats) in enumerate(zip(sequences, *features_values)): ex_dict = {} if isinstance(seq, bytes): seq = seq.decode("utf-8") ex_dict[side] = seq for j, f in enumerate(feats): if isinstance(f, bytes): f = f.decode("utf-8") ex_dict[features_names[j]] = f yield {side: ex_dict, "indices": i} class InferenceDataReader(object): """It handles inference data reading from disk in shards. Args: src (str): path to the source file tgt (str or NoneType): path to the target file src_feats (Dict[str]): paths to the features files shard_size (int): divides files into smaller files of size shard_size Returns: Tuple[List[str], List[str], Dict[List[str]]] """ def __init__(self, src, tgt, src_feats={}, shard_size=10000): self.src = src self.tgt = tgt self.src_feats = src_feats self.shard_size = shard_size def __iter__(self): src_shards = split_corpus(self.src, self.shard_size) tgt_shards = split_corpus(self.tgt, self.shard_size) if not self.src_feats: features_shards = [repeat(None)] else: features_shards = [] features_names = [] for feat_name, feat_path in self.src_feats.items(): features_shards.append( split_corpus(feat_path, self.shard_size)) features_names.append(feat_name) shard_pairs = zip(src_shards, tgt_shards, *features_shards) for i, shard in enumerate(shard_pairs): src_shard, tgt_shard, *features_shard = shard if features_shard[0] is not None: features_shard_ = dict() for j, x in enumerate(features_shard): features_shard_[features_names[j]] = x else: features_shard_ = None yield src_shard, tgt_shard, features_shard_ class InferenceDataIterator(object): def __init__(self, src, tgt, src_feats, transform): self.src = src self.tgt = tgt self.src_feats = src_feats self.transform = transform def _tokenize(self, example): example['src'] = example['src'].decode("utf-8").strip('\n').split() example['tgt'] = example['tgt'].decode("utf-8").strip('\n').split() \ if example["tgt"] is not None else None example['src_original'] = example['src'] example['tgt_original'] = example['tgt'] if 'src_feats' in example: for k in example['src_feats'].keys(): example['src_feats'][k] = example['src_feats'][k] \ .decode("utf-8").strip('\n').split() \ if example['src_feats'][k] is not None else None return example def _transform(self, example, remove_tgt=False): maybe_example = self.transform.apply( example, is_train=False, corpus_name="translate") assert maybe_example is not None, \ "Transformation on example skipped the example. " \ "Please check the transforms." return maybe_example def _process(self, example, remove_tgt=False): example['src'] = {"src": ' '.join(example['src'])} example['tgt'] = {"tgt": ' '.join(example['tgt'])} # Make features part of src as in TextMultiField # {'src': {'src': ..., 'feat1': ...., 'feat2': ....}} if 'src_feats' in example: for feat_name, feat_value in example['src_feats'].items(): example['src'][feat_name] = ' '.join(feat_value) del example["src_feats"] # Cleanup if remove_tgt: del example["tgt"] del example["tgt_original"] del example["src_original"] return example def __iter__(self): tgt = self.tgt if self.tgt is not None else repeat(None) if self.src_feats is not None: features_names = [] features_values = [] for feat_name, values in self.src_feats.items(): features_names.append(feat_name) features_values.append(values) else: features_values = [repeat(None)] for i, (src, tgt, *src_feats) in enumerate(zip( self.src, tgt, *features_values)): ex = { "src": src, "tgt": tgt if tgt is not None else b"" } if src_feats[0] is not None: src_feats_ = {} for j, x in enumerate(src_feats): src_feats_[features_names[j]] = x ex["src_feats"] = src_feats_ ex = self._tokenize(ex) ex = self._transform(ex) ex = self._process(ex, remove_tgt=self.tgt is None) ex["indices"] = i yield ex def text_sort_key(ex): """Sort using the number of tokens in the sequence.""" if hasattr(ex, "tgt"): return len(ex.src[0]), len(ex.tgt[0]) return len(ex.src[0]) # Legacy function. Currently it only truncates input if truncate is set. # mix this with partial def _feature_tokenize( string, layer=0, tok_delim=None, feat_delim=None, truncate=None): """Split apart word features (like POS/NER tags) from the tokens. Args: string (str): A string with ``tok_delim`` joining tokens and features joined by ``feat_delim``. For example, ``"hello|NOUN|'' Earth|NOUN|PLANET"``. layer (int): Which feature to extract. (Not used if there are no features, indicated by ``feat_delim is None``). In the example above, layer 2 is ``'' PLANET``. truncate (int or NoneType): Restrict sequences to this length of tokens. Returns: List[str] of tokens. """ tokens = string.split(tok_delim) if truncate is not None: tokens = tokens[:truncate] if feat_delim is not None: tokens = [t.split(feat_delim)[layer] for t in tokens] return tokens class TextMultiField(RawField): """Container for subfields. Text data might use POS/NER/etc labels in addition to tokens. This class associates the "base" :class:`Field` with any subfields. It also handles padding the data and stacking it. Args: base_name (str): Name for the base field. base_field (Field): The token field. feats_fields (Iterable[Tuple[str, Field]]): A list of name-field pairs. Attributes: fields (Iterable[Tuple[str, Field]]): A list of name-field pairs. The order is defined as the base field first, then ``feats_fields`` in alphabetical order. """ def __init__(self, base_name, base_field, feats_fields): super(TextMultiField, self).__init__() self.fields = [(base_name, base_field)] for name, ff in sorted(feats_fields, key=lambda kv: kv[0]): self.fields.append((name, ff)) @property def base_field(self): return self.fields[0][1] def process(self, batch, device=None): """Convert outputs of preprocess into Tensors. Args: batch (List[List[List[str]]]): A list of length batch size. Each element is a list of the preprocess results for each field (which are lists of str "words" or feature tags. device (torch.device or str): The device on which the tensor(s) are built. Returns: torch.LongTensor or Tuple[LongTensor, LongTensor]: A tensor of shape ``(seq_len, batch_size, len(self.fields))`` where the field features are ordered like ``self.fields``. If the base field returns lengths, these are also returned and have shape ``(batch_size,)``. """ # batch (list(list(list))): batch_size x len(self.fields) x seq_len batch_by_feat = list(zip(*batch)) base_data = self.base_field.process(batch_by_feat[0], device=device) if self.base_field.include_lengths: # lengths: batch_size base_data, lengths = base_data feats = [ff.process(batch_by_feat[i], device=device) for i, (_, ff) in enumerate(self.fields[1:], 1)] levels = [base_data] + feats # data: seq_len x batch_size x len(self.fields) data = torch.stack(levels, 2) if self.base_field.include_lengths: return data, lengths else: return data def preprocess(self, x): """Preprocess data. Args: x (str): A sentence string (words joined by whitespace). Returns: List[List[str]]: A list of length ``len(self.fields)`` containing lists of tokens/feature tags for the sentence. The output is ordered like ``self.fields``. """ return [f.preprocess(x[fn]) for fn, f in self.fields] def __getitem__(self, item): return self.fields[item] def text_fields(**kwargs): """Create text fields. Args: base_name (str): Name associated with the field. feats (Optional[Dict]): Word level feats include_lengths (bool): Optionally return the sequence lengths. pad (str, optional): Defaults to ``""``. bos (str or NoneType, optional): Defaults to ``""``. eos (str or NoneType, optional): Defaults to ``""``. truncate (bool or NoneType, optional): Defaults to ``None``. Returns: TextMultiField """ feats = kwargs["feats"] include_lengths = kwargs["include_lengths"] base_name = kwargs["base_name"] pad = kwargs.get("pad", DefaultTokens.PAD) bos = kwargs.get("bos", DefaultTokens.BOS) eos = kwargs.get("eos", DefaultTokens.EOS) truncate = kwargs.get("truncate", None) fields_ = [] feat_delim = None # u"│" if n_feats > 0 else None # Base field tokenize = partial( _feature_tokenize, layer=None, truncate=truncate, feat_delim=feat_delim) feat = Field( init_token=bos, eos_token=eos, pad_token=pad, tokenize=tokenize, include_lengths=include_lengths) fields_.append((base_name, feat)) # Feats fields if feats: for feat_name in feats.keys(): # Legacy function, it is not really necessary tokenize = partial( _feature_tokenize, layer=None, truncate=truncate, feat_delim=feat_delim) feat = Field( init_token=bos, eos_token=eos, pad_token=pad, tokenize=tokenize, include_lengths=False) fields_.append((feat_name, feat)) assert fields_[0][0] == base_name # sanity check field = TextMultiField(fields_[0][0], fields_[0][1], fields_[1:]) return field