sakharamg's picture
Uploading all files
158b61b
"""Transforms relate to noising from BART: based on code of fairseq."""
import math
import numpy as np
import torch
from typing import Sequence, Callable
from onmt.constants import DefaultTokens, SubwordMarker
from onmt.transforms import register_transform
from .transform import Transform
def _subword_start_by_joiner(tokens: Sequence[str]) -> Sequence[bool]:
"""Find word start in a subword list marked by joiner."""
flag = [True] * len(tokens)
for i, token in enumerate(tokens):
if token.startswith(SubwordMarker.JOINER) and i != 0:
flag[i] = False
if token.endswith(SubwordMarker.JOINER):
try:
flag[i+1] = False
except IndexError:
print("Sentence `{}` not correct!".format(" ".join(token)))
raise
return flag
def _subword_start_by_spacer(tokens: Sequence[str]) -> Sequence[bool]:
"""Find word start in a subword list marked by spacer(as prefix)."""
flag = [x.startswith(SubwordMarker.SPACER) for x in tokens]
flag[0] = True
return flag
def word_start_finder(ignore_subword=False, is_joiner=False) -> Callable:
"""Return callable to find all word start in the token list."""
if not ignore_subword:
if is_joiner:
return _subword_start_by_joiner
else:
return _subword_start_by_spacer
else:
return lambda tokens: [True] * len(tokens)
class BARTNoising(object):
"""Noise from BART."""
def __init__(self, vocab, mask_tok=DefaultTokens.MASK, mask_ratio=0.0,
insert_ratio=0.0, permute_sent_ratio=0.0, poisson_lambda=3.0,
replace_length=-1, rotate_ratio=0.0, mask_length='subword',
random_ratio=0.0, is_joiner=False,
full_stop_token=DefaultTokens.SENT_FULL_STOPS):
if vocab is None:
raise ValueError("Inject BART noise requires a valid vocabulary.")
self.vocab = vocab
self.mask_tok = mask_tok
self.mask_ratio = mask_ratio
self.random_ratio = random_ratio
self.insert_ratio = insert_ratio
self.rotate_ratio = rotate_ratio
self.permute_sent_ratio = permute_sent_ratio
self.full_stop_token = full_stop_token
# -1: keep everything (i.e. 1 mask per token)
# 0: replace everything (i.e. no mask)
# 1: 1 mask per span
if replace_length not in [-1, 0, 1]:
raise ValueError(f'invalid arg: replace_length={replace_length}')
self.replace_length = replace_length
if mask_length not in ['subword', 'word', 'span-poisson']:
raise ValueError(f'invalid arg: mask-length={mask_length}')
if mask_length == 'subword' and replace_length not in [0, 1]:
raise ValueError('if using subwords, use replace-length=1 or 0')
if mask_length == 'subword' or is_joiner is None:
# view each subword as word start / input is word level token
self._is_word_start = word_start_finder(ignore_subword=True)
else:
self._is_word_start = word_start_finder(is_joiner=is_joiner)
self.mask_span_distribution = None
if mask_length == 'span-poisson':
self.mask_span_distribution = self._make_poisson(poisson_lambda)
self.mask_length = mask_length
self.poisson_lambda = poisson_lambda
@staticmethod
def set_random_seed(seed):
"""Call this before use to ensure reproducibility."""
np.random.seed(seed)
torch.manual_seed(seed)
def _make_poisson(self, poisson_lambda):
lambda_to_the_k = 1
e_to_the_minus_lambda = math.exp(-poisson_lambda)
k_factorial = 1
ps = []
for k in range(0, 128):
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
lambda_to_the_k *= poisson_lambda
k_factorial *= (k + 1)
if ps[-1] < 0.0000001:
break
ps = torch.FloatTensor(ps)
return torch.distributions.Categorical(ps)
def _get_sentence_borders(self, tokens):
"""Return lengths of each sentence in the token sequence."""
full_stops = np.array(
[
True if token in self.full_stop_token else False
for token in tokens
]
)
# Pretend it ends with a full stop so last span is a sentence
full_stops[-1] = True
# Tokens that are full stops, where the previous token is not
sentence_lens = (full_stops[1:] * ~full_stops[:-1]).nonzero()[0] + 2
return sentence_lens
def permute_sentences(self, tokens, p=1.0):
if len(tokens) == 1:
return tokens
sentence_lens = self._get_sentence_borders(tokens)
n_sentences = sentence_lens.size
if n_sentences == 1:
return tokens
n_to_permute = math.ceil((n_sentences * 2 * p) / 2.0)
substitutions = np.random.permutation(n_sentences)[:n_to_permute]
ordering = np.arange(0, n_sentences)
ordering[substitutions] = substitutions[np.random.permutation(
n_to_permute)]
result = [tok for tok in tokens]
index = 0
for i in ordering:
sentence = tokens[(sentence_lens[i - 1] if i > 0 else 0):
sentence_lens[i]]
result[index:index + len(sentence)] = sentence
index += len(sentence)
assert len(result) == len(tokens), "Error when permute sentences."
return result
def whole_word_mask(self, tokens, p=1.0): # text span mask/infilling
is_word_start = torch.tensor(self._is_word_start(tokens)).int()
n_mask = int(math.ceil(is_word_start.sum() * p))
n_insert = 0
if n_mask == 0:
return tokens
if self.mask_span_distribution is not None: # Text (span) Infilling
lengths = self.mask_span_distribution.sample(
sample_shape=(n_mask,))
# Make sure we have enough to mask
cum_length = torch.cumsum(lengths, 0)
while cum_length[-1] < n_mask:
lengths = torch.cat([
lengths,
self.mask_span_distribution.sample(
sample_shape=(n_mask,))
], dim=0)
cum_length = torch.cumsum(lengths, 0)
# Trim to masking budget
i = 0
while cum_length[i] < n_mask:
i += 1
lengths[i] = n_mask - (0 if i == 0 else cum_length[i - 1])
n_mask = i + 1
lengths = lengths[:n_mask]
# Handle 0-length mask (inserts) separately
lengths = lengths[lengths > 0]
n_insert = n_mask - lengths.size(0)
n_mask -= n_insert
if n_mask == 0:
return self.insertion_noise(tokens, n_insert / len(tokens))
assert (lengths > 0).all()
else: # Token Masking
lengths = torch.ones((n_mask,)).long()
# assert is_word_start[-1] == 0
word_starts = is_word_start.nonzero(as_tuple=False)
indices = word_starts[
torch.randperm(word_starts.size(0))[:n_mask]
].squeeze(1)
mask_random = torch.FloatTensor(n_mask).uniform_() < self.random_ratio
tokens_length = len(tokens)
# assert tokens_length - 1 not in indices
to_keep = torch.ones(tokens_length, dtype=torch.bool)
if self.replace_length == 0:
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
for i in indices.tolist():
tokens[i] = self.mask_tok
random_tok_ids = torch.randint(
0, len(self.vocab), size=(mask_random.sum(),)).tolist()
for i, rid in zip(indices[mask_random].tolist(), random_tok_ids):
tokens[i] = self.vocab[rid]
if tokens_length - 1 in indices:
uncompleted = (indices != tokens_length - 1)
indices = indices[uncompleted]
mask_random = mask_random[uncompleted]
lengths = lengths[uncompleted]
# acts as a long length, so spans don't go over the end of doc
is_word_start[-1] = 255
if self.mask_span_distribution is not None:
assert len(lengths.size()) == 1
assert lengths.size() == indices.size()
lengths -= 1 # 1 for the position already masked
while indices.size(0) > 0:
assert lengths.size() == indices.size()
# next position from each word_start
lengths -= is_word_start[indices + 1].long()
uncompleted = lengths >= 0
indices = indices[uncompleted] + 1
mask_random = mask_random[uncompleted]
lengths = lengths[uncompleted]
if self.replace_length != -1:
# delete token: 1 mask/remove per span
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]: 1 mask per token
for i in indices.tolist():
tokens[i] = self.mask_tok
random_tok_ids = torch.randint(
0, len(self.vocab), size=(mask_random.sum(),)).tolist()
for i, rid in zip(
indices[mask_random].tolist(), random_tok_ids):
tokens[i] = self.vocab[rid]
else:
# A bit faster when all lengths are 1
while indices.size(0) > 0:
# to cover whole token
uncompleted = is_word_start[indices + 1] == 0
indices = indices[uncompleted] + 1
mask_random = mask_random[uncompleted]
if self.replace_length != -1:
# delete token
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
for i in indices.tolist():
tokens[i] = self.mask_tok
random_tok_ids = torch.randint(
0, len(self.vocab), size=(mask_random.sum(),)).tolist()
for i, rid in zip(
indices[mask_random].tolist(), random_tok_ids):
tokens[i] = self.vocab[rid]
# assert tokens_length - 1 not in indices
tokens = [tok for tok, keep in zip(tokens, to_keep.tolist())
if keep is True]
if n_insert > 0:
tokens = self.insertion_noise(tokens, n_insert / len(tokens))
return tokens
def insertion_noise(self, tokens, p=1.0):
n_tokens = len(tokens)
n_insert = math.ceil(n_tokens * p)
if n_insert == 0:
return tokens
n_random = math.ceil(n_insert * self.random_ratio)
noise_indices = np.random.permutation(n_tokens + n_insert)[:n_insert]
noise_mask = np.zeros(shape=(n_tokens + n_insert,), dtype=bool)
noise_mask[noise_indices] = 1
result = np.empty(shape=(n_tokens + n_insert,), dtype=object)
result[noise_indices[n_random:]] = self.mask_tok
if n_random > 0:
result[noise_indices[:n_random]] = np.random.choice(
self.vocab, size=n_random)
result[~noise_mask] = tokens
assert all([item is not None for item in result]),\
"Error when inserting noise."
return result.tolist()
def rolling_noise(self, tokens, p=1.0):
if np.random.random() >= p:
return tokens
offset = np.random.randint(0, max(1, len(tokens) - 1) + 1)
return tokens[offset:] + tokens[0:offset]
def apply(self, tokens):
if self.permute_sent_ratio > 0.0:
tokens = self.permute_sentences(tokens, self.permute_sent_ratio)
if self.mask_ratio > 0.0:
tokens = self.whole_word_mask(tokens, self.mask_ratio)
if self.insert_ratio > 0.0:
tokens = self.insertion_noise(tokens, self.insert_ratio)
if self.rotate_ratio > 0.0:
tokens = self.rolling_noise(tokens, self.rotate_ratio)
return tokens
def __repr__(self):
cls_name = type(self).__name__
kwargs = {}
if self.permute_sent_ratio > 0.0:
kwargs['permute_sent_ratio'] = self.permute_sent_ratio
kwargs['full_stop_token'] = self.full_stop_token
if self.insert_ratio > 0.0:
kwargs['insert_ratio'] = self.insert_ratio
if self.rotate_ratio > 0.0:
kwargs['rotate_ratio'] = self.rotate_ratio
if self.random_ratio > 0.0:
kwargs['random_ratio'] = self.random_ratio
if self.mask_ratio > 0.0:
kwargs['mask_ratio'] = self.mask_ratio
kwargs['mask_length'] = self.mask_length
kwargs['poisson_lambda'] = self.poisson_lambda
kwargs['replace_length'] = self.replace_length
cls_args = ', '.join(
[f'{kw}={arg}' for kw, arg in kwargs.items()])
return '{}({})'.format(cls_name, cls_args)
@register_transform(name='bart')
class BARTNoiseTransform(Transform):
def __init__(self, opts):
super().__init__(opts)
def _set_seed(self, seed):
"""set seed to ensure reproducibility."""
BARTNoising.set_random_seed(seed)
@classmethod
def add_options(cls, parser):
"""Avalilable options relate to BART."""
group = parser.add_argument_group("Transform/BART")
group.add("--permute_sent_ratio", "-permute_sent_ratio",
type=float, default=0.0,
help="Permute this proportion of sentences "
"(boundaries defined by {}) in all inputs.".format(
DefaultTokens.SENT_FULL_STOPS))
group.add("--rotate_ratio", "-rotate_ratio", type=float, default=0.0,
help="Rotate this proportion of inputs.")
group.add("--insert_ratio", "-insert_ratio", type=float, default=0.0,
help="Insert this percentage of additional random tokens.")
group.add("--random_ratio", "-random_ratio", type=float, default=0.0,
help="Instead of using {}, use random token "
"this often.".format(DefaultTokens.MASK))
group.add("--mask_ratio", "-mask_ratio", type=float, default=0.0,
help="Fraction of words/subwords that will be masked.")
group.add("--mask_length", "-mask_length", type=str, default="subword",
choices=["subword", "word", "span-poisson"],
help="Length of masking window to apply.")
group.add("--poisson_lambda", "-poisson_lambda",
type=float, default=3.0,
help="Lambda for Poisson distribution to sample span length "
"if `-mask_length` set to span-poisson.")
group.add("--replace_length", "-replace_length",
type=int, default=-1, choices=[-1, 0, 1],
help="When masking N tokens, replace with 0, 1, "
"or N tokens. (use -1 for N)")
@classmethod
def require_vocab(cls):
"""Override this method to inform it need vocab to start."""
return True
def warm_up(self, vocabs):
super().warm_up(vocabs)
subword_type = self.opts.src_subword_type
if self.opts.mask_length == 'subword':
if subword_type == 'none':
raise ValueError(
f'src_subword_type={subword_type} incompatible with '
f'mask_length={self.opts.mask_length}!')
is_joiner = (subword_type == 'bpe') if subword_type != 'none' else None
self.bart_noise = BARTNoising(
self.vocabs['src'].itos,
mask_tok=DefaultTokens.MASK,
mask_ratio=self.opts.mask_ratio,
insert_ratio=self.opts.insert_ratio,
permute_sent_ratio=self.opts.permute_sent_ratio,
poisson_lambda=self.opts.poisson_lambda,
replace_length=self.opts.replace_length,
rotate_ratio=self.opts.rotate_ratio,
mask_length=self.opts.mask_length,
random_ratio=self.opts.random_ratio,
is_joiner=is_joiner
)
def apply(self, example, is_train=False, stats=None, **kwargs):
"""Apply BART noise to src side tokens."""
if is_train:
src = self.bart_noise.apply(example['src'])
example['src'] = src
return example
def _repr_args(self):
"""Return str represent key arguments for BART."""
return repr(self.bart_noise)