""" This code is based on the following file: https://github.com/tztztztztz/eqlv2/blob/master/mmdet/models/losses/group_softmax.py """ import bisect import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from ..builder import LOSSES from mmdet.utils import get_root_logger def get_instance_count(version="v1"): if version == "v0_5": from mmdet.utils.lvis_v0_5_categories import get_instance_count return get_instance_count() elif version == "v1": from mmdet.utils.lvis_v1_0_categories import get_instance_count return get_instance_count() else: raise KeyError(f"version {version} is not supported") @LOSSES.register_module() class GroupSoftmax(nn.Module): """ This uses a different encoding from v1. v1: [cls1, cls2, ..., other1_for_group0, other_for_group_1, bg, bg_others] this: [group0_others, group0_cls0, ..., group1_others, group1_cls0, ...] """ def __init__(self, use_sigmoid=False, reduction='mean', class_weight=None, loss_weight=1.0, beta=8, bin_split=(10, 100, 1000), version="v1"): super(GroupSoftmax, self).__init__() self.use_sigmoid = False self.group = True self.reduction = reduction self.loss_weight = loss_weight self.class_weight = class_weight self.beta = beta self.bin_split = bin_split self.version = version self.cat_instance_count = get_instance_count(version=self.version) self.num_classes = len(self.cat_instance_count) assert not self.use_sigmoid self._assign_group() self._prepare_for_label_remapping() logger = get_root_logger() logger.info(f"set up {self.__class__.__name__}, version {self.version}, beta: {self.beta}, num classes {self.num_classes}") def _assign_group(self): self.num_group = (len(self.bin_split) + 1) + 1 # add a group for background self.group_cls_ids = [[] for _ in range(self.num_group)] group_ids = list(map(lambda x: bisect.bisect_right(self.bin_split, x), self.cat_instance_count)) self.group_ids = group_ids + [self.num_group - 1] for cls_id, group_id in enumerate(self.group_ids): self.group_cls_ids[group_id].append(cls_id) logger = get_root_logger() logger.info(f"group number {self.num_group}") self.n_cls_group = list(map(lambda x: len(x), self.group_cls_ids)) logger.info(f"number of classes in group: {self.n_cls_group}") def _get_group_pred(self, cls_score, apply_activation_func=False): group_pred = [] start = 0 for group_id, n_cls in enumerate(self.n_cls_group): num_logits = n_cls + 1 # + 1 for "others" pred = cls_score.narrow(1, start, num_logits) start = start + num_logits if apply_activation_func: pred = F.softmax(pred, dim=1) group_pred.append(pred) assert start == self.num_classes + 1 + self.num_group return group_pred def _prepare_for_label_remapping(self): group_label_maps = [] for group_id, n_cls in enumerate(self.n_cls_group): label_map = [0 for _ in range(self.num_classes + 1)] group_label_maps.append(label_map) # init value is 1 because 0 is set for "others" _tmp_group_num = [1 for _ in range(self.num_group)] for cls_id, group_id in enumerate(self.group_ids): g_p = _tmp_group_num[group_id] # position in group group_label_maps[group_id][cls_id] = g_p _tmp_group_num[group_id] += 1 self.group_label_maps = torch.LongTensor(group_label_maps) def _remap_labels(self, labels): new_labels = [] new_weights = [] # use this for sampling others new_avg = [] for group_id in range(self.num_group): mapping = self.group_label_maps[group_id] new_bin_label = mapping[labels] new_bin_label = torch.LongTensor(new_bin_label).to(labels.device) if self.is_background_group(group_id): weight = torch.ones_like(new_bin_label) else: weight = self._sample_others(new_bin_label) new_labels.append(new_bin_label) new_weights.append(weight) avg_factor = max(torch.sum(weight).float().item(), 1.) new_avg.append(avg_factor) return new_labels, new_weights, new_avg def _sample_others(self, label): # only works for non bg-fg bins fg = torch.where(label > 0, torch.ones_like(label), torch.zeros_like(label)) fg_idx = fg.nonzero(as_tuple=True)[0] fg_num = fg_idx.shape[0] if fg_num == 0: return torch.zeros_like(label) bg = 1 - fg bg_idx = bg.nonzero(as_tuple=True)[0] bg_num = bg_idx.shape[0] bg_sample_num = int(fg_num * self.beta) if bg_sample_num >= bg_num: weight = torch.ones_like(label) else: sample_idx = np.random.choice(bg_idx.cpu().numpy(), (bg_sample_num, ), replace=False) sample_idx = torch.from_numpy(sample_idx) fg[sample_idx] = 1 weight = fg return weight.to(label.device) def forward(self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, **kwargs): group_preds = self._get_group_pred(cls_score, apply_activation_func=False) new_labels, new_weights, new_avg = self._remap_labels(label) cls_loss = [] for group_id in range(self.num_group): pred_in_group = group_preds[group_id] label_in_group = new_labels[group_id] weight_in_group = new_weights[group_id] avg_in_group = new_avg[group_id] loss_in_group = F.cross_entropy(pred_in_group, label_in_group, reduction='none', ignore_index=-1) loss_in_group = torch.sum(loss_in_group * weight_in_group) loss_in_group /= avg_in_group cls_loss.append(loss_in_group) cls_loss = sum(cls_loss) return cls_loss * self.loss_weight def get_activation(self, cls_score): n_i, n_c = cls_score.size() group_activation = self._get_group_pred(cls_score, apply_activation_func=True) bg_score = group_activation[-1] activation = cls_score.new_zeros((n_i, len(self.group_ids))) for group_id, cls_ids in enumerate(self.group_cls_ids[:-1]): activation[:, cls_ids] = group_activation[group_id][:, 1:] activation *= bg_score[:, [0]] activation[:, -1] = bg_score[:, 1] return activation def get_channel_num(self, num_classes): num_channel = num_classes + 1 + self.num_group return num_channel def is_background_group(self, group_id): return group_id == self.num_group - 1