File size: 7,314 Bytes
c9019cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
"""
This code is based on the following file:
https://github.com/tztztztztz/eqlv2/blob/master/mmdet/models/losses/group_softmax.py
"""
import bisect
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from mmdet.utils import get_root_logger
def get_instance_count(version="v1"):
if version == "v0_5":
from mmdet.utils.lvis_v0_5_categories import get_instance_count
return get_instance_count()
elif version == "v1":
from mmdet.utils.lvis_v1_0_categories import get_instance_count
return get_instance_count()
else:
raise KeyError(f"version {version} is not supported")
@LOSSES.register_module()
class GroupSoftmax(nn.Module):
"""
This uses a different encoding from v1.
v1: [cls1, cls2, ..., other1_for_group0, other_for_group_1, bg, bg_others]
this: [group0_others, group0_cls0, ..., group1_others, group1_cls0, ...]
"""
def __init__(self,
use_sigmoid=False,
reduction='mean',
class_weight=None,
loss_weight=1.0,
beta=8,
bin_split=(10, 100, 1000),
version="v1"):
super(GroupSoftmax, self).__init__()
self.use_sigmoid = False
self.group = True
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.beta = beta
self.bin_split = bin_split
self.version = version
self.cat_instance_count = get_instance_count(version=self.version)
self.num_classes = len(self.cat_instance_count)
assert not self.use_sigmoid
self._assign_group()
self._prepare_for_label_remapping()
logger = get_root_logger()
logger.info(f"set up {self.__class__.__name__}, version {self.version}, beta: {self.beta}, num classes {self.num_classes}")
def _assign_group(self):
self.num_group = (len(self.bin_split) + 1) + 1 # add a group for background
self.group_cls_ids = [[] for _ in range(self.num_group)]
group_ids = list(map(lambda x: bisect.bisect_right(self.bin_split, x), self.cat_instance_count))
self.group_ids = group_ids + [self.num_group - 1]
for cls_id, group_id in enumerate(self.group_ids):
self.group_cls_ids[group_id].append(cls_id)
logger = get_root_logger()
logger.info(f"group number {self.num_group}")
self.n_cls_group = list(map(lambda x: len(x), self.group_cls_ids))
logger.info(f"number of classes in group: {self.n_cls_group}")
def _get_group_pred(self, cls_score, apply_activation_func=False):
group_pred = []
start = 0
for group_id, n_cls in enumerate(self.n_cls_group):
num_logits = n_cls + 1 # + 1 for "others"
pred = cls_score.narrow(1, start, num_logits)
start = start + num_logits
if apply_activation_func:
pred = F.softmax(pred, dim=1)
group_pred.append(pred)
assert start == self.num_classes + 1 + self.num_group
return group_pred
def _prepare_for_label_remapping(self):
group_label_maps = []
for group_id, n_cls in enumerate(self.n_cls_group):
label_map = [0 for _ in range(self.num_classes + 1)]
group_label_maps.append(label_map)
# init value is 1 because 0 is set for "others"
_tmp_group_num = [1 for _ in range(self.num_group)]
for cls_id, group_id in enumerate(self.group_ids):
g_p = _tmp_group_num[group_id] # position in group
group_label_maps[group_id][cls_id] = g_p
_tmp_group_num[group_id] += 1
self.group_label_maps = torch.LongTensor(group_label_maps)
def _remap_labels(self, labels):
new_labels = []
new_weights = [] # use this for sampling others
new_avg = []
for group_id in range(self.num_group):
mapping = self.group_label_maps[group_id]
new_bin_label = mapping[labels]
new_bin_label = torch.LongTensor(new_bin_label).to(labels.device)
if self.is_background_group(group_id):
weight = torch.ones_like(new_bin_label)
else:
weight = self._sample_others(new_bin_label)
new_labels.append(new_bin_label)
new_weights.append(weight)
avg_factor = max(torch.sum(weight).float().item(), 1.)
new_avg.append(avg_factor)
return new_labels, new_weights, new_avg
def _sample_others(self, label):
# only works for non bg-fg bins
fg = torch.where(label > 0, torch.ones_like(label),
torch.zeros_like(label))
fg_idx = fg.nonzero(as_tuple=True)[0]
fg_num = fg_idx.shape[0]
if fg_num == 0:
return torch.zeros_like(label)
bg = 1 - fg
bg_idx = bg.nonzero(as_tuple=True)[0]
bg_num = bg_idx.shape[0]
bg_sample_num = int(fg_num * self.beta)
if bg_sample_num >= bg_num:
weight = torch.ones_like(label)
else:
sample_idx = np.random.choice(bg_idx.cpu().numpy(),
(bg_sample_num, ), replace=False)
sample_idx = torch.from_numpy(sample_idx)
fg[sample_idx] = 1
weight = fg
return weight.to(label.device)
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
group_preds = self._get_group_pred(cls_score, apply_activation_func=False)
new_labels, new_weights, new_avg = self._remap_labels(label)
cls_loss = []
for group_id in range(self.num_group):
pred_in_group = group_preds[group_id]
label_in_group = new_labels[group_id]
weight_in_group = new_weights[group_id]
avg_in_group = new_avg[group_id]
loss_in_group = F.cross_entropy(pred_in_group,
label_in_group,
reduction='none',
ignore_index=-1)
loss_in_group = torch.sum(loss_in_group * weight_in_group)
loss_in_group /= avg_in_group
cls_loss.append(loss_in_group)
cls_loss = sum(cls_loss)
return cls_loss * self.loss_weight
def get_activation(self, cls_score):
n_i, n_c = cls_score.size()
group_activation = self._get_group_pred(cls_score, apply_activation_func=True)
bg_score = group_activation[-1]
activation = cls_score.new_zeros((n_i, len(self.group_ids)))
for group_id, cls_ids in enumerate(self.group_cls_ids[:-1]):
activation[:, cls_ids] = group_activation[group_id][:, 1:]
activation *= bg_score[:, [0]]
activation[:, -1] = bg_score[:, 1]
return activation
def get_channel_num(self, num_classes):
num_channel = num_classes + 1 + self.num_group
return num_channel
def is_background_group(self, group_id):
return group_id == self.num_group - 1
|