Spaces:
Runtime error
Runtime error
File size: 10,901 Bytes
3094730 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.structures.bbox import BaseBoxes
from mmdet.utils import ConfigType
from torch import Tensor
from mmyolo.registry import TASK_UTILS
INF = 100000000
EPS = 1.0e-7
def find_inside_points(boxes: Tensor,
points: Tensor,
box_dim: int = 4,
eps: float = 0.01) -> Tensor:
"""Find inside box points in batches. Boxes dimension must be 3.
Args:
boxes (Tensor): Boxes tensor. Must be batch input.
Has shape of (batch_size, n_boxes, box_dim).
points (Tensor): Points coordinates. Has shape of (n_points, 2).
box_dim (int): The dimension of box. 4 means horizontal box and
5 means rotated box. Defaults to 4.
eps (float): Make sure the points are inside not on the boundary.
Only use in rotated boxes. Defaults to 0.01.
Returns:
Tensor: A BoolTensor indicating whether a point is inside
boxes. The index has shape of (n_points, batch_size, n_boxes).
"""
if box_dim == 4:
# Horizontal Boxes
lt_ = points[:, None, None] - boxes[..., :2]
rb_ = boxes[..., 2:] - points[:, None, None]
deltas = torch.cat([lt_, rb_], dim=-1)
is_in_gts = deltas.min(dim=-1).values > 0
elif box_dim == 5:
# Rotated Boxes
points = points[:, None, None]
ctrs, wh, t = torch.split(boxes, [2, 2, 1], dim=-1)
cos_value, sin_value = torch.cos(t), torch.sin(t)
matrix = torch.cat([cos_value, sin_value, -sin_value, cos_value],
dim=-1).reshape(*boxes.shape[:-1], 2, 2)
offset = points - ctrs
offset = torch.matmul(matrix, offset[..., None])
offset = offset.squeeze(-1)
offset_x, offset_y = offset[..., 0], offset[..., 1]
w, h = wh[..., 0], wh[..., 1]
is_in_gts = (offset_x <= w / 2 - eps) & (offset_x >= - w / 2 + eps) & \
(offset_y <= h / 2 - eps) & (offset_y >= - h / 2 + eps)
else:
raise NotImplementedError(f'Unsupport box_dim:{box_dim}')
return is_in_gts
def get_box_center(boxes: Tensor, box_dim: int = 4) -> Tensor:
"""Return a tensor representing the centers of boxes.
Args:
boxes (Tensor): Boxes tensor. Has shape of (b, n, box_dim)
box_dim (int): The dimension of box. 4 means horizontal box and
5 means rotated box. Defaults to 4.
Returns:
Tensor: Centers have shape of (b, n, 2)
"""
if box_dim == 4:
# Horizontal Boxes, (x1, y1, x2, y2)
return (boxes[..., :2] + boxes[..., 2:]) / 2.0
elif box_dim == 5:
# Rotated Boxes, (x, y, w, h, a)
return boxes[..., :2]
else:
raise NotImplementedError(f'Unsupported box_dim:{box_dim}')
@TASK_UTILS.register_module()
class BatchDynamicSoftLabelAssigner(nn.Module):
"""Computes matching between predictions and ground truth with dynamic soft
label assignment.
Args:
num_classes (int): number of class
soft_center_radius (float): Radius of the soft center prior.
Defaults to 3.0.
topk (int): Select top-k predictions to calculate dynamic k
best matches for each gt. Defaults to 13.
iou_weight (float): The scale factor of iou cost. Defaults to 3.0.
iou_calculator (ConfigType): Config of overlaps Calculator.
Defaults to dict(type='BboxOverlaps2D').
batch_iou (bool): Use batch input when calculate IoU.
If set to False use loop instead. Defaults to True.
"""
def __init__(
self,
num_classes,
soft_center_radius: float = 3.0,
topk: int = 13,
iou_weight: float = 3.0,
iou_calculator: ConfigType = dict(type='mmdet.BboxOverlaps2D'),
batch_iou: bool = True,
) -> None:
super().__init__()
self.num_classes = num_classes
self.soft_center_radius = soft_center_radius
self.topk = topk
self.iou_weight = iou_weight
self.iou_calculator = TASK_UTILS.build(iou_calculator)
self.batch_iou = batch_iou
@torch.no_grad()
def forward(self, pred_bboxes: Tensor, pred_scores: Tensor, priors: Tensor,
gt_labels: Tensor, gt_bboxes: Tensor,
pad_bbox_flag: Tensor) -> dict:
num_gt = gt_bboxes.size(1)
decoded_bboxes = pred_bboxes
batch_size, num_bboxes, box_dim = decoded_bboxes.size()
if num_gt == 0 or num_bboxes == 0:
return {
'assigned_labels':
gt_labels.new_full(
pred_scores[..., 0].shape,
self.num_classes,
dtype=torch.long),
'assigned_labels_weights':
gt_bboxes.new_full(pred_scores[..., 0].shape, 1),
'assigned_bboxes':
gt_bboxes.new_full(pred_bboxes.shape, 0),
'assign_metrics':
gt_bboxes.new_full(pred_scores[..., 0].shape, 0)
}
prior_center = priors[:, :2]
if isinstance(gt_bboxes, BaseBoxes):
raise NotImplementedError(
f'type of {type(gt_bboxes)} are not implemented !')
else:
is_in_gts = find_inside_points(gt_bboxes, prior_center, box_dim)
# (N_points, B, N_boxes)
is_in_gts = is_in_gts * pad_bbox_flag[..., 0][None]
# (N_points, B, N_boxes) -> (B, N_points, N_boxes)
is_in_gts = is_in_gts.permute(1, 0, 2)
# (B, N_points)
valid_mask = is_in_gts.sum(dim=-1) > 0
gt_center = get_box_center(gt_bboxes, box_dim)
strides = priors[..., 2]
distance = (priors[None].unsqueeze(2)[..., :2] -
gt_center[:, None, :, :]
).pow(2).sum(-1).sqrt() / strides[None, :, None]
# prevent overflow
distance = distance * valid_mask.unsqueeze(-1)
soft_center_prior = torch.pow(10, distance - self.soft_center_radius)
if self.batch_iou:
pairwise_ious = self.iou_calculator(decoded_bboxes, gt_bboxes)
else:
ious = []
for box, gt in zip(decoded_bboxes, gt_bboxes):
iou = self.iou_calculator(box, gt)
ious.append(iou)
pairwise_ious = torch.stack(ious, dim=0)
iou_cost = -torch.log(pairwise_ious + EPS) * self.iou_weight
# select the predicted scores corresponded to the gt_labels
pairwise_pred_scores = pred_scores.permute(0, 2, 1)
idx = torch.zeros([2, batch_size, num_gt], dtype=torch.long)
idx[0] = torch.arange(end=batch_size).view(-1, 1).repeat(1, num_gt)
idx[1] = gt_labels.long().squeeze(-1)
pairwise_pred_scores = pairwise_pred_scores[idx[0],
idx[1]].permute(0, 2, 1)
# classification cost
scale_factor = pairwise_ious - pairwise_pred_scores.sigmoid()
pairwise_cls_cost = F.binary_cross_entropy_with_logits(
pairwise_pred_scores, pairwise_ious,
reduction='none') * scale_factor.abs().pow(2.0)
cost_matrix = pairwise_cls_cost + iou_cost + soft_center_prior
max_pad_value = torch.ones_like(cost_matrix) * INF
cost_matrix = torch.where(valid_mask[..., None].repeat(1, 1, num_gt),
cost_matrix, max_pad_value)
(matched_pred_ious, matched_gt_inds,
fg_mask_inboxes) = self.dynamic_k_matching(cost_matrix, pairwise_ious,
pad_bbox_flag)
del pairwise_ious, cost_matrix
batch_index = (fg_mask_inboxes > 0).nonzero(as_tuple=True)[0]
assigned_labels = gt_labels.new_full(pred_scores[..., 0].shape,
self.num_classes)
assigned_labels[fg_mask_inboxes] = gt_labels[
batch_index, matched_gt_inds].squeeze(-1)
assigned_labels = assigned_labels.long()
assigned_labels_weights = gt_bboxes.new_full(pred_scores[..., 0].shape,
1)
assigned_bboxes = gt_bboxes.new_full(pred_bboxes.shape, 0)
assigned_bboxes[fg_mask_inboxes] = gt_bboxes[batch_index,
matched_gt_inds]
assign_metrics = gt_bboxes.new_full(pred_scores[..., 0].shape, 0)
assign_metrics[fg_mask_inboxes] = matched_pred_ious
return dict(
assigned_labels=assigned_labels,
assigned_labels_weights=assigned_labels_weights,
assigned_bboxes=assigned_bboxes,
assign_metrics=assign_metrics)
def dynamic_k_matching(
self, cost_matrix: Tensor, pairwise_ious: Tensor,
pad_bbox_flag: int) -> Tuple[Tensor, Tensor, Tensor]:
"""Use IoU and matching cost to calculate the dynamic top-k positive
targets.
Args:
cost_matrix (Tensor): Cost matrix.
pairwise_ious (Tensor): Pairwise iou matrix.
num_gt (int): Number of gt.
valid_mask (Tensor): Mask for valid bboxes.
Returns:
tuple: matched ious and gt indexes.
"""
matching_matrix = torch.zeros_like(cost_matrix, dtype=torch.uint8)
# select candidate topk ious for dynamic-k calculation
candidate_topk = min(self.topk, pairwise_ious.size(1))
topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=1)
# calculate dynamic k for each gt
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
num_gts = pad_bbox_flag.sum((1, 2)).int()
# sorting the batch cost matirx is faster than topk
_, sorted_indices = torch.sort(cost_matrix, dim=1)
for b in range(pad_bbox_flag.shape[0]):
for gt_idx in range(num_gts[b]):
topk_ids = sorted_indices[b, :dynamic_ks[b, gt_idx], gt_idx]
matching_matrix[b, :, gt_idx][topk_ids] = 1
del topk_ious, dynamic_ks
prior_match_gt_mask = matching_matrix.sum(2) > 1
if prior_match_gt_mask.sum() > 0:
cost_min, cost_argmin = torch.min(
cost_matrix[prior_match_gt_mask, :], dim=1)
matching_matrix[prior_match_gt_mask, :] *= 0
matching_matrix[prior_match_gt_mask, cost_argmin] = 1
# get foreground mask inside box and center prior
fg_mask_inboxes = matching_matrix.sum(2) > 0
matched_pred_ious = (matching_matrix *
pairwise_ious).sum(2)[fg_mask_inboxes]
matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
return matched_pred_ious, matched_gt_inds, fg_mask_inboxes
|