File size: 1,458 Bytes
4ea50ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from torch import nn
from torch.nn import functional as F

def Focal_Loss(pred, gt):
        # print('yes!!')



        ce = nn.CrossEntropyLoss()
        alpha = 0.25
        gamma = 2
        # logp = ce(input, target)
        p = torch.sigmoid(pred)

        loss = -alpha * (1 - p) ** gamma * (gt * torch.log(p)) - \
               (1 - alpha) * p ** gamma * ((1 - gt) * torch.log(1 - p))

        return loss.mean()











        # pred =torch.sigmoid(pred)
        # pos_inds = gt.eq(1).float()
        # neg_inds = gt.lt(1).float()
        #
        # loss = 0
        #
        # pos_loss = torch.log(pred + 1e-10) * torch.pow(pred, 2) * pos_inds
        # # neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
        # neg_loss = torch.log(1 - pred) * torch.pow(1 - pred, 2) * neg_inds
        #
        # num_pos = pos_inds.float().sum()
        # num_neg = neg_inds.float().sum()
        #
        # pos_loss = pos_loss.sum()
        # neg_loss = neg_loss.sum()
        #
        # if num_pos == 0:
        #     loss = loss - neg_loss
        # else:
        #     # loss = loss - (pos_loss + neg_loss) / (num_pos)
        #     loss = loss - (pos_loss + neg_loss )
        # return loss * 5




        # if weight is not None and weight.sum() > 0:
        #     return (losses * weight).sum() / weight.sum()
        # else:
        #     assert losses.numel() != 0
        #     return losses.mean()