unit_test / uniperceiver /losses /label_smoothing.py
herrius's picture
Upload 259 files
32b542e
import torch
import torch.nn as nn
import torch.nn.functional as F
from uniperceiver.config import configurable
from .build import LOSSES_REGISTRY
@LOSSES_REGISTRY.register()
class LabelSmoothingCrossEntropy(nn.Module):
@configurable
def __init__(
self,
*,
label_smoothing,
loss_weight,
loss_fp32,
):
super(LabelSmoothingCrossEntropy, self).__init__()
self.label_smoothing = label_smoothing
self.confidence = 1.0 - self.label_smoothing
if not isinstance(loss_weight, float):
self.loss_weight = 1.0
else:
self.loss_weight = loss_weight
self.loss_fp32 = loss_fp32
@classmethod
def from_config(cls, cfg):
return {
"label_smoothing": cfg.LOSSES.LABELSMOOTHING,
'loss_weight': getattr(cfg.LOSSES, 'LOSS_WEIGHT', None),
'loss_fp32': getattr(cfg.LOSSES, 'LOSS_FP32', False),
}
def Forward(self, x, target):
if self.loss_fp32 and x.dtype != torch.float32:
logprobs = F.log_softmax(x, dim=-1,
dtype=torch.float32).to(x.dtype)
else:
logprobs = F.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.label_smoothing * smooth_loss
return loss.mean()
def forward(self, outputs_dict):
ret = {}
for logit, target, loss_identification in zip(outputs_dict['logits'],
outputs_dict['targets'],
outputs_dict['loss_names']):
loss = self.Forward(logit, target)
if self.loss_weight != 1.0:
loss *= self.loss_weight
loss_name = 'LabelSmoothing'
if len(loss_identification) > 0:
loss_name = loss_name + f' ({loss_identification})'
ret.update({loss_name: loss})
return ret