|
import torch
|
|
|
|
|
|
def penalty_builder(penalty_config):
|
|
if penalty_config == '':
|
|
return lambda x, y: y
|
|
pen_type, alpha = penalty_config.split('_')
|
|
alpha = float(alpha)
|
|
if pen_type == 'wu':
|
|
return lambda x, y: length_wu(x, y, alpha)
|
|
if pen_type == 'avg':
|
|
return lambda x, y: length_average(x, y, alpha)
|
|
|
|
|
|
def length_wu(length, logprobs, alpha=0.):
|
|
"""
|
|
NMT length re-ranking score from
|
|
"Google's Neural Machine Translation System" :cite:`wu2016google`.
|
|
"""
|
|
|
|
modifier = (((5 + length) ** alpha) /
|
|
((5 + 1) ** alpha))
|
|
return logprobs / modifier
|
|
|
|
|
|
def length_average(length, logprobs, alpha=0.):
|
|
"""
|
|
Returns the average probability of tokens in a sequence.
|
|
"""
|
|
return logprobs / length
|
|
|
|
|
|
def split_tensors(n, x):
|
|
if torch.is_tensor(x):
|
|
assert x.shape[0] % n == 0
|
|
x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1)
|
|
elif type(x) is list or type(x) is tuple:
|
|
x = [split_tensors(n, _) for _ in x]
|
|
elif x is None:
|
|
x = [None] * n
|
|
return x
|
|
|
|
|
|
def repeat_tensors(n, x):
|
|
"""
|
|
For a tensor of size Bx..., we repeat it n times, and make it Bnx...
|
|
For collections, do nested repeat
|
|
"""
|
|
if torch.is_tensor(x):
|
|
x = x.unsqueeze(1)
|
|
x = x.expand(-1, n, *([-1] * len(x.shape[2:])))
|
|
x = x.reshape(x.shape[0] * n, *x.shape[2:])
|
|
elif type(x) is list or type(x) is tuple:
|
|
x = [repeat_tensors(n, _) for _ in x]
|
|
return x
|
|
|