File size: 3,198 Bytes
c9019cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
"""
This code is based on the following file:
https://github.com/tztztztztz/eqlv2/blob/master/mmdet/models/losses/eql.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
def get_image_count_frequency(version="v0_5"):
if version == "v0_5":
from mmdet.utils.lvis_v0_5_categories import get_image_count_frequency
return get_image_count_frequency()
elif version == "v1":
from mmdet.utils.lvis_v1_0_categories import get_image_count_frequency
return get_image_count_frequency()
elif version == "openimage":
from mmdet.utils.openimage_categories import get_instance_count
return get_instance_count()
elif version == "NDL":
from mmdet.utils.ndl_categories import get_instance_count
return get_instance_count()
else:
raise KeyError(f"version {version} is not supported")
@LOSSES.register_module()
class EQL(nn.Module):
def __init__(self,
use_sigmoid=True,
reduction='mean',
class_weight=None,
loss_weight=1.0,
lambda_=0.00177,
version="v0_5"):
super(EQL, self).__init__()
self.use_sigmoid = use_sigmoid
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.lambda_ = lambda_
self.version = version
self.freq_info = torch.FloatTensor(get_image_count_frequency(version))
num_class_included = torch.sum(self.freq_info < self.lambda_)
print(f"set up EQL (version {version}), {num_class_included} classes included.")
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
self.n_i, self.n_c = cls_score.size()
self.gt_classes = label
self.pred_class_logits = cls_score
def expand_label(pred, gt_classes):
target = pred.new_zeros(self.n_i, self.n_c + 1)
target[torch.arange(self.n_i), gt_classes] = 1
return target[:, :self.n_c]
target = expand_label(cls_score, label)
eql_w = 1 - self.exclude_func() * self.threshold_func() * (1 - target)
cls_loss = F.binary_cross_entropy_with_logits(cls_score, target,
reduction='none')
cls_loss = torch.sum(cls_loss * eql_w) / self.n_i
return self.loss_weight * cls_loss
def exclude_func(self):
# instance-level weight
bg_ind = self.n_c
weight = (self.gt_classes != bg_ind).float()
weight = weight.view(self.n_i, 1).expand(self.n_i, self.n_c)
return weight
def threshold_func(self):
# class-level weight
weight = self.pred_class_logits.new_zeros(self.n_c)
# weight[self.freq_info < self.lambda_] = 1
for i in range(len(weight)):
if self.freq_info[i] < self.lambda_:
weight[i] = 1
weight = weight.view(1, self.n_c).expand(self.n_i, self.n_c)
return weight |