File size: 2,456 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
# encoding: utf-8
"""Class Declaration of Transformer's Label Smootion loss."""
import logging
import chainer
import chainer.functions as F
class LabelSmoothingLoss(chainer.Chain):
"""Label Smoothing Loss.
Args:
smoothing (float): smoothing rate (0.0 means the conventional CE).
n_target_vocab (int): number of classes.
normalize_length (bool): normalize loss by sequence length if True.
"""
def __init__(self, smoothing, n_target_vocab, normalize_length=False, ignore_id=-1):
"""Initialize Loss."""
super(LabelSmoothingLoss, self).__init__()
self.use_label_smoothing = False
if smoothing > 0.0:
logging.info("Use label smoothing")
self.smoothing = smoothing
self.confidence = 1.0 - smoothing
self.use_label_smoothing = True
self.n_target_vocab = n_target_vocab
self.normalize_length = normalize_length
self.ignore_id = ignore_id
self.acc = None
def forward(self, ys_block, ys_pad):
"""Forward Loss.
Args:
ys_block (chainer.Variable): Predicted labels.
ys_pad (chainer.Variable): Target (true) labels.
Returns:
float: Training loss.
"""
# Output (all together at once for efficiency)
batch, length, dims = ys_block.shape
concat_logit_block = ys_block.reshape(-1, dims)
# Target reshape
concat_t_block = ys_pad.reshape((batch * length))
ignore_mask = concat_t_block >= 0
n_token = ignore_mask.sum()
normalizer = n_token if self.normalize_length else batch
if not self.use_label_smoothing:
loss = F.softmax_cross_entropy(concat_logit_block, concat_t_block)
loss = loss * n_token / normalizer
else:
log_prob = F.log_softmax(concat_logit_block)
broad_ignore_mask = self.xp.broadcast_to(
ignore_mask[:, None], concat_logit_block.shape
)
pre_loss = (
ignore_mask * log_prob[self.xp.arange(batch * length), concat_t_block]
)
loss = -F.sum(pre_loss) / normalizer
label_smoothing = broad_ignore_mask * -1.0 / self.n_target_vocab * log_prob
label_smoothing = F.sum(label_smoothing) / normalizer
loss = self.confidence * loss + self.smoothing * label_smoothing
return loss
|