|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from uniperceiver.config import configurable |
|
from .build import LOSSES_REGISTRY |
|
|
|
|
|
|
|
|
|
@LOSSES_REGISTRY.register() |
|
class LabelSmoothingCrossEntropy(nn.Module): |
|
@configurable |
|
def __init__( |
|
self, |
|
*, |
|
label_smoothing, |
|
loss_weight, |
|
loss_fp32, |
|
): |
|
super(LabelSmoothingCrossEntropy, self).__init__() |
|
self.label_smoothing = label_smoothing |
|
self.confidence = 1.0 - self.label_smoothing |
|
if not isinstance(loss_weight, float): |
|
self.loss_weight = 1.0 |
|
else: |
|
self.loss_weight = loss_weight |
|
self.loss_fp32 = loss_fp32 |
|
|
|
@classmethod |
|
def from_config(cls, cfg): |
|
return { |
|
"label_smoothing": cfg.LOSSES.LABELSMOOTHING, |
|
'loss_weight': getattr(cfg.LOSSES, 'LOSS_WEIGHT', None), |
|
'loss_fp32': getattr(cfg.LOSSES, 'LOSS_FP32', False), |
|
} |
|
|
|
def Forward(self, x, target): |
|
if self.loss_fp32 and x.dtype != torch.float32: |
|
logprobs = F.log_softmax(x, dim=-1, |
|
dtype=torch.float32).to(x.dtype) |
|
else: |
|
logprobs = F.log_softmax(x, dim=-1) |
|
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) |
|
nll_loss = nll_loss.squeeze(1) |
|
smooth_loss = -logprobs.mean(dim=-1) |
|
loss = self.confidence * nll_loss + self.label_smoothing * smooth_loss |
|
return loss.mean() |
|
|
|
def forward(self, outputs_dict): |
|
ret = {} |
|
|
|
for logit, target, loss_identification in zip(outputs_dict['logits'], |
|
outputs_dict['targets'], |
|
outputs_dict['loss_names']): |
|
|
|
|
|
loss = self.Forward(logit, target) |
|
if self.loss_weight != 1.0: |
|
loss *= self.loss_weight |
|
loss_name = 'LabelSmoothing' |
|
if len(loss_identification) > 0: |
|
loss_name = loss_name + f' ({loss_identification})' |
|
ret.update({loss_name: loss}) |
|
|
|
|
|
return ret |
|
|