sakharamg's picture
Uploading all files
158b61b
from onmt.utils.logging import logger
from onmt.transforms import register_transform
from .transform import Transform
from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer
import re
from collections import defaultdict
@register_transform(name='filterfeats')
class FilterFeatsTransform(Transform):
"""Filter out examples with a mismatch between source and features."""
def __init__(self, opts):
super().__init__(opts)
@classmethod
def add_options(cls, parser):
pass
def _parse_opts(self):
pass
def apply(self, example, is_train=False, stats=None, **kwargs):
"""Return None if mismatch"""
if 'src_feats' not in example:
# Do nothing
return example
for feat_name, feat_values in example['src_feats'].items():
if len(example['src']) != len(feat_values):
logger.warning(
f"Skipping example due to mismatch "
f"between source and feature {feat_name}")
return None
return example
def _repr_args(self):
return ''
@register_transform(name='inferfeats')
class InferFeatsTransform(Transform):
"""Infer features for subword tokenization."""
def __init__(self, opts):
super().__init__(opts)
@classmethod
def add_options(cls, parser):
"""Avalilable options related to this Transform."""
group = parser.add_argument_group("Transform/InferFeats")
group.add("--reversible_tokenization", "-reversible_tokenization",
default="joiner", choices=["joiner", "spacer"],
help="Type of reversible tokenization "
"applied on the tokenizer.")
group.add("--prior_tokenization", "-prior_tokenization",
default=False, action="store_true",
help="Whether the input has already been tokenized.")
def _parse_opts(self):
super()._parse_opts()
self.reversible_tokenization = self.opts.reversible_tokenization
self.prior_tokenization = self.opts.prior_tokenization
def apply(self, example, is_train=False, stats=None, **kwargs):
if "src_feats" not in example:
# Do nothing
return example
if self.reversible_tokenization == "joiner":
original_src = example["src_original"] \
if self.prior_tokenization else None
word_to_subword_mapping = subword_map_by_joiner(
example["src"], original_subwords=original_src)
else: # Spacer
word_to_subword_mapping = subword_map_by_spacer(example["src"])
inferred_feats = defaultdict(list)
for subword, word_id in zip(example["src"], word_to_subword_mapping):
for feat_name, feat_values in example["src_feats"].items():
# Punctuation only
if not re.sub(r'(\W)+', '', subword).strip() \
and not self.prior_tokenization:
inferred_feat = "<null>"
else:
inferred_feat = feat_values[word_id]
inferred_feats[feat_name].append(inferred_feat)
for feat_name, feat_values in inferred_feats.items():
example["src_feats"][feat_name] = inferred_feats[feat_name]
return example
def _repr_args(self):
return ''