|
import mmcv |
|
import torch |
|
|
|
from mmdet.core import bbox_overlaps |
|
|
|
|
|
@mmcv.jit(derivate=True, coderize=True) |
|
def isr_p(cls_score, |
|
bbox_pred, |
|
bbox_targets, |
|
rois, |
|
sampling_results, |
|
loss_cls, |
|
bbox_coder, |
|
k=2, |
|
bias=0, |
|
num_class=80): |
|
"""Importance-based Sample Reweighting (ISR_P), positive part. |
|
|
|
Args: |
|
cls_score (Tensor): Predicted classification scores. |
|
bbox_pred (Tensor): Predicted bbox deltas. |
|
bbox_targets (tuple[Tensor]): A tuple of bbox targets, the are |
|
labels, label_weights, bbox_targets, bbox_weights, respectively. |
|
rois (Tensor): Anchors (single_stage) in shape (n, 4) or RoIs |
|
(two_stage) in shape (n, 5). |
|
sampling_results (obj): Sampling results. |
|
loss_cls (func): Classification loss func of the head. |
|
bbox_coder (obj): BBox coder of the head. |
|
k (float): Power of the non-linear mapping. |
|
bias (float): Shift of the non-linear mapping. |
|
num_class (int): Number of classes, default: 80. |
|
|
|
Return: |
|
tuple([Tensor]): labels, imp_based_label_weights, bbox_targets, |
|
bbox_target_weights |
|
""" |
|
|
|
labels, label_weights, bbox_targets, bbox_weights = bbox_targets |
|
pos_label_inds = ((labels >= 0) & |
|
(labels < num_class)).nonzero().reshape(-1) |
|
pos_labels = labels[pos_label_inds] |
|
|
|
|
|
num_pos = float(pos_label_inds.size(0)) |
|
if num_pos == 0: |
|
return labels, label_weights, bbox_targets, bbox_weights |
|
|
|
|
|
gts = list() |
|
last_max_gt = 0 |
|
for i in range(len(sampling_results)): |
|
gt_i = sampling_results[i].pos_assigned_gt_inds |
|
gts.append(gt_i + last_max_gt) |
|
if len(gt_i) != 0: |
|
last_max_gt = gt_i.max() + 1 |
|
gts = torch.cat(gts) |
|
assert len(gts) == num_pos |
|
|
|
cls_score = cls_score.detach() |
|
bbox_pred = bbox_pred.detach() |
|
|
|
|
|
|
|
if rois.size(-1) == 5: |
|
pos_rois = rois[pos_label_inds][:, 1:] |
|
else: |
|
pos_rois = rois[pos_label_inds] |
|
|
|
if bbox_pred.size(-1) > 4: |
|
bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4) |
|
pos_delta_pred = bbox_pred[pos_label_inds, pos_labels].view(-1, 4) |
|
else: |
|
pos_delta_pred = bbox_pred[pos_label_inds].view(-1, 4) |
|
|
|
|
|
pos_delta_target = bbox_targets[pos_label_inds].view(-1, 4) |
|
pos_bbox_pred = bbox_coder.decode(pos_rois, pos_delta_pred) |
|
target_bbox_pred = bbox_coder.decode(pos_rois, pos_delta_target) |
|
ious = bbox_overlaps(pos_bbox_pred, target_bbox_pred, is_aligned=True) |
|
|
|
pos_imp_weights = label_weights[pos_label_inds] |
|
|
|
|
|
max_l_num = pos_labels.bincount().max() |
|
for label in pos_labels.unique(): |
|
l_inds = (pos_labels == label).nonzero().view(-1) |
|
l_gts = gts[l_inds] |
|
for t in l_gts.unique(): |
|
t_inds = l_inds[l_gts == t] |
|
t_ious = ious[t_inds] |
|
_, t_iou_rank_idx = t_ious.sort(descending=True) |
|
_, t_iou_rank = t_iou_rank_idx.sort() |
|
ious[t_inds] += max_l_num - t_iou_rank.float() |
|
l_ious = ious[l_inds] |
|
_, l_iou_rank_idx = l_ious.sort(descending=True) |
|
_, l_iou_rank = l_iou_rank_idx.sort() |
|
|
|
pos_imp_weights[l_inds] *= (max_l_num - l_iou_rank.float()) / max_l_num |
|
|
|
pos_imp_weights = (bias + pos_imp_weights * (1 - bias)).pow(k) |
|
|
|
|
|
pos_loss_cls = loss_cls( |
|
cls_score[pos_label_inds], pos_labels, reduction_override='none') |
|
if pos_loss_cls.dim() > 1: |
|
ori_pos_loss_cls = pos_loss_cls * label_weights[pos_label_inds][:, |
|
None] |
|
new_pos_loss_cls = pos_loss_cls * pos_imp_weights[:, None] |
|
else: |
|
ori_pos_loss_cls = pos_loss_cls * label_weights[pos_label_inds] |
|
new_pos_loss_cls = pos_loss_cls * pos_imp_weights |
|
pos_loss_cls_ratio = ori_pos_loss_cls.sum() / new_pos_loss_cls.sum() |
|
pos_imp_weights = pos_imp_weights * pos_loss_cls_ratio |
|
label_weights[pos_label_inds] = pos_imp_weights |
|
|
|
bbox_targets = labels, label_weights, bbox_targets, bbox_weights |
|
return bbox_targets |
|
|
|
|
|
@mmcv.jit(derivate=True, coderize=True) |
|
def carl_loss(cls_score, |
|
labels, |
|
bbox_pred, |
|
bbox_targets, |
|
loss_bbox, |
|
k=1, |
|
bias=0.2, |
|
avg_factor=None, |
|
sigmoid=False, |
|
num_class=80): |
|
"""Classification-Aware Regression Loss (CARL). |
|
|
|
Args: |
|
cls_score (Tensor): Predicted classification scores. |
|
labels (Tensor): Targets of classification. |
|
bbox_pred (Tensor): Predicted bbox deltas. |
|
bbox_targets (Tensor): Target of bbox regression. |
|
loss_bbox (func): Regression loss func of the head. |
|
bbox_coder (obj): BBox coder of the head. |
|
k (float): Power of the non-linear mapping. |
|
bias (float): Shift of the non-linear mapping. |
|
avg_factor (int): Average factor used in regression loss. |
|
sigmoid (bool): Activation of the classification score. |
|
num_class (int): Number of classes, default: 80. |
|
|
|
Return: |
|
dict: CARL loss dict. |
|
""" |
|
pos_label_inds = ((labels >= 0) & |
|
(labels < num_class)).nonzero().reshape(-1) |
|
if pos_label_inds.numel() == 0: |
|
return dict(loss_carl=cls_score.sum()[None] * 0.) |
|
pos_labels = labels[pos_label_inds] |
|
|
|
|
|
|
|
if sigmoid: |
|
pos_cls_score = cls_score.sigmoid()[pos_label_inds, pos_labels] |
|
else: |
|
pos_cls_score = cls_score.softmax(-1)[pos_label_inds, pos_labels] |
|
carl_loss_weights = (bias + (1 - bias) * pos_cls_score).pow(k) |
|
|
|
|
|
num_pos = float(pos_cls_score.size(0)) |
|
weight_ratio = num_pos / carl_loss_weights.sum() |
|
carl_loss_weights *= weight_ratio |
|
|
|
if avg_factor is None: |
|
avg_factor = bbox_targets.size(0) |
|
|
|
|
|
if bbox_pred.size(-1) > 4: |
|
bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4) |
|
pos_bbox_preds = bbox_pred[pos_label_inds, pos_labels] |
|
else: |
|
pos_bbox_preds = bbox_pred[pos_label_inds] |
|
ori_loss_reg = loss_bbox( |
|
pos_bbox_preds, |
|
bbox_targets[pos_label_inds], |
|
reduction_override='none') / avg_factor |
|
loss_carl = (ori_loss_reg * carl_loss_weights[:, None]).sum() |
|
return dict(loss_carl=loss_carl[None]) |
|
|