File size: 2,136 Bytes
f1dd031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Author: Siyuan Li
Licensed: Apache-2.0 License
"""

import torch
import torch.nn as nn
from mmdet.models import weight_reduce_loss
from mmdet.registry import MODELS


def multi_pos_cross_entropy(
    pred, label, weight=None, reduction="mean", avg_factor=None, pos_normalize=True,
):

    valid_mask = label.sum(1) != 0
    pred = pred[valid_mask]
    label = label[valid_mask]
    weight = weight[valid_mask]
    if min(pred.shape) != 0:
        logits_max, _ = torch.max(pred, dim=1, keepdim=True)
        logits = pred - logits_max.detach()
    else:
        logits = pred

    if pos_normalize:
        pos_norm = torch.div(label, label.sum(1).reshape(-1, 1))
        exp_logits = (torch.exp(logits)) * pos_norm + (
            torch.exp(logits)
        ) * torch.logical_not(label)
    else:
        exp_logits = torch.exp(logits)
    exp_logits_input = exp_logits.sum(1, keepdim=True)
    log_prob = logits - torch.log(exp_logits_input)

    mean_log_prob_pos = (label * log_prob).sum(1) / label.sum(1)
    loss = -mean_log_prob_pos

    # apply weights and do the reduction
    if weight is not None:
        weight = weight.float()
    loss = weight_reduce_loss(
        loss, weight=weight, reduction=reduction, avg_factor=avg_factor
    )

    return loss


@MODELS.register_module()
class UnbiasedContrastLoss(nn.Module):
    def __init__(self, reduction="mean", loss_weight=1.0):
        super(UnbiasedContrastLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(
        self,
        cls_score,
        label,
        weight=None,
        avg_factor=None,
        reduction_override=None,
        **kwargs
    ):
        assert cls_score.size() == label.size()
        assert reduction_override in (None, "none", "mean", "sum")
        reduction = reduction_override if reduction_override else self.reduction
        loss_cls = self.loss_weight * multi_pos_cross_entropy(
            cls_score,
            label,
            weight,
            reduction=reduction,
            avg_factor=avg_factor,
            **kwargs
        )
        return loss_cls