Spaces:
Running
Running
File size: 1,627 Bytes
29f689c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import torch
import torch.nn.functional as F
from .ar_loss import ARLoss
def BanlanceMultiClassCrossEntropyLoss(x_o, x_t):
# [B, num_cls, H, W]
B, num_cls, H, W = x_o.shape
x_o = x_o.reshape(B, num_cls, H * W).permute(0, 2, 1)
# [B, H, W, num_cls]
# generate gt
x_t[x_t > 0.5] = 1
x_t[x_t <= 0.5] = 0
fg_x_t = x_t.sum(-1) # [B, H, W]
x_t = x_t.argmax(-1) # [B, H, W]
x_t[fg_x_t == 0] = num_cls - 1 # background
x_t = x_t.reshape(B, H * W)
# loss
weight = torch.ones((B, num_cls)).type_as(x_o) # the weight of bg is 1.
num_bg = (x_t == (num_cls - 1)).sum(-1) # [B]
weight[:, :-1] = (num_bg / (H * W - num_bg + 1e-5)).unsqueeze(-1).expand(
-1, num_cls - 1)
logit = F.log_softmax(x_o, dim=-1) # [B, H*W, num_cls]
logit = logit * weight.unsqueeze(1)
loss = -logit.gather(2, x_t.unsqueeze(-1).long())
return loss.mean()
class CAMLoss(ARLoss):
def __init__(self, label_smoothing=0.1, loss_weight_binary=1.5, **kwargs):
super(CAMLoss, self).__init__(label_smoothing=label_smoothing)
self.label_smoothing = label_smoothing
self.loss_weight_binary = loss_weight_binary
def forward(self, pred, batch):
binary_mask = batch[-1]
rec_loss = super().forward(pred['rec_output'], batch[:-1])['loss']
output = pred
loss_binary = self.loss_weight_binary * BanlanceMultiClassCrossEntropyLoss(
output['pred_binary'], binary_mask)
return {
'loss': rec_loss + loss_binary,
'rec_loss': rec_loss,
'loss_binary': loss_binary
}
|