Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmdet.utils import ConfigType | |
from torch import Tensor | |
from mmyolo.registry import TASK_UTILS | |
from .utils import (select_candidates_in_gts, select_highest_overlaps, | |
yolov6_iou_calculator) | |
def bbox_center_distance(bboxes: Tensor, | |
priors: Tensor) -> Tuple[Tensor, Tensor]: | |
"""Compute the center distance between bboxes and priors. | |
Args: | |
bboxes (Tensor): Shape (n, 4) for bbox, "xyxy" format. | |
priors (Tensor): Shape (num_priors, 4) for priors, "xyxy" format. | |
Returns: | |
distances (Tensor): Center distances between bboxes and priors, | |
shape (num_priors, n). | |
priors_points (Tensor): Priors cx cy points, | |
shape (num_priors, 2). | |
""" | |
bbox_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0 | |
bbox_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0 | |
bbox_points = torch.stack((bbox_cx, bbox_cy), dim=1) | |
priors_cx = (priors[:, 0] + priors[:, 2]) / 2.0 | |
priors_cy = (priors[:, 1] + priors[:, 3]) / 2.0 | |
priors_points = torch.stack((priors_cx, priors_cy), dim=1) | |
distances = (bbox_points[:, None, :] - | |
priors_points[None, :, :]).pow(2).sum(-1).sqrt() | |
return distances, priors_points | |
class BatchATSSAssigner(nn.Module): | |
"""Assign a batch of corresponding gt bboxes or background to each prior. | |
This code is based on | |
https://github.com/meituan/YOLOv6/blob/main/yolov6/assigners/atss_assigner.py | |
Each proposal will be assigned with `0` or a positive integer | |
indicating the ground truth index. | |
- 0: negative sample, no assigned gt | |
- positive integer: positive sample, index (1-based) of assigned gt | |
Args: | |
num_classes (int): number of class | |
iou_calculator (:obj:`ConfigDict` or dict): Config dict for iou | |
calculator. Defaults to ``dict(type='BboxOverlaps2D')`` | |
topk (int): number of priors selected in each level | |
""" | |
def __init__( | |
self, | |
num_classes: int, | |
iou_calculator: ConfigType = dict(type='mmdet.BboxOverlaps2D'), | |
topk: int = 9): | |
super().__init__() | |
self.num_classes = num_classes | |
self.iou_calculator = TASK_UTILS.build(iou_calculator) | |
self.topk = topk | |
def forward(self, pred_bboxes: Tensor, priors: Tensor, | |
num_level_priors: List, gt_labels: Tensor, gt_bboxes: Tensor, | |
pad_bbox_flag: Tensor) -> dict: | |
"""Assign gt to priors. | |
The assignment is done in following steps | |
1. compute iou between all prior (prior of all pyramid levels) and gt | |
2. compute center distance between all prior and gt | |
3. on each pyramid level, for each gt, select k prior whose center | |
are closest to the gt center, so we total select k*l prior as | |
candidates for each gt | |
4. get corresponding iou for the these candidates, and compute the | |
mean and std, set mean + std as the iou threshold | |
5. select these candidates whose iou are greater than or equal to | |
the threshold as positive | |
6. limit the positive sample's center in gt | |
Args: | |
pred_bboxes (Tensor): Predicted bounding boxes, | |
shape(batch_size, num_priors, 4) | |
priors (Tensor): Model priors with stride, shape(num_priors, 4) | |
num_level_priors (List): Number of bboxes in each level, len(3) | |
gt_labels (Tensor): Ground truth label, | |
shape(batch_size, num_gt, 1) | |
gt_bboxes (Tensor): Ground truth bbox, | |
shape(batch_size, num_gt, 4) | |
pad_bbox_flag (Tensor): Ground truth bbox mask, | |
1 means bbox, 0 means no bbox, | |
shape(batch_size, num_gt, 1) | |
Returns: | |
assigned_result (dict): Assigned result | |
'assigned_labels' (Tensor): shape(batch_size, num_gt) | |
'assigned_bboxes' (Tensor): shape(batch_size, num_gt, 4) | |
'assigned_scores' (Tensor): | |
shape(batch_size, num_gt, number_classes) | |
'fg_mask_pre_prior' (Tensor): shape(bs, num_gt) | |
""" | |
# generate priors | |
cell_half_size = priors[:, 2:] * 2.5 | |
priors_gen = torch.zeros_like(priors) | |
priors_gen[:, :2] = priors[:, :2] - cell_half_size | |
priors_gen[:, 2:] = priors[:, :2] + cell_half_size | |
priors = priors_gen | |
batch_size = gt_bboxes.size(0) | |
num_gt, num_priors = gt_bboxes.size(1), priors.size(0) | |
assigned_result = { | |
'assigned_labels': | |
gt_bboxes.new_full([batch_size, num_priors], self.num_classes), | |
'assigned_bboxes': | |
gt_bboxes.new_full([batch_size, num_priors, 4], 0), | |
'assigned_scores': | |
gt_bboxes.new_full([batch_size, num_priors, self.num_classes], 0), | |
'fg_mask_pre_prior': | |
gt_bboxes.new_full([batch_size, num_priors], 0) | |
} | |
if num_gt == 0: | |
return assigned_result | |
# compute iou between all prior (prior of all pyramid levels) and gt | |
overlaps = self.iou_calculator(gt_bboxes.reshape([-1, 4]), priors) | |
overlaps = overlaps.reshape([batch_size, -1, num_priors]) | |
# compute center distance between all prior and gt | |
distances, priors_points = bbox_center_distance( | |
gt_bboxes.reshape([-1, 4]), priors) | |
distances = distances.reshape([batch_size, -1, num_priors]) | |
# Selecting candidates based on the center distance | |
is_in_candidate, candidate_idxs = self.select_topk_candidates( | |
distances, num_level_priors, pad_bbox_flag) | |
# get corresponding iou for the these candidates, and compute the | |
# mean and std, set mean + std as the iou threshold | |
overlaps_thr_per_gt, iou_candidates = self.threshold_calculator( | |
is_in_candidate, candidate_idxs, overlaps, num_priors, batch_size, | |
num_gt) | |
# select candidates iou >= threshold as positive | |
is_pos = torch.where( | |
iou_candidates > overlaps_thr_per_gt.repeat([1, 1, num_priors]), | |
is_in_candidate, torch.zeros_like(is_in_candidate)) | |
is_in_gts = select_candidates_in_gts(priors_points, gt_bboxes) | |
pos_mask = is_pos * is_in_gts * pad_bbox_flag | |
# if an anchor box is assigned to multiple gts, | |
# the one with the highest IoU will be selected. | |
gt_idx_pre_prior, fg_mask_pre_prior, pos_mask = \ | |
select_highest_overlaps(pos_mask, overlaps, num_gt) | |
# assigned target | |
assigned_labels, assigned_bboxes, assigned_scores = self.get_targets( | |
gt_labels, gt_bboxes, gt_idx_pre_prior, fg_mask_pre_prior, | |
num_priors, batch_size, num_gt) | |
# soft label with iou | |
if pred_bboxes is not None: | |
ious = yolov6_iou_calculator(gt_bboxes, pred_bboxes) * pos_mask | |
ious = ious.max(axis=-2)[0].unsqueeze(-1) | |
assigned_scores *= ious | |
assigned_result['assigned_labels'] = assigned_labels.long() | |
assigned_result['assigned_bboxes'] = assigned_bboxes | |
assigned_result['assigned_scores'] = assigned_scores | |
assigned_result['fg_mask_pre_prior'] = fg_mask_pre_prior.bool() | |
return assigned_result | |
def select_topk_candidates(self, distances: Tensor, | |
num_level_priors: List[int], | |
pad_bbox_flag: Tensor) -> Tuple[Tensor, Tensor]: | |
"""Selecting candidates based on the center distance. | |
Args: | |
distances (Tensor): Distance between all bbox and gt, | |
shape(batch_size, num_gt, num_priors) | |
num_level_priors (List[int]): Number of bboxes in each level, | |
len(3) | |
pad_bbox_flag (Tensor): Ground truth bbox mask, | |
shape(batch_size, num_gt, 1) | |
Return: | |
is_in_candidate_list (Tensor): Flag show that each level have | |
topk candidates or not, shape(batch_size, num_gt, num_priors) | |
candidate_idxs (Tensor): Candidates index, | |
shape(batch_size, num_gt, num_gt) | |
""" | |
is_in_candidate_list = [] | |
candidate_idxs = [] | |
start_idx = 0 | |
distances_dtype = distances.dtype | |
distances = torch.split(distances, num_level_priors, dim=-1) | |
pad_bbox_flag = pad_bbox_flag.repeat(1, 1, self.topk).bool() | |
for distances_per_level, priors_per_level in zip( | |
distances, num_level_priors): | |
# on each pyramid level, for each gt, | |
# select k bbox whose center are closest to the gt center | |
end_index = start_idx + priors_per_level | |
selected_k = min(self.topk, priors_per_level) | |
_, topk_idxs_per_level = distances_per_level.topk( | |
selected_k, dim=-1, largest=False) | |
candidate_idxs.append(topk_idxs_per_level + start_idx) | |
topk_idxs_per_level = torch.where( | |
pad_bbox_flag, topk_idxs_per_level, | |
torch.zeros_like(topk_idxs_per_level)) | |
is_in_candidate = F.one_hot(topk_idxs_per_level, | |
priors_per_level).sum(dim=-2) | |
is_in_candidate = torch.where(is_in_candidate > 1, | |
torch.zeros_like(is_in_candidate), | |
is_in_candidate) | |
is_in_candidate_list.append(is_in_candidate.to(distances_dtype)) | |
start_idx = end_index | |
is_in_candidate_list = torch.cat(is_in_candidate_list, dim=-1) | |
candidate_idxs = torch.cat(candidate_idxs, dim=-1) | |
return is_in_candidate_list, candidate_idxs | |
def threshold_calculator(is_in_candidate: List, candidate_idxs: Tensor, | |
overlaps: Tensor, num_priors: int, | |
batch_size: int, | |
num_gt: int) -> Tuple[Tensor, Tensor]: | |
"""Get corresponding iou for the these candidates, and compute the mean | |
and std, set mean + std as the iou threshold. | |
Args: | |
is_in_candidate (Tensor): Flag show that each level have | |
topk candidates or not, shape(batch_size, num_gt, num_priors). | |
candidate_idxs (Tensor): Candidates index, | |
shape(batch_size, num_gt, num_gt) | |
overlaps (Tensor): Overlaps area, | |
shape(batch_size, num_gt, num_priors). | |
num_priors (int): Number of priors. | |
batch_size (int): Batch size. | |
num_gt (int): Number of ground truth. | |
Return: | |
overlaps_thr_per_gt (Tensor): Overlap threshold of | |
per ground truth, shape(batch_size, num_gt, 1). | |
candidate_overlaps (Tensor): Candidate overlaps, | |
shape(batch_size, num_gt, num_priors). | |
""" | |
batch_size_num_gt = batch_size * num_gt | |
candidate_overlaps = torch.where(is_in_candidate > 0, overlaps, | |
torch.zeros_like(overlaps)) | |
candidate_idxs = candidate_idxs.reshape([batch_size_num_gt, -1]) | |
assist_indexes = num_priors * torch.arange( | |
batch_size_num_gt, device=candidate_idxs.device) | |
assist_indexes = assist_indexes[:, None] | |
flatten_indexes = candidate_idxs + assist_indexes | |
candidate_overlaps_reshape = candidate_overlaps.reshape( | |
-1)[flatten_indexes] | |
candidate_overlaps_reshape = candidate_overlaps_reshape.reshape( | |
[batch_size, num_gt, -1]) | |
overlaps_mean_per_gt = candidate_overlaps_reshape.mean( | |
axis=-1, keepdim=True) | |
overlaps_std_per_gt = candidate_overlaps_reshape.std( | |
axis=-1, keepdim=True) | |
overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt | |
return overlaps_thr_per_gt, candidate_overlaps | |
def get_targets(self, gt_labels: Tensor, gt_bboxes: Tensor, | |
assigned_gt_inds: Tensor, fg_mask_pre_prior: Tensor, | |
num_priors: int, batch_size: int, | |
num_gt: int) -> Tuple[Tensor, Tensor, Tensor]: | |
"""Get target info. | |
Args: | |
gt_labels (Tensor): Ground true labels, | |
shape(batch_size, num_gt, 1) | |
gt_bboxes (Tensor): Ground true bboxes, | |
shape(batch_size, num_gt, 4) | |
assigned_gt_inds (Tensor): Assigned ground truth indexes, | |
shape(batch_size, num_priors) | |
fg_mask_pre_prior (Tensor): Force ground truth matching mask, | |
shape(batch_size, num_priors) | |
num_priors (int): Number of priors. | |
batch_size (int): Batch size. | |
num_gt (int): Number of ground truth. | |
Return: | |
assigned_labels (Tensor): Assigned labels, | |
shape(batch_size, num_priors) | |
assigned_bboxes (Tensor): Assigned bboxes, | |
shape(batch_size, num_priors) | |
assigned_scores (Tensor): Assigned scores, | |
shape(batch_size, num_priors) | |
""" | |
# assigned target labels | |
batch_index = torch.arange( | |
batch_size, dtype=gt_labels.dtype, device=gt_labels.device) | |
batch_index = batch_index[..., None] | |
assigned_gt_inds = (assigned_gt_inds + batch_index * num_gt).long() | |
assigned_labels = gt_labels.flatten()[assigned_gt_inds.flatten()] | |
assigned_labels = assigned_labels.reshape([batch_size, num_priors]) | |
assigned_labels = torch.where( | |
fg_mask_pre_prior > 0, assigned_labels, | |
torch.full_like(assigned_labels, self.num_classes)) | |
# assigned target boxes | |
assigned_bboxes = gt_bboxes.reshape([-1, | |
4])[assigned_gt_inds.flatten()] | |
assigned_bboxes = assigned_bboxes.reshape([batch_size, num_priors, 4]) | |
# assigned target scores | |
assigned_scores = F.one_hot(assigned_labels.long(), | |
self.num_classes + 1).float() | |
assigned_scores = assigned_scores[:, :, :self.num_classes] | |
return assigned_labels, assigned_bboxes, assigned_scores | |