File size: 3,678 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import torch
class PenaltyBuilder(object):
"""Returns the Length and Coverage Penalty function for Beam Search.
Args:
length_pen (str): option name of length pen
cov_pen (str): option name of cov pen
Attributes:
has_cov_pen (bool): Whether coverage penalty is None (applying it
is a no-op). Note that the converse isn't true. Setting beta
to 0 should force coverage length to be a no-op.
has_len_pen (bool): Whether length penalty is None (applying it
is a no-op). Note that the converse isn't true. Setting alpha
to 1 should force length penalty to be a no-op.
coverage_penalty (callable[[FloatTensor, float], FloatTensor]):
Calculates the coverage penalty.
length_penalty (callable[[int, float], float]): Calculates
the length penalty.
"""
def __init__(self, cov_pen, length_pen):
self.has_cov_pen = not self._pen_is_none(cov_pen)
self.coverage_penalty = self._coverage_penalty(cov_pen)
self.has_len_pen = not self._pen_is_none(length_pen)
self.length_penalty = self._length_penalty(length_pen)
@staticmethod
def _pen_is_none(pen):
return pen == "none" or pen is None
def _coverage_penalty(self, cov_pen):
if cov_pen == "wu":
return self.coverage_wu
elif cov_pen == "summary":
return self.coverage_summary
elif self._pen_is_none(cov_pen):
return self.coverage_none
else:
raise NotImplementedError("No '{:s}' coverage penalty.".format(
cov_pen))
def _length_penalty(self, length_pen):
if length_pen == "wu":
return self.length_wu
elif length_pen == "avg":
return self.length_average
elif self._pen_is_none(length_pen):
return self.length_none
else:
raise NotImplementedError("No '{:s}' length penalty.".format(
length_pen))
# Below are all the different penalty terms implemented so far.
# Subtract coverage penalty from topk log probs.
# Divide topk log probs by length penalty.
def coverage_wu(self, cov, beta=0.):
"""GNMT coverage re-ranking score.
See "Google's Neural Machine Translation System" :cite:`wu2016google`.
``cov`` is expected to be sized ``(*, seq_len)``, where ``*`` is
probably ``batch_size x beam_size`` but could be several
dimensions like ``(batch_size, beam_size)``. If ``cov`` is attention,
then the ``seq_len`` axis probably sums to (almost) 1.
"""
penalty = -torch.min(cov, cov.clone().fill_(1.0)).log().sum(-1)
return beta * penalty
def coverage_summary(self, cov, beta=0.):
"""Our summary penalty."""
penalty = torch.max(cov, cov.clone().fill_(1.0)).sum(-1)
penalty -= cov.size(-1)
return beta * penalty
def coverage_none(self, cov, beta=0.):
"""Returns zero as penalty"""
none = torch.zeros((1,), device=cov.device,
dtype=torch.float)
if cov.dim() == 3:
none = none.unsqueeze(0)
return none
def length_wu(self, cur_len, alpha=0.):
"""GNMT length re-ranking score.
See "Google's Neural Machine Translation System" :cite:`wu2016google`.
"""
return ((5 + cur_len) / 6.0) ** alpha
def length_average(self, cur_len, alpha=0.):
"""Returns the current sequence length."""
return cur_len
def length_none(self, cur_len, alpha=0.):
"""Returns unmodified scores."""
return 1.0
|