KyanChen's picture
Upload 89 files
3094730
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import Tensor
def select_candidates_in_gts(priors_points: Tensor,
gt_bboxes: Tensor,
eps: float = 1e-9) -> Tensor:
"""Select the positive priors' center in gt.
Args:
priors_points (Tensor): Model priors points,
shape(num_priors, 2)
gt_bboxes (Tensor): Ground true bboxes,
shape(batch_size, num_gt, 4)
eps (float): Default to 1e-9.
Return:
(Tensor): shape(batch_size, num_gt, num_priors)
"""
batch_size, num_gt, _ = gt_bboxes.size()
gt_bboxes = gt_bboxes.reshape([-1, 4])
priors_number = priors_points.size(0)
priors_points = priors_points.unsqueeze(0).repeat(batch_size * num_gt, 1,
1)
# calculate the left, top, right, bottom distance between positive
# prior center and gt side
gt_bboxes_lt = gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, priors_number, 1)
gt_bboxes_rb = gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, priors_number, 1)
bbox_deltas = torch.cat(
[priors_points - gt_bboxes_lt, gt_bboxes_rb - priors_points], dim=-1)
bbox_deltas = bbox_deltas.reshape([batch_size, num_gt, priors_number, -1])
return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
def select_highest_overlaps(pos_mask: Tensor, overlaps: Tensor,
num_gt: int) -> Tuple[Tensor, Tensor, Tensor]:
"""If an anchor box is assigned to multiple gts, the one with the highest
iou will be selected.
Args:
pos_mask (Tensor): The assigned positive sample mask,
shape(batch_size, num_gt, num_priors)
overlaps (Tensor): IoU between all bbox and ground truth,
shape(batch_size, num_gt, num_priors)
num_gt (int): Number of ground truth.
Return:
gt_idx_pre_prior (Tensor): Target ground truth index,
shape(batch_size, num_priors)
fg_mask_pre_prior (Tensor): Force matching ground truth,
shape(batch_size, num_priors)
pos_mask (Tensor): The assigned positive sample mask,
shape(batch_size, num_gt, num_priors)
"""
fg_mask_pre_prior = pos_mask.sum(axis=-2)
# Make sure the positive sample matches the only one and is the largest IoU
if fg_mask_pre_prior.max() > 1:
mask_multi_gts = (fg_mask_pre_prior.unsqueeze(1) > 1).repeat(
[1, num_gt, 1])
index = overlaps.argmax(axis=1)
is_max_overlaps = F.one_hot(index, num_gt)
is_max_overlaps = \
is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype)
pos_mask = torch.where(mask_multi_gts, is_max_overlaps, pos_mask)
fg_mask_pre_prior = pos_mask.sum(axis=-2)
gt_idx_pre_prior = pos_mask.argmax(axis=-2)
return gt_idx_pre_prior, fg_mask_pre_prior, pos_mask
# TODO:'mmdet.BboxOverlaps2D' will cause gradient inconsistency,
# which will be found and solved in a later version.
def yolov6_iou_calculator(bbox1: Tensor,
bbox2: Tensor,
eps: float = 1e-9) -> Tensor:
"""Calculate iou for batch.
Args:
bbox1 (Tensor): shape(batch size, num_gt, 4)
bbox2 (Tensor): shape(batch size, num_priors, 4)
eps (float): Default to 1e-9.
Return:
(Tensor): IoU, shape(size, num_gt, num_priors)
"""
bbox1 = bbox1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4]
bbox2 = bbox2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4]
# calculate xy info of predict and gt bbox
bbox1_x1y1, bbox1_x2y2 = bbox1[:, :, :, 0:2], bbox1[:, :, :, 2:4]
bbox2_x1y1, bbox2_x2y2 = bbox2[:, :, :, 0:2], bbox2[:, :, :, 2:4]
# calculate overlap area
overlap = (torch.minimum(bbox1_x2y2, bbox2_x2y2) -
torch.maximum(bbox1_x1y1, bbox2_x1y1)).clip(0).prod(-1)
# calculate bbox area
bbox1_area = (bbox1_x2y2 - bbox1_x1y1).clip(0).prod(-1)
bbox2_area = (bbox2_x2y2 - bbox2_x1y1).clip(0).prod(-1)
union = bbox1_area + bbox2_area - overlap + eps
return overlap / union