File size: 2,114 Bytes
32b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
import torch.nn.functional as F
from uniperceiver.config import configurable
from .build import LOSSES_REGISTRY




@LOSSES_REGISTRY.register()
class LabelSmoothingCrossEntropy(nn.Module):
    @configurable
    def __init__(
        self,
        *,
        label_smoothing,
        loss_weight,
        loss_fp32,
    ):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.label_smoothing = label_smoothing
        self.confidence = 1.0 - self.label_smoothing
        if not isinstance(loss_weight, float):
            self.loss_weight = 1.0
        else:
            self.loss_weight = loss_weight
        self.loss_fp32 = loss_fp32

    @classmethod
    def from_config(cls, cfg):
        return {
            "label_smoothing": cfg.LOSSES.LABELSMOOTHING,
            'loss_weight': getattr(cfg.LOSSES, 'LOSS_WEIGHT', None),
            'loss_fp32': getattr(cfg.LOSSES, 'LOSS_FP32', False),
        }

    def Forward(self, x, target):
        if self.loss_fp32 and x.dtype != torch.float32:
            logprobs = F.log_softmax(x, dim=-1,
                                     dtype=torch.float32).to(x.dtype)
        else:
            logprobs = F.log_softmax(x, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.label_smoothing * smooth_loss
        return loss.mean()

    def forward(self, outputs_dict):
        ret = {}

        for logit, target, loss_identification in zip(outputs_dict['logits'],
                                            outputs_dict['targets'],
                                            outputs_dict['loss_names']):


            loss = self.Forward(logit, target)
            if self.loss_weight != 1.0:
                loss *= self.loss_weight
            loss_name = 'LabelSmoothing'
            if len(loss_identification) > 0:
                loss_name = loss_name + f' ({loss_identification})'
            ret.update({loss_name: loss})


        return ret