Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Tuple | |
import torch | |
import torch.nn.functional as F | |
from torch import Tensor | |
def select_candidates_in_gts(priors_points: Tensor, | |
gt_bboxes: Tensor, | |
eps: float = 1e-9) -> Tensor: | |
"""Select the positive priors' center in gt. | |
Args: | |
priors_points (Tensor): Model priors points, | |
shape(num_priors, 2) | |
gt_bboxes (Tensor): Ground true bboxes, | |
shape(batch_size, num_gt, 4) | |
eps (float): Default to 1e-9. | |
Return: | |
(Tensor): shape(batch_size, num_gt, num_priors) | |
""" | |
batch_size, num_gt, _ = gt_bboxes.size() | |
gt_bboxes = gt_bboxes.reshape([-1, 4]) | |
priors_number = priors_points.size(0) | |
priors_points = priors_points.unsqueeze(0).repeat(batch_size * num_gt, 1, | |
1) | |
# calculate the left, top, right, bottom distance between positive | |
# prior center and gt side | |
gt_bboxes_lt = gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, priors_number, 1) | |
gt_bboxes_rb = gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, priors_number, 1) | |
bbox_deltas = torch.cat( | |
[priors_points - gt_bboxes_lt, gt_bboxes_rb - priors_points], dim=-1) | |
bbox_deltas = bbox_deltas.reshape([batch_size, num_gt, priors_number, -1]) | |
return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype) | |
def select_highest_overlaps(pos_mask: Tensor, overlaps: Tensor, | |
num_gt: int) -> Tuple[Tensor, Tensor, Tensor]: | |
"""If an anchor box is assigned to multiple gts, the one with the highest | |
iou will be selected. | |
Args: | |
pos_mask (Tensor): The assigned positive sample mask, | |
shape(batch_size, num_gt, num_priors) | |
overlaps (Tensor): IoU between all bbox and ground truth, | |
shape(batch_size, num_gt, num_priors) | |
num_gt (int): Number of ground truth. | |
Return: | |
gt_idx_pre_prior (Tensor): Target ground truth index, | |
shape(batch_size, num_priors) | |
fg_mask_pre_prior (Tensor): Force matching ground truth, | |
shape(batch_size, num_priors) | |
pos_mask (Tensor): The assigned positive sample mask, | |
shape(batch_size, num_gt, num_priors) | |
""" | |
fg_mask_pre_prior = pos_mask.sum(axis=-2) | |
# Make sure the positive sample matches the only one and is the largest IoU | |
if fg_mask_pre_prior.max() > 1: | |
mask_multi_gts = (fg_mask_pre_prior.unsqueeze(1) > 1).repeat( | |
[1, num_gt, 1]) | |
index = overlaps.argmax(axis=1) | |
is_max_overlaps = F.one_hot(index, num_gt) | |
is_max_overlaps = \ | |
is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype) | |
pos_mask = torch.where(mask_multi_gts, is_max_overlaps, pos_mask) | |
fg_mask_pre_prior = pos_mask.sum(axis=-2) | |
gt_idx_pre_prior = pos_mask.argmax(axis=-2) | |
return gt_idx_pre_prior, fg_mask_pre_prior, pos_mask | |
# TODO:'mmdet.BboxOverlaps2D' will cause gradient inconsistency, | |
# which will be found and solved in a later version. | |
def yolov6_iou_calculator(bbox1: Tensor, | |
bbox2: Tensor, | |
eps: float = 1e-9) -> Tensor: | |
"""Calculate iou for batch. | |
Args: | |
bbox1 (Tensor): shape(batch size, num_gt, 4) | |
bbox2 (Tensor): shape(batch size, num_priors, 4) | |
eps (float): Default to 1e-9. | |
Return: | |
(Tensor): IoU, shape(size, num_gt, num_priors) | |
""" | |
bbox1 = bbox1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4] | |
bbox2 = bbox2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4] | |
# calculate xy info of predict and gt bbox | |
bbox1_x1y1, bbox1_x2y2 = bbox1[:, :, :, 0:2], bbox1[:, :, :, 2:4] | |
bbox2_x1y1, bbox2_x2y2 = bbox2[:, :, :, 0:2], bbox2[:, :, :, 2:4] | |
# calculate overlap area | |
overlap = (torch.minimum(bbox1_x2y2, bbox2_x2y2) - | |
torch.maximum(bbox1_x1y1, bbox2_x1y1)).clip(0).prod(-1) | |
# calculate bbox area | |
bbox1_area = (bbox1_x2y2 - bbox1_x1y1).clip(0).prod(-1) | |
bbox2_area = (bbox2_x2y2 - bbox2_x1y1).clip(0).prod(-1) | |
union = bbox1_area + bbox2_area - overlap + eps | |
return overlap / union | |