File size: 4,214 Bytes
c9019cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
"""
This code is based on the following file:
https://github.com/tztztztztz/eqlv2/blob/master/mmdet/models/losses/eqlv2.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from mmdet.utils import get_root_logger
from functools import partial
from ..builder import LOSSES
@LOSSES.register_module()
class EQLv2(nn.Module):
def __init__(self,
use_sigmoid=True,
reduction='mean',
class_weight=None,
loss_weight=1.0,
num_classes=15, #1203, # 1203 for lvis v1.0, 1230 for lvis v0.5
gamma=12,
mu=0.8,
alpha=4.0,
vis_grad=False):
super().__init__()
self.use_sigmoid = True
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.num_classes = num_classes
self.group = True
# cfg for eqlv2
self.vis_grad = vis_grad
self.gamma = gamma
self.mu = mu
self.alpha = alpha
# initial variables
self._pos_grad = None
self._neg_grad = None
self.pos_neg = None
def _func(x, gamma, mu):
return 1 / (1 + torch.exp(-gamma * (x - mu)))
self.map_func = partial(_func, gamma=self.gamma, mu=self.mu)
logger = get_root_logger()
logger.info(f"build EQL v2, gamma: {gamma}, mu: {mu}, alpha: {alpha}")
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
self.n_i, self.n_c = cls_score.size()
self.gt_classes = label
self.pred_class_logits = cls_score
def expand_label(pred, gt_classes):
target = pred.new_zeros(self.n_i, self.n_c)
target[torch.arange(self.n_i), gt_classes] = 1
return target
target = expand_label(cls_score, label)
pos_w, neg_w = self.get_weight(cls_score)
weight = pos_w * target + neg_w * (1 - target)
cls_loss = F.binary_cross_entropy_with_logits(cls_score, target,
reduction='none')
cls_loss = torch.sum(cls_loss * weight) / self.n_i
self.collect_grad(cls_score.detach(), target.detach(), weight.detach())
return self.loss_weight * cls_loss
def get_channel_num(self, num_classes):
num_channel = num_classes + 1
return num_channel
def get_activation(self, cls_score):
cls_score = torch.sigmoid(cls_score)
n_i, n_c = cls_score.size()
bg_score = cls_score[:, -1].view(n_i, 1)
cls_score[:, :-1] *= (1 - bg_score)
return cls_score
def collect_grad(self, cls_score, target, weight):
prob = torch.sigmoid(cls_score)
grad = target * (prob - 1) + (1 - target) * prob
grad = torch.abs(grad)
# do not collect grad for objectiveness branch [:-1]
pos_grad = torch.sum(grad * target * weight, dim=0)[:-1]
neg_grad = torch.sum(grad * (1 - target) * weight, dim=0)[:-1]
dist.all_reduce(pos_grad)
dist.all_reduce(neg_grad)
self._pos_grad += pos_grad
self._neg_grad += neg_grad
self.pos_neg = self._pos_grad / (self._neg_grad + 1e-10)
def get_weight(self, cls_score):
# we do not have information about pos grad and neg grad at beginning
if self._pos_grad is None:
self._pos_grad = cls_score.new_zeros(self.num_classes)
self._neg_grad = cls_score.new_zeros(self.num_classes)
neg_w = cls_score.new_ones((self.n_i, self.n_c))
pos_w = cls_score.new_ones((self.n_i, self.n_c))
else:
# the negative weight for objectiveness is always 1
neg_w = torch.cat([self.map_func(self.pos_neg), cls_score.new_ones(1)])
pos_w = 1 + self.alpha * (1 - neg_w)
neg_w = neg_w.view(1, -1).expand(self.n_i, self.n_c)
pos_w = pos_w.view(1, -1).expand(self.n_i, self.n_c)
return pos_w, neg_w
|