File size: 3,432 Bytes
158b61b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from onmt.utils.logging import logger
from onmt.transforms import register_transform
from .transform import Transform
from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer
import re
from collections import defaultdict


@register_transform(name='filterfeats')
class FilterFeatsTransform(Transform):
    """Filter out examples with a mismatch between source and features."""

    def __init__(self, opts):
        super().__init__(opts)

    @classmethod
    def add_options(cls, parser):
        pass

    def _parse_opts(self):
        pass

    def apply(self, example, is_train=False, stats=None, **kwargs):
        """Return None if mismatch"""

        if 'src_feats' not in example:
            # Do nothing
            return example

        for feat_name, feat_values in example['src_feats'].items():
            if len(example['src']) != len(feat_values):
                logger.warning(
                    f"Skipping example due to mismatch "
                    f"between source and feature {feat_name}")
                return None
        return example

    def _repr_args(self):
        return ''


@register_transform(name='inferfeats')
class InferFeatsTransform(Transform):
    """Infer features for subword tokenization."""

    def __init__(self, opts):
        super().__init__(opts)

    @classmethod
    def add_options(cls, parser):
        """Avalilable options related to this Transform."""
        group = parser.add_argument_group("Transform/InferFeats")
        group.add("--reversible_tokenization", "-reversible_tokenization",
                  default="joiner", choices=["joiner", "spacer"],
                  help="Type of reversible tokenization "
                       "applied on the tokenizer.")
        group.add("--prior_tokenization", "-prior_tokenization",
                  default=False, action="store_true",
                  help="Whether the input has already been tokenized.")

    def _parse_opts(self):
        super()._parse_opts()
        self.reversible_tokenization = self.opts.reversible_tokenization
        self.prior_tokenization = self.opts.prior_tokenization

    def apply(self, example, is_train=False, stats=None, **kwargs):

        if "src_feats" not in example:
            # Do nothing
            return example

        if self.reversible_tokenization == "joiner":
            original_src = example["src_original"] \
                if self.prior_tokenization else None
            word_to_subword_mapping = subword_map_by_joiner(
                example["src"], original_subwords=original_src)
        else:  # Spacer
            word_to_subword_mapping = subword_map_by_spacer(example["src"])

        inferred_feats = defaultdict(list)
        for subword, word_id in zip(example["src"], word_to_subword_mapping):
            for feat_name, feat_values in example["src_feats"].items():
                # Punctuation only
                if not re.sub(r'(\W)+', '', subword).strip() \
                        and not self.prior_tokenization:
                    inferred_feat = "<null>"
                else:
                    inferred_feat = feat_values[word_id]

                inferred_feats[feat_name].append(inferred_feat)

        for feat_name, feat_values in inferred_feats.items():
            example["src_feats"][feat_name] = inferred_feats[feat_name]

        return example

    def _repr_args(self):
        return ''