Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmdet.structures.bbox import BaseBoxes | |
from mmdet.utils import ConfigType | |
from torch import Tensor | |
from mmyolo.registry import TASK_UTILS | |
INF = 100000000 | |
EPS = 1.0e-7 | |
def find_inside_points(boxes: Tensor, | |
points: Tensor, | |
box_dim: int = 4, | |
eps: float = 0.01) -> Tensor: | |
"""Find inside box points in batches. Boxes dimension must be 3. | |
Args: | |
boxes (Tensor): Boxes tensor. Must be batch input. | |
Has shape of (batch_size, n_boxes, box_dim). | |
points (Tensor): Points coordinates. Has shape of (n_points, 2). | |
box_dim (int): The dimension of box. 4 means horizontal box and | |
5 means rotated box. Defaults to 4. | |
eps (float): Make sure the points are inside not on the boundary. | |
Only use in rotated boxes. Defaults to 0.01. | |
Returns: | |
Tensor: A BoolTensor indicating whether a point is inside | |
boxes. The index has shape of (n_points, batch_size, n_boxes). | |
""" | |
if box_dim == 4: | |
# Horizontal Boxes | |
lt_ = points[:, None, None] - boxes[..., :2] | |
rb_ = boxes[..., 2:] - points[:, None, None] | |
deltas = torch.cat([lt_, rb_], dim=-1) | |
is_in_gts = deltas.min(dim=-1).values > 0 | |
elif box_dim == 5: | |
# Rotated Boxes | |
points = points[:, None, None] | |
ctrs, wh, t = torch.split(boxes, [2, 2, 1], dim=-1) | |
cos_value, sin_value = torch.cos(t), torch.sin(t) | |
matrix = torch.cat([cos_value, sin_value, -sin_value, cos_value], | |
dim=-1).reshape(*boxes.shape[:-1], 2, 2) | |
offset = points - ctrs | |
offset = torch.matmul(matrix, offset[..., None]) | |
offset = offset.squeeze(-1) | |
offset_x, offset_y = offset[..., 0], offset[..., 1] | |
w, h = wh[..., 0], wh[..., 1] | |
is_in_gts = (offset_x <= w / 2 - eps) & (offset_x >= - w / 2 + eps) & \ | |
(offset_y <= h / 2 - eps) & (offset_y >= - h / 2 + eps) | |
else: | |
raise NotImplementedError(f'Unsupport box_dim:{box_dim}') | |
return is_in_gts | |
def get_box_center(boxes: Tensor, box_dim: int = 4) -> Tensor: | |
"""Return a tensor representing the centers of boxes. | |
Args: | |
boxes (Tensor): Boxes tensor. Has shape of (b, n, box_dim) | |
box_dim (int): The dimension of box. 4 means horizontal box and | |
5 means rotated box. Defaults to 4. | |
Returns: | |
Tensor: Centers have shape of (b, n, 2) | |
""" | |
if box_dim == 4: | |
# Horizontal Boxes, (x1, y1, x2, y2) | |
return (boxes[..., :2] + boxes[..., 2:]) / 2.0 | |
elif box_dim == 5: | |
# Rotated Boxes, (x, y, w, h, a) | |
return boxes[..., :2] | |
else: | |
raise NotImplementedError(f'Unsupported box_dim:{box_dim}') | |
class BatchDynamicSoftLabelAssigner(nn.Module): | |
"""Computes matching between predictions and ground truth with dynamic soft | |
label assignment. | |
Args: | |
num_classes (int): number of class | |
soft_center_radius (float): Radius of the soft center prior. | |
Defaults to 3.0. | |
topk (int): Select top-k predictions to calculate dynamic k | |
best matches for each gt. Defaults to 13. | |
iou_weight (float): The scale factor of iou cost. Defaults to 3.0. | |
iou_calculator (ConfigType): Config of overlaps Calculator. | |
Defaults to dict(type='BboxOverlaps2D'). | |
batch_iou (bool): Use batch input when calculate IoU. | |
If set to False use loop instead. Defaults to True. | |
""" | |
def __init__( | |
self, | |
num_classes, | |
soft_center_radius: float = 3.0, | |
topk: int = 13, | |
iou_weight: float = 3.0, | |
iou_calculator: ConfigType = dict(type='mmdet.BboxOverlaps2D'), | |
batch_iou: bool = True, | |
) -> None: | |
super().__init__() | |
self.num_classes = num_classes | |
self.soft_center_radius = soft_center_radius | |
self.topk = topk | |
self.iou_weight = iou_weight | |
self.iou_calculator = TASK_UTILS.build(iou_calculator) | |
self.batch_iou = batch_iou | |
def forward(self, pred_bboxes: Tensor, pred_scores: Tensor, priors: Tensor, | |
gt_labels: Tensor, gt_bboxes: Tensor, | |
pad_bbox_flag: Tensor) -> dict: | |
num_gt = gt_bboxes.size(1) | |
decoded_bboxes = pred_bboxes | |
batch_size, num_bboxes, box_dim = decoded_bboxes.size() | |
if num_gt == 0 or num_bboxes == 0: | |
return { | |
'assigned_labels': | |
gt_labels.new_full( | |
pred_scores[..., 0].shape, | |
self.num_classes, | |
dtype=torch.long), | |
'assigned_labels_weights': | |
gt_bboxes.new_full(pred_scores[..., 0].shape, 1), | |
'assigned_bboxes': | |
gt_bboxes.new_full(pred_bboxes.shape, 0), | |
'assign_metrics': | |
gt_bboxes.new_full(pred_scores[..., 0].shape, 0) | |
} | |
prior_center = priors[:, :2] | |
if isinstance(gt_bboxes, BaseBoxes): | |
raise NotImplementedError( | |
f'type of {type(gt_bboxes)} are not implemented !') | |
else: | |
is_in_gts = find_inside_points(gt_bboxes, prior_center, box_dim) | |
# (N_points, B, N_boxes) | |
is_in_gts = is_in_gts * pad_bbox_flag[..., 0][None] | |
# (N_points, B, N_boxes) -> (B, N_points, N_boxes) | |
is_in_gts = is_in_gts.permute(1, 0, 2) | |
# (B, N_points) | |
valid_mask = is_in_gts.sum(dim=-1) > 0 | |
gt_center = get_box_center(gt_bboxes, box_dim) | |
strides = priors[..., 2] | |
distance = (priors[None].unsqueeze(2)[..., :2] - | |
gt_center[:, None, :, :] | |
).pow(2).sum(-1).sqrt() / strides[None, :, None] | |
# prevent overflow | |
distance = distance * valid_mask.unsqueeze(-1) | |
soft_center_prior = torch.pow(10, distance - self.soft_center_radius) | |
if self.batch_iou: | |
pairwise_ious = self.iou_calculator(decoded_bboxes, gt_bboxes) | |
else: | |
ious = [] | |
for box, gt in zip(decoded_bboxes, gt_bboxes): | |
iou = self.iou_calculator(box, gt) | |
ious.append(iou) | |
pairwise_ious = torch.stack(ious, dim=0) | |
iou_cost = -torch.log(pairwise_ious + EPS) * self.iou_weight | |
# select the predicted scores corresponded to the gt_labels | |
pairwise_pred_scores = pred_scores.permute(0, 2, 1) | |
idx = torch.zeros([2, batch_size, num_gt], dtype=torch.long) | |
idx[0] = torch.arange(end=batch_size).view(-1, 1).repeat(1, num_gt) | |
idx[1] = gt_labels.long().squeeze(-1) | |
pairwise_pred_scores = pairwise_pred_scores[idx[0], | |
idx[1]].permute(0, 2, 1) | |
# classification cost | |
scale_factor = pairwise_ious - pairwise_pred_scores.sigmoid() | |
pairwise_cls_cost = F.binary_cross_entropy_with_logits( | |
pairwise_pred_scores, pairwise_ious, | |
reduction='none') * scale_factor.abs().pow(2.0) | |
cost_matrix = pairwise_cls_cost + iou_cost + soft_center_prior | |
max_pad_value = torch.ones_like(cost_matrix) * INF | |
cost_matrix = torch.where(valid_mask[..., None].repeat(1, 1, num_gt), | |
cost_matrix, max_pad_value) | |
(matched_pred_ious, matched_gt_inds, | |
fg_mask_inboxes) = self.dynamic_k_matching(cost_matrix, pairwise_ious, | |
pad_bbox_flag) | |
del pairwise_ious, cost_matrix | |
batch_index = (fg_mask_inboxes > 0).nonzero(as_tuple=True)[0] | |
assigned_labels = gt_labels.new_full(pred_scores[..., 0].shape, | |
self.num_classes) | |
assigned_labels[fg_mask_inboxes] = gt_labels[ | |
batch_index, matched_gt_inds].squeeze(-1) | |
assigned_labels = assigned_labels.long() | |
assigned_labels_weights = gt_bboxes.new_full(pred_scores[..., 0].shape, | |
1) | |
assigned_bboxes = gt_bboxes.new_full(pred_bboxes.shape, 0) | |
assigned_bboxes[fg_mask_inboxes] = gt_bboxes[batch_index, | |
matched_gt_inds] | |
assign_metrics = gt_bboxes.new_full(pred_scores[..., 0].shape, 0) | |
assign_metrics[fg_mask_inboxes] = matched_pred_ious | |
return dict( | |
assigned_labels=assigned_labels, | |
assigned_labels_weights=assigned_labels_weights, | |
assigned_bboxes=assigned_bboxes, | |
assign_metrics=assign_metrics) | |
def dynamic_k_matching( | |
self, cost_matrix: Tensor, pairwise_ious: Tensor, | |
pad_bbox_flag: int) -> Tuple[Tensor, Tensor, Tensor]: | |
"""Use IoU and matching cost to calculate the dynamic top-k positive | |
targets. | |
Args: | |
cost_matrix (Tensor): Cost matrix. | |
pairwise_ious (Tensor): Pairwise iou matrix. | |
num_gt (int): Number of gt. | |
valid_mask (Tensor): Mask for valid bboxes. | |
Returns: | |
tuple: matched ious and gt indexes. | |
""" | |
matching_matrix = torch.zeros_like(cost_matrix, dtype=torch.uint8) | |
# select candidate topk ious for dynamic-k calculation | |
candidate_topk = min(self.topk, pairwise_ious.size(1)) | |
topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=1) | |
# calculate dynamic k for each gt | |
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1) | |
num_gts = pad_bbox_flag.sum((1, 2)).int() | |
# sorting the batch cost matirx is faster than topk | |
_, sorted_indices = torch.sort(cost_matrix, dim=1) | |
for b in range(pad_bbox_flag.shape[0]): | |
for gt_idx in range(num_gts[b]): | |
topk_ids = sorted_indices[b, :dynamic_ks[b, gt_idx], gt_idx] | |
matching_matrix[b, :, gt_idx][topk_ids] = 1 | |
del topk_ious, dynamic_ks | |
prior_match_gt_mask = matching_matrix.sum(2) > 1 | |
if prior_match_gt_mask.sum() > 0: | |
cost_min, cost_argmin = torch.min( | |
cost_matrix[prior_match_gt_mask, :], dim=1) | |
matching_matrix[prior_match_gt_mask, :] *= 0 | |
matching_matrix[prior_match_gt_mask, cost_argmin] = 1 | |
# get foreground mask inside box and center prior | |
fg_mask_inboxes = matching_matrix.sum(2) > 0 | |
matched_pred_ious = (matching_matrix * | |
pairwise_ious).sum(2)[fg_mask_inboxes] | |
matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1) | |
return matched_pred_ious, matched_gt_inds, fg_mask_inboxes | |