|
|
|
"""Class Declaration of Transformer's Label Smootion loss.""" |
|
|
|
import logging |
|
|
|
import chainer |
|
|
|
import chainer.functions as F |
|
|
|
|
|
class LabelSmoothingLoss(chainer.Chain): |
|
"""Label Smoothing Loss. |
|
|
|
Args: |
|
smoothing (float): smoothing rate (0.0 means the conventional CE). |
|
n_target_vocab (int): number of classes. |
|
normalize_length (bool): normalize loss by sequence length if True. |
|
|
|
""" |
|
|
|
def __init__(self, smoothing, n_target_vocab, normalize_length=False, ignore_id=-1): |
|
"""Initialize Loss.""" |
|
super(LabelSmoothingLoss, self).__init__() |
|
self.use_label_smoothing = False |
|
if smoothing > 0.0: |
|
logging.info("Use label smoothing") |
|
self.smoothing = smoothing |
|
self.confidence = 1.0 - smoothing |
|
self.use_label_smoothing = True |
|
self.n_target_vocab = n_target_vocab |
|
self.normalize_length = normalize_length |
|
self.ignore_id = ignore_id |
|
self.acc = None |
|
|
|
def forward(self, ys_block, ys_pad): |
|
"""Forward Loss. |
|
|
|
Args: |
|
ys_block (chainer.Variable): Predicted labels. |
|
ys_pad (chainer.Variable): Target (true) labels. |
|
|
|
Returns: |
|
float: Training loss. |
|
|
|
""" |
|
|
|
batch, length, dims = ys_block.shape |
|
concat_logit_block = ys_block.reshape(-1, dims) |
|
|
|
|
|
concat_t_block = ys_pad.reshape((batch * length)) |
|
ignore_mask = concat_t_block >= 0 |
|
n_token = ignore_mask.sum() |
|
normalizer = n_token if self.normalize_length else batch |
|
|
|
if not self.use_label_smoothing: |
|
loss = F.softmax_cross_entropy(concat_logit_block, concat_t_block) |
|
loss = loss * n_token / normalizer |
|
else: |
|
log_prob = F.log_softmax(concat_logit_block) |
|
broad_ignore_mask = self.xp.broadcast_to( |
|
ignore_mask[:, None], concat_logit_block.shape |
|
) |
|
pre_loss = ( |
|
ignore_mask * log_prob[self.xp.arange(batch * length), concat_t_block] |
|
) |
|
loss = -F.sum(pre_loss) / normalizer |
|
label_smoothing = broad_ignore_mask * -1.0 / self.n_target_vocab * log_prob |
|
label_smoothing = F.sum(label_smoothing) / normalizer |
|
loss = self.confidence * loss + self.smoothing * label_smoothing |
|
return loss |
|
|