File size: 1,627 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn.functional as F

from .ar_loss import ARLoss


def BanlanceMultiClassCrossEntropyLoss(x_o, x_t):
    # [B, num_cls, H, W]
    B, num_cls, H, W = x_o.shape
    x_o = x_o.reshape(B, num_cls, H * W).permute(0, 2, 1)
    # [B, H, W, num_cls]
    # generate gt
    x_t[x_t > 0.5] = 1
    x_t[x_t <= 0.5] = 0
    fg_x_t = x_t.sum(-1)  # [B, H, W]
    x_t = x_t.argmax(-1)  # [B, H, W]
    x_t[fg_x_t == 0] = num_cls - 1  # background
    x_t = x_t.reshape(B, H * W)
    # loss
    weight = torch.ones((B, num_cls)).type_as(x_o)  # the weight of bg is 1.
    num_bg = (x_t == (num_cls - 1)).sum(-1)  # [B]
    weight[:, :-1] = (num_bg / (H * W - num_bg + 1e-5)).unsqueeze(-1).expand(
        -1, num_cls - 1)
    logit = F.log_softmax(x_o, dim=-1)  # [B, H*W, num_cls]
    logit = logit * weight.unsqueeze(1)
    loss = -logit.gather(2, x_t.unsqueeze(-1).long())
    return loss.mean()


class CAMLoss(ARLoss):

    def __init__(self, label_smoothing=0.1, loss_weight_binary=1.5, **kwargs):
        super(CAMLoss, self).__init__(label_smoothing=label_smoothing)
        self.label_smoothing = label_smoothing
        self.loss_weight_binary = loss_weight_binary

    def forward(self, pred, batch):
        binary_mask = batch[-1]
        rec_loss = super().forward(pred['rec_output'], batch[:-1])['loss']
        output = pred
        loss_binary = self.loss_weight_binary * BanlanceMultiClassCrossEntropyLoss(
            output['pred_binary'], binary_mask)

        return {
            'loss': rec_loss + loss_binary,
            'rec_loss': rec_loss,
            'loss_binary': loss_binary
        }