# Copyright (c) OpenMMLab. All rights reserved. from typing import Sequence import torch import torch.nn as nn import torch.nn.functional as F from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_overlaps def _cat_multi_level_tensor_in_place(*multi_level_tensor, place_hold_var): """concat multi-level tensor in place.""" for level_tensor in multi_level_tensor: for i, var in enumerate(level_tensor): if len(var) > 0: level_tensor[i] = torch.cat(var, dim=0) else: level_tensor[i] = place_hold_var class BatchYOLOv7Assigner(nn.Module): """Batch YOLOv7 Assigner. It consists of two assigning steps: 1. YOLOv5 cross-grid sample assigning 2. SimOTA assigning This code referenced to https://github.com/WongKinYiu/yolov7/blob/main/utils/loss.py. Args: num_classes (int): Number of classes. num_base_priors (int): Number of base priors. featmap_strides (Sequence[int]): Feature map strides. prior_match_thr (float): Threshold to match priors. Defaults to 4.0. candidate_topk (int): Number of topk candidates to assign. Defaults to 10. iou_weight (float): IOU weight. Defaults to 3.0. cls_weight (float): Class weight. Defaults to 1.0. """ def __init__(self, num_classes: int, num_base_priors: int, featmap_strides: Sequence[int], prior_match_thr: float = 4.0, candidate_topk: int = 10, iou_weight: float = 3.0, cls_weight: float = 1.0): super().__init__() self.num_classes = num_classes self.num_base_priors = num_base_priors self.featmap_strides = featmap_strides # yolov5 param self.prior_match_thr = prior_match_thr # simota param self.candidate_topk = candidate_topk self.iou_weight = iou_weight self.cls_weight = cls_weight @torch.no_grad() def forward(self, pred_results, batch_targets_normed, batch_input_shape, priors_base_sizes, grid_offset, near_neighbor_thr=0.5) -> dict: """Forward function.""" # (num_base_priors, num_batch_gt, 7) # 7 is mean (batch_idx, cls_id, x_norm, y_norm, # w_norm, h_norm, prior_idx) # mlvl is mean multi_level if batch_targets_normed.shape[1] == 0: # empty gt of batch num_levels = len(pred_results) return dict( mlvl_positive_infos=[pred_results[0].new_empty( (0, 4))] * num_levels, mlvl_priors=[] * num_levels, mlvl_targets_normed=[] * num_levels) # if near_neighbor_thr = 0.5 are mean the nearest # 3 neighbors are also considered positive samples. # if near_neighbor_thr = 1.0 are mean the nearest # 5 neighbors are also considered positive samples. mlvl_positive_infos, mlvl_priors = self.yolov5_assigner( pred_results, batch_targets_normed, priors_base_sizes, grid_offset, near_neighbor_thr=near_neighbor_thr) mlvl_positive_infos, mlvl_priors, \ mlvl_targets_normed = self.simota_assigner( pred_results, batch_targets_normed, mlvl_positive_infos, mlvl_priors, batch_input_shape) place_hold_var = batch_targets_normed.new_empty((0, 4)) _cat_multi_level_tensor_in_place( mlvl_positive_infos, mlvl_priors, mlvl_targets_normed, place_hold_var=place_hold_var) return dict( mlvl_positive_infos=mlvl_positive_infos, mlvl_priors=mlvl_priors, mlvl_targets_normed=mlvl_targets_normed) def yolov5_assigner(self, pred_results, batch_targets_normed, priors_base_sizes, grid_offset, near_neighbor_thr=0.5): """YOLOv5 cross-grid sample assigner.""" num_batch_gts = batch_targets_normed.shape[1] assert num_batch_gts > 0 mlvl_positive_infos, mlvl_priors = [], [] scaled_factor = torch.ones(7, device=pred_results[0].device) for i in range(len(pred_results)): # lever priors_base_sizes_i = priors_base_sizes[i] # (1, 1, feat_shape_w, feat_shape_h, feat_shape_w, feat_shape_h) scaled_factor[2:6] = torch.tensor( pred_results[i].shape)[[3, 2, 3, 2]] # Scale batch_targets from range 0-1 to range 0-features_maps size. # (num_base_priors, num_batch_gts, 7) batch_targets_scaled = batch_targets_normed * scaled_factor # Shape match wh_ratio = batch_targets_scaled[..., 4:6] / priors_base_sizes_i[:, None] match_inds = torch.max( wh_ratio, 1. / wh_ratio).max(2)[0] < self.prior_match_thr batch_targets_scaled = batch_targets_scaled[ match_inds] # (num_matched_target, 7) # no gt bbox matches anchor if batch_targets_scaled.shape[0] == 0: mlvl_positive_infos.append( batch_targets_scaled.new_empty((0, 4))) mlvl_priors.append([]) continue # Positive samples with additional neighbors batch_targets_cxcy = batch_targets_scaled[:, 2:4] grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy left, up = ((batch_targets_cxcy % 1 < near_neighbor_thr) & (batch_targets_cxcy > 1)).T right, bottom = ((grid_xy % 1 < near_neighbor_thr) & (grid_xy > 1)).T offset_inds = torch.stack( (torch.ones_like(left), left, up, right, bottom)) batch_targets_scaled = batch_targets_scaled.repeat( (5, 1, 1))[offset_inds] # () retained_offsets = grid_offset.repeat(1, offset_inds.shape[1], 1)[offset_inds] # batch_targets_scaled: (num_matched_target, 7) # 7 is mean (batch_idx, cls_id, x_scaled, # y_scaled, w_scaled, h_scaled, prior_idx) # mlvl_positive_info: (num_matched_target, 4) # 4 is mean (batch_idx, prior_idx, x_scaled, y_scaled) mlvl_positive_info = batch_targets_scaled[:, [0, 6, 2, 3]] retained_offsets = retained_offsets * near_neighbor_thr mlvl_positive_info[:, 2:] = mlvl_positive_info[:, 2:] - retained_offsets mlvl_positive_info[:, 2].clamp_(0, scaled_factor[2] - 1) mlvl_positive_info[:, 3].clamp_(0, scaled_factor[3] - 1) mlvl_positive_info = mlvl_positive_info.long() priors_inds = mlvl_positive_info[:, 1] mlvl_positive_infos.append(mlvl_positive_info) mlvl_priors.append(priors_base_sizes_i[priors_inds]) return mlvl_positive_infos, mlvl_priors def simota_assigner(self, pred_results, batch_targets_normed, mlvl_positive_infos, mlvl_priors, batch_input_shape): """SimOTA assigner.""" num_batch_gts = batch_targets_normed.shape[1] assert num_batch_gts > 0 num_levels = len(mlvl_positive_infos) mlvl_positive_infos_matched = [[] for _ in range(num_levels)] mlvl_priors_matched = [[] for _ in range(num_levels)] mlvl_targets_normed_matched = [[] for _ in range(num_levels)] for batch_idx in range(pred_results[0].shape[0]): # (num_batch_gt, 7) # 7 is mean (batch_idx, cls_id, x_norm, y_norm, # w_norm, h_norm, prior_idx) targets_normed = batch_targets_normed[0] # (num_gt, 7) targets_normed = targets_normed[targets_normed[:, 0] == batch_idx] num_gts = targets_normed.shape[0] if num_gts == 0: continue _mlvl_decoderd_bboxes = [] _mlvl_obj_cls = [] _mlvl_priors = [] _mlvl_positive_infos = [] _from_which_layer = [] for i, head_pred in enumerate(pred_results): # (num_matched_target, 4) # 4 is mean (batch_idx, prior_idx, grid_x, grid_y) _mlvl_positive_info = mlvl_positive_infos[i] if _mlvl_positive_info.shape[0] == 0: continue idx = (_mlvl_positive_info[:, 0] == batch_idx) _mlvl_positive_info = _mlvl_positive_info[idx] _mlvl_positive_infos.append(_mlvl_positive_info) priors = mlvl_priors[i][idx] _mlvl_priors.append(priors) _from_which_layer.append( _mlvl_positive_info.new_full( size=(_mlvl_positive_info.shape[0], ), fill_value=i)) # (n,85) level_batch_idx, prior_ind, \ grid_x, grid_y = _mlvl_positive_info.T pred_positive = head_pred[level_batch_idx, prior_ind, grid_y, grid_x] _mlvl_obj_cls.append(pred_positive[:, 4:]) # decoded grid = torch.stack([grid_x, grid_y], dim=1) pred_positive_cxcy = (pred_positive[:, :2].sigmoid() * 2. - 0.5 + grid) * self.featmap_strides[i] pred_positive_wh = (pred_positive[:, 2:4].sigmoid() * 2) ** 2 \ * priors * self.featmap_strides[i] pred_positive_xywh = torch.cat( [pred_positive_cxcy, pred_positive_wh], dim=-1) _mlvl_decoderd_bboxes.append(pred_positive_xywh) if len(_mlvl_decoderd_bboxes) == 0: continue # 1 calc pair_wise_iou_loss _mlvl_decoderd_bboxes = torch.cat(_mlvl_decoderd_bboxes, dim=0) num_pred_positive = _mlvl_decoderd_bboxes.shape[0] if num_pred_positive == 0: continue # scaled xywh batch_input_shape_wh = pred_results[0].new_tensor( batch_input_shape[::-1]).repeat((1, 2)) targets_scaled_bbox = targets_normed[:, 2:6] * batch_input_shape_wh targets_scaled_bbox = bbox_cxcywh_to_xyxy(targets_scaled_bbox) _mlvl_decoderd_bboxes = bbox_cxcywh_to_xyxy(_mlvl_decoderd_bboxes) pair_wise_iou = bbox_overlaps(targets_scaled_bbox, _mlvl_decoderd_bboxes) pair_wise_iou_loss = -torch.log(pair_wise_iou + 1e-8) # 2 calc pair_wise_cls_loss _mlvl_obj_cls = torch.cat(_mlvl_obj_cls, dim=0).float().sigmoid() _mlvl_positive_infos = torch.cat(_mlvl_positive_infos, dim=0) _from_which_layer = torch.cat(_from_which_layer, dim=0) _mlvl_priors = torch.cat(_mlvl_priors, dim=0) gt_cls_per_image = ( F.one_hot(targets_normed[:, 1].to(torch.int64), self.num_classes).float().unsqueeze(1).repeat( 1, num_pred_positive, 1)) # cls_score * obj cls_preds_ = _mlvl_obj_cls[:, 1:]\ .unsqueeze(0)\ .repeat(num_gts, 1, 1) \ * _mlvl_obj_cls[:, 0:1]\ .unsqueeze(0).repeat(num_gts, 1, 1) y = cls_preds_.sqrt_() pair_wise_cls_loss = F.binary_cross_entropy_with_logits( torch.log(y / (1 - y)), gt_cls_per_image, reduction='none').sum(-1) del cls_preds_ # calc cost cost = ( self.cls_weight * pair_wise_cls_loss + self.iou_weight * pair_wise_iou_loss) # num_gt, num_match_pred matching_matrix = torch.zeros_like(cost) top_k, _ = torch.topk( pair_wise_iou, min(self.candidate_topk, pair_wise_iou.shape[1]), dim=1) dynamic_ks = torch.clamp(top_k.sum(1).int(), min=1) # Select only topk matches per gt for gt_idx in range(num_gts): _, pos_idx = torch.topk( cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False) matching_matrix[gt_idx][pos_idx] = 1.0 del top_k, dynamic_ks # Each prediction box can match at most one gt box, # and if there are more than one, # only the least costly one can be taken anchor_matching_gt = matching_matrix.sum(0) if (anchor_matching_gt > 1).sum() > 0: _, cost_argmin = torch.min( cost[:, anchor_matching_gt > 1], dim=0) matching_matrix[:, anchor_matching_gt > 1] *= 0.0 matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0 fg_mask_inboxes = matching_matrix.sum(0) > 0.0 matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0) targets_normed = targets_normed[matched_gt_inds] _mlvl_positive_infos = _mlvl_positive_infos[fg_mask_inboxes] _from_which_layer = _from_which_layer[fg_mask_inboxes] _mlvl_priors = _mlvl_priors[fg_mask_inboxes] # Rearranged in the order of the prediction layers # to facilitate loss for i in range(num_levels): layer_idx = _from_which_layer == i mlvl_positive_infos_matched[i].append( _mlvl_positive_infos[layer_idx]) mlvl_priors_matched[i].append(_mlvl_priors[layer_idx]) mlvl_targets_normed_matched[i].append( targets_normed[layer_idx]) results = mlvl_positive_infos_matched, \ mlvl_priors_matched, \ mlvl_targets_normed_matched return results