|
""" |
|
This code is based on the following file: |
|
https://github.com/tztztztztz/eqlv2/blob/master/mmdet/models/losses/eql.py |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from ..builder import LOSSES |
|
|
|
|
|
def get_image_count_frequency(version="v0_5"): |
|
if version == "v0_5": |
|
from mmdet.utils.lvis_v0_5_categories import get_image_count_frequency |
|
return get_image_count_frequency() |
|
elif version == "v1": |
|
from mmdet.utils.lvis_v1_0_categories import get_image_count_frequency |
|
return get_image_count_frequency() |
|
elif version == "openimage": |
|
from mmdet.utils.openimage_categories import get_instance_count |
|
return get_instance_count() |
|
elif version == "NDL": |
|
from mmdet.utils.ndl_categories import get_instance_count |
|
return get_instance_count() |
|
else: |
|
raise KeyError(f"version {version} is not supported") |
|
|
|
|
|
@LOSSES.register_module() |
|
class EQL(nn.Module): |
|
def __init__(self, |
|
use_sigmoid=True, |
|
reduction='mean', |
|
class_weight=None, |
|
loss_weight=1.0, |
|
lambda_=0.00177, |
|
version="v0_5"): |
|
super(EQL, self).__init__() |
|
self.use_sigmoid = use_sigmoid |
|
self.reduction = reduction |
|
self.loss_weight = loss_weight |
|
self.class_weight = class_weight |
|
self.lambda_ = lambda_ |
|
self.version = version |
|
self.freq_info = torch.FloatTensor(get_image_count_frequency(version)) |
|
|
|
num_class_included = torch.sum(self.freq_info < self.lambda_) |
|
print(f"set up EQL (version {version}), {num_class_included} classes included.") |
|
|
|
def forward(self, |
|
cls_score, |
|
label, |
|
weight=None, |
|
avg_factor=None, |
|
reduction_override=None, |
|
**kwargs): |
|
self.n_i, self.n_c = cls_score.size() |
|
|
|
self.gt_classes = label |
|
self.pred_class_logits = cls_score |
|
|
|
def expand_label(pred, gt_classes): |
|
target = pred.new_zeros(self.n_i, self.n_c + 1) |
|
target[torch.arange(self.n_i), gt_classes] = 1 |
|
return target[:, :self.n_c] |
|
|
|
target = expand_label(cls_score, label) |
|
|
|
eql_w = 1 - self.exclude_func() * self.threshold_func() * (1 - target) |
|
|
|
cls_loss = F.binary_cross_entropy_with_logits(cls_score, target, |
|
reduction='none') |
|
|
|
cls_loss = torch.sum(cls_loss * eql_w) / self.n_i |
|
|
|
return self.loss_weight * cls_loss |
|
|
|
def exclude_func(self): |
|
|
|
bg_ind = self.n_c |
|
weight = (self.gt_classes != bg_ind).float() |
|
weight = weight.view(self.n_i, 1).expand(self.n_i, self.n_c) |
|
return weight |
|
|
|
def threshold_func(self): |
|
|
|
weight = self.pred_class_logits.new_zeros(self.n_c) |
|
|
|
for i in range(len(weight)): |
|
if self.freq_info[i] < self.lambda_: |
|
weight[i] = 1 |
|
weight = weight.view(1, self.n_c).expand(self.n_i, self.n_c) |
|
return weight |