|
"""Transforms relate to noising from BART: based on code of fairseq.""" |
|
import math |
|
import numpy as np |
|
import torch |
|
|
|
from typing import Sequence, Callable |
|
from onmt.constants import DefaultTokens, SubwordMarker |
|
from onmt.transforms import register_transform |
|
from .transform import Transform |
|
|
|
|
|
def _subword_start_by_joiner(tokens: Sequence[str]) -> Sequence[bool]: |
|
"""Find word start in a subword list marked by joiner.""" |
|
flag = [True] * len(tokens) |
|
for i, token in enumerate(tokens): |
|
if token.startswith(SubwordMarker.JOINER) and i != 0: |
|
flag[i] = False |
|
if token.endswith(SubwordMarker.JOINER): |
|
try: |
|
flag[i+1] = False |
|
except IndexError: |
|
print("Sentence `{}` not correct!".format(" ".join(token))) |
|
raise |
|
return flag |
|
|
|
|
|
def _subword_start_by_spacer(tokens: Sequence[str]) -> Sequence[bool]: |
|
"""Find word start in a subword list marked by spacer(as prefix).""" |
|
flag = [x.startswith(SubwordMarker.SPACER) for x in tokens] |
|
flag[0] = True |
|
return flag |
|
|
|
|
|
def word_start_finder(ignore_subword=False, is_joiner=False) -> Callable: |
|
"""Return callable to find all word start in the token list.""" |
|
if not ignore_subword: |
|
if is_joiner: |
|
return _subword_start_by_joiner |
|
else: |
|
return _subword_start_by_spacer |
|
else: |
|
return lambda tokens: [True] * len(tokens) |
|
|
|
|
|
class BARTNoising(object): |
|
"""Noise from BART.""" |
|
|
|
def __init__(self, vocab, mask_tok=DefaultTokens.MASK, mask_ratio=0.0, |
|
insert_ratio=0.0, permute_sent_ratio=0.0, poisson_lambda=3.0, |
|
replace_length=-1, rotate_ratio=0.0, mask_length='subword', |
|
random_ratio=0.0, is_joiner=False, |
|
full_stop_token=DefaultTokens.SENT_FULL_STOPS): |
|
if vocab is None: |
|
raise ValueError("Inject BART noise requires a valid vocabulary.") |
|
self.vocab = vocab |
|
|
|
self.mask_tok = mask_tok |
|
|
|
self.mask_ratio = mask_ratio |
|
self.random_ratio = random_ratio |
|
self.insert_ratio = insert_ratio |
|
self.rotate_ratio = rotate_ratio |
|
self.permute_sent_ratio = permute_sent_ratio |
|
|
|
self.full_stop_token = full_stop_token |
|
|
|
|
|
|
|
|
|
if replace_length not in [-1, 0, 1]: |
|
raise ValueError(f'invalid arg: replace_length={replace_length}') |
|
self.replace_length = replace_length |
|
|
|
if mask_length not in ['subword', 'word', 'span-poisson']: |
|
raise ValueError(f'invalid arg: mask-length={mask_length}') |
|
if mask_length == 'subword' and replace_length not in [0, 1]: |
|
raise ValueError('if using subwords, use replace-length=1 or 0') |
|
|
|
if mask_length == 'subword' or is_joiner is None: |
|
|
|
self._is_word_start = word_start_finder(ignore_subword=True) |
|
else: |
|
self._is_word_start = word_start_finder(is_joiner=is_joiner) |
|
|
|
self.mask_span_distribution = None |
|
if mask_length == 'span-poisson': |
|
self.mask_span_distribution = self._make_poisson(poisson_lambda) |
|
self.mask_length = mask_length |
|
self.poisson_lambda = poisson_lambda |
|
|
|
@staticmethod |
|
def set_random_seed(seed): |
|
"""Call this before use to ensure reproducibility.""" |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
|
|
def _make_poisson(self, poisson_lambda): |
|
lambda_to_the_k = 1 |
|
e_to_the_minus_lambda = math.exp(-poisson_lambda) |
|
k_factorial = 1 |
|
ps = [] |
|
for k in range(0, 128): |
|
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial) |
|
lambda_to_the_k *= poisson_lambda |
|
k_factorial *= (k + 1) |
|
if ps[-1] < 0.0000001: |
|
break |
|
ps = torch.FloatTensor(ps) |
|
return torch.distributions.Categorical(ps) |
|
|
|
def _get_sentence_borders(self, tokens): |
|
"""Return lengths of each sentence in the token sequence.""" |
|
full_stops = np.array( |
|
[ |
|
True if token in self.full_stop_token else False |
|
for token in tokens |
|
] |
|
) |
|
|
|
full_stops[-1] = True |
|
|
|
sentence_lens = (full_stops[1:] * ~full_stops[:-1]).nonzero()[0] + 2 |
|
return sentence_lens |
|
|
|
def permute_sentences(self, tokens, p=1.0): |
|
if len(tokens) == 1: |
|
return tokens |
|
sentence_lens = self._get_sentence_borders(tokens) |
|
n_sentences = sentence_lens.size |
|
if n_sentences == 1: |
|
return tokens |
|
|
|
n_to_permute = math.ceil((n_sentences * 2 * p) / 2.0) |
|
|
|
substitutions = np.random.permutation(n_sentences)[:n_to_permute] |
|
ordering = np.arange(0, n_sentences) |
|
ordering[substitutions] = substitutions[np.random.permutation( |
|
n_to_permute)] |
|
|
|
result = [tok for tok in tokens] |
|
index = 0 |
|
for i in ordering: |
|
sentence = tokens[(sentence_lens[i - 1] if i > 0 else 0): |
|
sentence_lens[i]] |
|
result[index:index + len(sentence)] = sentence |
|
index += len(sentence) |
|
assert len(result) == len(tokens), "Error when permute sentences." |
|
return result |
|
|
|
def whole_word_mask(self, tokens, p=1.0): |
|
is_word_start = torch.tensor(self._is_word_start(tokens)).int() |
|
n_mask = int(math.ceil(is_word_start.sum() * p)) |
|
n_insert = 0 |
|
if n_mask == 0: |
|
return tokens |
|
|
|
if self.mask_span_distribution is not None: |
|
lengths = self.mask_span_distribution.sample( |
|
sample_shape=(n_mask,)) |
|
|
|
|
|
cum_length = torch.cumsum(lengths, 0) |
|
while cum_length[-1] < n_mask: |
|
lengths = torch.cat([ |
|
lengths, |
|
self.mask_span_distribution.sample( |
|
sample_shape=(n_mask,)) |
|
], dim=0) |
|
cum_length = torch.cumsum(lengths, 0) |
|
|
|
|
|
i = 0 |
|
while cum_length[i] < n_mask: |
|
i += 1 |
|
lengths[i] = n_mask - (0 if i == 0 else cum_length[i - 1]) |
|
n_mask = i + 1 |
|
lengths = lengths[:n_mask] |
|
|
|
|
|
lengths = lengths[lengths > 0] |
|
n_insert = n_mask - lengths.size(0) |
|
n_mask -= n_insert |
|
if n_mask == 0: |
|
return self.insertion_noise(tokens, n_insert / len(tokens)) |
|
|
|
assert (lengths > 0).all() |
|
else: |
|
lengths = torch.ones((n_mask,)).long() |
|
|
|
word_starts = is_word_start.nonzero(as_tuple=False) |
|
indices = word_starts[ |
|
torch.randperm(word_starts.size(0))[:n_mask] |
|
].squeeze(1) |
|
mask_random = torch.FloatTensor(n_mask).uniform_() < self.random_ratio |
|
|
|
tokens_length = len(tokens) |
|
|
|
to_keep = torch.ones(tokens_length, dtype=torch.bool) |
|
|
|
if self.replace_length == 0: |
|
to_keep[indices] = 0 |
|
else: |
|
|
|
for i in indices.tolist(): |
|
tokens[i] = self.mask_tok |
|
random_tok_ids = torch.randint( |
|
0, len(self.vocab), size=(mask_random.sum(),)).tolist() |
|
for i, rid in zip(indices[mask_random].tolist(), random_tok_ids): |
|
tokens[i] = self.vocab[rid] |
|
|
|
if tokens_length - 1 in indices: |
|
uncompleted = (indices != tokens_length - 1) |
|
indices = indices[uncompleted] |
|
mask_random = mask_random[uncompleted] |
|
lengths = lengths[uncompleted] |
|
|
|
|
|
is_word_start[-1] = 255 |
|
|
|
if self.mask_span_distribution is not None: |
|
assert len(lengths.size()) == 1 |
|
assert lengths.size() == indices.size() |
|
lengths -= 1 |
|
while indices.size(0) > 0: |
|
assert lengths.size() == indices.size() |
|
|
|
lengths -= is_word_start[indices + 1].long() |
|
uncompleted = lengths >= 0 |
|
indices = indices[uncompleted] + 1 |
|
mask_random = mask_random[uncompleted] |
|
lengths = lengths[uncompleted] |
|
if self.replace_length != -1: |
|
|
|
to_keep[indices] = 0 |
|
else: |
|
|
|
for i in indices.tolist(): |
|
tokens[i] = self.mask_tok |
|
random_tok_ids = torch.randint( |
|
0, len(self.vocab), size=(mask_random.sum(),)).tolist() |
|
for i, rid in zip( |
|
indices[mask_random].tolist(), random_tok_ids): |
|
tokens[i] = self.vocab[rid] |
|
else: |
|
|
|
while indices.size(0) > 0: |
|
|
|
uncompleted = is_word_start[indices + 1] == 0 |
|
indices = indices[uncompleted] + 1 |
|
mask_random = mask_random[uncompleted] |
|
if self.replace_length != -1: |
|
|
|
to_keep[indices] = 0 |
|
else: |
|
|
|
for i in indices.tolist(): |
|
tokens[i] = self.mask_tok |
|
random_tok_ids = torch.randint( |
|
0, len(self.vocab), size=(mask_random.sum(),)).tolist() |
|
for i, rid in zip( |
|
indices[mask_random].tolist(), random_tok_ids): |
|
tokens[i] = self.vocab[rid] |
|
|
|
|
|
|
|
tokens = [tok for tok, keep in zip(tokens, to_keep.tolist()) |
|
if keep is True] |
|
|
|
if n_insert > 0: |
|
tokens = self.insertion_noise(tokens, n_insert / len(tokens)) |
|
|
|
return tokens |
|
|
|
def insertion_noise(self, tokens, p=1.0): |
|
n_tokens = len(tokens) |
|
n_insert = math.ceil(n_tokens * p) |
|
if n_insert == 0: |
|
return tokens |
|
n_random = math.ceil(n_insert * self.random_ratio) |
|
|
|
noise_indices = np.random.permutation(n_tokens + n_insert)[:n_insert] |
|
noise_mask = np.zeros(shape=(n_tokens + n_insert,), dtype=bool) |
|
noise_mask[noise_indices] = 1 |
|
|
|
result = np.empty(shape=(n_tokens + n_insert,), dtype=object) |
|
result[noise_indices[n_random:]] = self.mask_tok |
|
if n_random > 0: |
|
result[noise_indices[:n_random]] = np.random.choice( |
|
self.vocab, size=n_random) |
|
result[~noise_mask] = tokens |
|
|
|
assert all([item is not None for item in result]),\ |
|
"Error when inserting noise." |
|
return result.tolist() |
|
|
|
def rolling_noise(self, tokens, p=1.0): |
|
if np.random.random() >= p: |
|
return tokens |
|
offset = np.random.randint(0, max(1, len(tokens) - 1) + 1) |
|
return tokens[offset:] + tokens[0:offset] |
|
|
|
def apply(self, tokens): |
|
if self.permute_sent_ratio > 0.0: |
|
tokens = self.permute_sentences(tokens, self.permute_sent_ratio) |
|
|
|
if self.mask_ratio > 0.0: |
|
tokens = self.whole_word_mask(tokens, self.mask_ratio) |
|
|
|
if self.insert_ratio > 0.0: |
|
tokens = self.insertion_noise(tokens, self.insert_ratio) |
|
|
|
if self.rotate_ratio > 0.0: |
|
tokens = self.rolling_noise(tokens, self.rotate_ratio) |
|
return tokens |
|
|
|
def __repr__(self): |
|
cls_name = type(self).__name__ |
|
kwargs = {} |
|
if self.permute_sent_ratio > 0.0: |
|
kwargs['permute_sent_ratio'] = self.permute_sent_ratio |
|
kwargs['full_stop_token'] = self.full_stop_token |
|
if self.insert_ratio > 0.0: |
|
kwargs['insert_ratio'] = self.insert_ratio |
|
if self.rotate_ratio > 0.0: |
|
kwargs['rotate_ratio'] = self.rotate_ratio |
|
if self.random_ratio > 0.0: |
|
kwargs['random_ratio'] = self.random_ratio |
|
if self.mask_ratio > 0.0: |
|
kwargs['mask_ratio'] = self.mask_ratio |
|
kwargs['mask_length'] = self.mask_length |
|
kwargs['poisson_lambda'] = self.poisson_lambda |
|
kwargs['replace_length'] = self.replace_length |
|
cls_args = ', '.join( |
|
[f'{kw}={arg}' for kw, arg in kwargs.items()]) |
|
return '{}({})'.format(cls_name, cls_args) |
|
|
|
|
|
@register_transform(name='bart') |
|
class BARTNoiseTransform(Transform): |
|
def __init__(self, opts): |
|
super().__init__(opts) |
|
|
|
def _set_seed(self, seed): |
|
"""set seed to ensure reproducibility.""" |
|
BARTNoising.set_random_seed(seed) |
|
|
|
@classmethod |
|
def add_options(cls, parser): |
|
"""Avalilable options relate to BART.""" |
|
group = parser.add_argument_group("Transform/BART") |
|
group.add("--permute_sent_ratio", "-permute_sent_ratio", |
|
type=float, default=0.0, |
|
help="Permute this proportion of sentences " |
|
"(boundaries defined by {}) in all inputs.".format( |
|
DefaultTokens.SENT_FULL_STOPS)) |
|
group.add("--rotate_ratio", "-rotate_ratio", type=float, default=0.0, |
|
help="Rotate this proportion of inputs.") |
|
group.add("--insert_ratio", "-insert_ratio", type=float, default=0.0, |
|
help="Insert this percentage of additional random tokens.") |
|
group.add("--random_ratio", "-random_ratio", type=float, default=0.0, |
|
help="Instead of using {}, use random token " |
|
"this often.".format(DefaultTokens.MASK)) |
|
|
|
group.add("--mask_ratio", "-mask_ratio", type=float, default=0.0, |
|
help="Fraction of words/subwords that will be masked.") |
|
group.add("--mask_length", "-mask_length", type=str, default="subword", |
|
choices=["subword", "word", "span-poisson"], |
|
help="Length of masking window to apply.") |
|
group.add("--poisson_lambda", "-poisson_lambda", |
|
type=float, default=3.0, |
|
help="Lambda for Poisson distribution to sample span length " |
|
"if `-mask_length` set to span-poisson.") |
|
group.add("--replace_length", "-replace_length", |
|
type=int, default=-1, choices=[-1, 0, 1], |
|
help="When masking N tokens, replace with 0, 1, " |
|
"or N tokens. (use -1 for N)") |
|
|
|
@classmethod |
|
def require_vocab(cls): |
|
"""Override this method to inform it need vocab to start.""" |
|
return True |
|
|
|
def warm_up(self, vocabs): |
|
super().warm_up(vocabs) |
|
|
|
subword_type = self.opts.src_subword_type |
|
if self.opts.mask_length == 'subword': |
|
if subword_type == 'none': |
|
raise ValueError( |
|
f'src_subword_type={subword_type} incompatible with ' |
|
f'mask_length={self.opts.mask_length}!') |
|
is_joiner = (subword_type == 'bpe') if subword_type != 'none' else None |
|
self.bart_noise = BARTNoising( |
|
self.vocabs['src'].itos, |
|
mask_tok=DefaultTokens.MASK, |
|
mask_ratio=self.opts.mask_ratio, |
|
insert_ratio=self.opts.insert_ratio, |
|
permute_sent_ratio=self.opts.permute_sent_ratio, |
|
poisson_lambda=self.opts.poisson_lambda, |
|
replace_length=self.opts.replace_length, |
|
rotate_ratio=self.opts.rotate_ratio, |
|
mask_length=self.opts.mask_length, |
|
random_ratio=self.opts.random_ratio, |
|
is_joiner=is_joiner |
|
) |
|
|
|
def apply(self, example, is_train=False, stats=None, **kwargs): |
|
"""Apply BART noise to src side tokens.""" |
|
if is_train: |
|
src = self.bart_noise.apply(example['src']) |
|
example['src'] = src |
|
return example |
|
|
|
def _repr_args(self): |
|
"""Return str represent key arguments for BART.""" |
|
return repr(self.bart_noise) |
|
|