KyanChen's picture
Upload 89 files
3094730
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.structures.bbox import BaseBoxes
from mmdet.utils import ConfigType
from torch import Tensor
from mmyolo.registry import TASK_UTILS
INF = 100000000
EPS = 1.0e-7
def find_inside_points(boxes: Tensor,
points: Tensor,
box_dim: int = 4,
eps: float = 0.01) -> Tensor:
"""Find inside box points in batches. Boxes dimension must be 3.
Args:
boxes (Tensor): Boxes tensor. Must be batch input.
Has shape of (batch_size, n_boxes, box_dim).
points (Tensor): Points coordinates. Has shape of (n_points, 2).
box_dim (int): The dimension of box. 4 means horizontal box and
5 means rotated box. Defaults to 4.
eps (float): Make sure the points are inside not on the boundary.
Only use in rotated boxes. Defaults to 0.01.
Returns:
Tensor: A BoolTensor indicating whether a point is inside
boxes. The index has shape of (n_points, batch_size, n_boxes).
"""
if box_dim == 4:
# Horizontal Boxes
lt_ = points[:, None, None] - boxes[..., :2]
rb_ = boxes[..., 2:] - points[:, None, None]
deltas = torch.cat([lt_, rb_], dim=-1)
is_in_gts = deltas.min(dim=-1).values > 0
elif box_dim == 5:
# Rotated Boxes
points = points[:, None, None]
ctrs, wh, t = torch.split(boxes, [2, 2, 1], dim=-1)
cos_value, sin_value = torch.cos(t), torch.sin(t)
matrix = torch.cat([cos_value, sin_value, -sin_value, cos_value],
dim=-1).reshape(*boxes.shape[:-1], 2, 2)
offset = points - ctrs
offset = torch.matmul(matrix, offset[..., None])
offset = offset.squeeze(-1)
offset_x, offset_y = offset[..., 0], offset[..., 1]
w, h = wh[..., 0], wh[..., 1]
is_in_gts = (offset_x <= w / 2 - eps) & (offset_x >= - w / 2 + eps) & \
(offset_y <= h / 2 - eps) & (offset_y >= - h / 2 + eps)
else:
raise NotImplementedError(f'Unsupport box_dim:{box_dim}')
return is_in_gts
def get_box_center(boxes: Tensor, box_dim: int = 4) -> Tensor:
"""Return a tensor representing the centers of boxes.
Args:
boxes (Tensor): Boxes tensor. Has shape of (b, n, box_dim)
box_dim (int): The dimension of box. 4 means horizontal box and
5 means rotated box. Defaults to 4.
Returns:
Tensor: Centers have shape of (b, n, 2)
"""
if box_dim == 4:
# Horizontal Boxes, (x1, y1, x2, y2)
return (boxes[..., :2] + boxes[..., 2:]) / 2.0
elif box_dim == 5:
# Rotated Boxes, (x, y, w, h, a)
return boxes[..., :2]
else:
raise NotImplementedError(f'Unsupported box_dim:{box_dim}')
@TASK_UTILS.register_module()
class BatchDynamicSoftLabelAssigner(nn.Module):
"""Computes matching between predictions and ground truth with dynamic soft
label assignment.
Args:
num_classes (int): number of class
soft_center_radius (float): Radius of the soft center prior.
Defaults to 3.0.
topk (int): Select top-k predictions to calculate dynamic k
best matches for each gt. Defaults to 13.
iou_weight (float): The scale factor of iou cost. Defaults to 3.0.
iou_calculator (ConfigType): Config of overlaps Calculator.
Defaults to dict(type='BboxOverlaps2D').
batch_iou (bool): Use batch input when calculate IoU.
If set to False use loop instead. Defaults to True.
"""
def __init__(
self,
num_classes,
soft_center_radius: float = 3.0,
topk: int = 13,
iou_weight: float = 3.0,
iou_calculator: ConfigType = dict(type='mmdet.BboxOverlaps2D'),
batch_iou: bool = True,
) -> None:
super().__init__()
self.num_classes = num_classes
self.soft_center_radius = soft_center_radius
self.topk = topk
self.iou_weight = iou_weight
self.iou_calculator = TASK_UTILS.build(iou_calculator)
self.batch_iou = batch_iou
@torch.no_grad()
def forward(self, pred_bboxes: Tensor, pred_scores: Tensor, priors: Tensor,
gt_labels: Tensor, gt_bboxes: Tensor,
pad_bbox_flag: Tensor) -> dict:
num_gt = gt_bboxes.size(1)
decoded_bboxes = pred_bboxes
batch_size, num_bboxes, box_dim = decoded_bboxes.size()
if num_gt == 0 or num_bboxes == 0:
return {
'assigned_labels':
gt_labels.new_full(
pred_scores[..., 0].shape,
self.num_classes,
dtype=torch.long),
'assigned_labels_weights':
gt_bboxes.new_full(pred_scores[..., 0].shape, 1),
'assigned_bboxes':
gt_bboxes.new_full(pred_bboxes.shape, 0),
'assign_metrics':
gt_bboxes.new_full(pred_scores[..., 0].shape, 0)
}
prior_center = priors[:, :2]
if isinstance(gt_bboxes, BaseBoxes):
raise NotImplementedError(
f'type of {type(gt_bboxes)} are not implemented !')
else:
is_in_gts = find_inside_points(gt_bboxes, prior_center, box_dim)
# (N_points, B, N_boxes)
is_in_gts = is_in_gts * pad_bbox_flag[..., 0][None]
# (N_points, B, N_boxes) -> (B, N_points, N_boxes)
is_in_gts = is_in_gts.permute(1, 0, 2)
# (B, N_points)
valid_mask = is_in_gts.sum(dim=-1) > 0
gt_center = get_box_center(gt_bboxes, box_dim)
strides = priors[..., 2]
distance = (priors[None].unsqueeze(2)[..., :2] -
gt_center[:, None, :, :]
).pow(2).sum(-1).sqrt() / strides[None, :, None]
# prevent overflow
distance = distance * valid_mask.unsqueeze(-1)
soft_center_prior = torch.pow(10, distance - self.soft_center_radius)
if self.batch_iou:
pairwise_ious = self.iou_calculator(decoded_bboxes, gt_bboxes)
else:
ious = []
for box, gt in zip(decoded_bboxes, gt_bboxes):
iou = self.iou_calculator(box, gt)
ious.append(iou)
pairwise_ious = torch.stack(ious, dim=0)
iou_cost = -torch.log(pairwise_ious + EPS) * self.iou_weight
# select the predicted scores corresponded to the gt_labels
pairwise_pred_scores = pred_scores.permute(0, 2, 1)
idx = torch.zeros([2, batch_size, num_gt], dtype=torch.long)
idx[0] = torch.arange(end=batch_size).view(-1, 1).repeat(1, num_gt)
idx[1] = gt_labels.long().squeeze(-1)
pairwise_pred_scores = pairwise_pred_scores[idx[0],
idx[1]].permute(0, 2, 1)
# classification cost
scale_factor = pairwise_ious - pairwise_pred_scores.sigmoid()
pairwise_cls_cost = F.binary_cross_entropy_with_logits(
pairwise_pred_scores, pairwise_ious,
reduction='none') * scale_factor.abs().pow(2.0)
cost_matrix = pairwise_cls_cost + iou_cost + soft_center_prior
max_pad_value = torch.ones_like(cost_matrix) * INF
cost_matrix = torch.where(valid_mask[..., None].repeat(1, 1, num_gt),
cost_matrix, max_pad_value)
(matched_pred_ious, matched_gt_inds,
fg_mask_inboxes) = self.dynamic_k_matching(cost_matrix, pairwise_ious,
pad_bbox_flag)
del pairwise_ious, cost_matrix
batch_index = (fg_mask_inboxes > 0).nonzero(as_tuple=True)[0]
assigned_labels = gt_labels.new_full(pred_scores[..., 0].shape,
self.num_classes)
assigned_labels[fg_mask_inboxes] = gt_labels[
batch_index, matched_gt_inds].squeeze(-1)
assigned_labels = assigned_labels.long()
assigned_labels_weights = gt_bboxes.new_full(pred_scores[..., 0].shape,
1)
assigned_bboxes = gt_bboxes.new_full(pred_bboxes.shape, 0)
assigned_bboxes[fg_mask_inboxes] = gt_bboxes[batch_index,
matched_gt_inds]
assign_metrics = gt_bboxes.new_full(pred_scores[..., 0].shape, 0)
assign_metrics[fg_mask_inboxes] = matched_pred_ious
return dict(
assigned_labels=assigned_labels,
assigned_labels_weights=assigned_labels_weights,
assigned_bboxes=assigned_bboxes,
assign_metrics=assign_metrics)
def dynamic_k_matching(
self, cost_matrix: Tensor, pairwise_ious: Tensor,
pad_bbox_flag: int) -> Tuple[Tensor, Tensor, Tensor]:
"""Use IoU and matching cost to calculate the dynamic top-k positive
targets.
Args:
cost_matrix (Tensor): Cost matrix.
pairwise_ious (Tensor): Pairwise iou matrix.
num_gt (int): Number of gt.
valid_mask (Tensor): Mask for valid bboxes.
Returns:
tuple: matched ious and gt indexes.
"""
matching_matrix = torch.zeros_like(cost_matrix, dtype=torch.uint8)
# select candidate topk ious for dynamic-k calculation
candidate_topk = min(self.topk, pairwise_ious.size(1))
topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=1)
# calculate dynamic k for each gt
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
num_gts = pad_bbox_flag.sum((1, 2)).int()
# sorting the batch cost matirx is faster than topk
_, sorted_indices = torch.sort(cost_matrix, dim=1)
for b in range(pad_bbox_flag.shape[0]):
for gt_idx in range(num_gts[b]):
topk_ids = sorted_indices[b, :dynamic_ks[b, gt_idx], gt_idx]
matching_matrix[b, :, gt_idx][topk_ids] = 1
del topk_ious, dynamic_ks
prior_match_gt_mask = matching_matrix.sum(2) > 1
if prior_match_gt_mask.sum() > 0:
cost_min, cost_argmin = torch.min(
cost_matrix[prior_match_gt_mask, :], dim=1)
matching_matrix[prior_match_gt_mask, :] *= 0
matching_matrix[prior_match_gt_mask, cost_argmin] = 1
# get foreground mask inside box and center prior
fg_mask_inboxes = matching_matrix.sum(2) > 0
matched_pred_ious = (matching_matrix *
pairwise_ious).sum(2)[fg_mask_inboxes]
matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
return matched_pred_ious, matched_gt_inds, fg_mask_inboxes