File size: 4,253 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from onmt.utils.logging import logger
from onmt.transforms import register_transform
from .transform import Transform, ObservableStats
class FilterTooLongStats(ObservableStats):
"""Runing statistics for FilterTooLongTransform."""
__slots__ = ["filtered"]
def __init__(self):
self.filtered = 1
def update(self, other: "FilterTooLongStats"):
self.filtered += other.filtered
@register_transform(name='filtertoolong')
class FilterTooLongTransform(Transform):
"""Filter out sentence that are too long."""
def __init__(self, opts):
super().__init__(opts)
@classmethod
def add_options(cls, parser):
"""Avalilable options relate to this Transform."""
group = parser.add_argument_group("Transform/Filter")
group.add("--src_seq_length", "-src_seq_length", type=int, default=200,
help="Maximum source sequence length.")
group.add("--tgt_seq_length", "-tgt_seq_length", type=int, default=200,
help="Maximum target sequence length.")
def _parse_opts(self):
self.src_seq_length = self.opts.src_seq_length
self.tgt_seq_length = self.opts.tgt_seq_length
def apply(self, example, is_train=False, stats=None, **kwargs):
"""Return None if too long else return as is."""
if (len(example['src']) > self.src_seq_length or
len(example['tgt']) > self.tgt_seq_length):
if stats is not None:
stats.update(FilterTooLongStats())
return None
else:
return example
def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}, {}={}'.format(
'src_seq_length', self.src_seq_length,
'tgt_seq_length', self.tgt_seq_length
)
@register_transform(name='prefix')
class PrefixTransform(Transform):
"""Add Prefix to src (& tgt) sentence."""
def __init__(self, opts):
super().__init__(opts)
@staticmethod
def _get_prefix(corpus):
"""Get prefix string of a `corpus`."""
if 'prefix' in corpus['transforms']:
prefix = {
'src': corpus['src_prefix'],
'tgt': corpus['tgt_prefix']
}
else:
prefix = None
return prefix
@classmethod
def get_prefix_dict(cls, opts):
"""Get all needed prefix correspond to corpus in `opts`."""
prefix_dict = {}
for c_name, corpus in opts.data.items():
prefix = cls._get_prefix(corpus)
if prefix is not None:
logger.info(f"Get prefix for {c_name}: {prefix}")
prefix_dict[c_name] = prefix
return prefix_dict
@classmethod
def get_specials(cls, opts):
"""Get special vocabs added by prefix transform."""
prefix_dict = cls.get_prefix_dict(opts)
src_specials, tgt_specials = set(), set()
for _, prefix in prefix_dict.items():
src_specials.update(prefix['src'].split())
tgt_specials.update(prefix['tgt'].split())
return (src_specials, tgt_specials)
def warm_up(self, vocabs=None):
"""Warm up to get prefix dictionary."""
super().warm_up(None)
self.prefix_dict = self.get_prefix_dict(self.opts)
def _prepend(self, example, prefix):
"""Prepend `prefix` to `tokens`."""
for side, side_prefix in prefix.items():
example[side] = side_prefix.split() + example[side]
return example
def apply(self, example, is_train=False, stats=None, **kwargs):
"""Apply prefix prepend to example.
Should provide `corpus_name` to get correspond prefix.
"""
corpus_name = kwargs.get('corpus_name', None)
if corpus_name is None:
raise ValueError('corpus_name is required.')
corpus_prefix = self.prefix_dict.get(corpus_name, None)
if corpus_prefix is None:
raise ValueError(f'prefix for {corpus_name} does not exist.')
return self._prepend(example, corpus_prefix)
def _repr_args(self):
"""Return str represent key arguments for class."""
return '{}={}'.format('prefix_dict', self.prefix_dict)
|