All rights reserved. import copy import math from typing import List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn from mmdet.models.dense_heads.base_dense_head import BaseDenseHead from mmdet.models.utils import filter_scores_and_topk, multi_apply from mmdet.structures.bbox import bbox_overlaps from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList, OptMultiConfig) from mmengine.config import ConfigDict from mmengine.dist import get_dist_info from mmengine.logging import print_log from mmengine.model import BaseModule from mmengine.structures import InstanceData from torch import Tensor from mmyolo.registry import MODELS, TASK_UTILS from ..utils import make_divisible def get_prior_xy_info(index: int, num_base_priors: int, featmap_sizes: int) -> Tuple[int, int, int]: """Get prior index and xy index in feature map by flatten index.""" _, featmap_w = featmap_sizes priors = index % num_base_priors xy_index = index // num_base_priors grid_y = xy_index // featmap_w grid_x = xy_index % featmap_w return priors, grid_x, grid_y @MODELS.register_module() class YOLOv5HeadModule(BaseModule): """YOLOv5Head head module used in `YOLOv5`. Args: num_classes (int): Number of categories excluding the background category. in_channels (Union[int, Sequence]): Number of channels in the input feature map. widen_factor (float): Width multiplier, multiply number of channels in each layer by this amount. Defaults to 1.0. num_base_priors (int): The number of priors (points) at a point on the feature grid. featmap_strides (Sequence[int]): Downsample factor of each feature map. Defaults to (8, 16, 32). init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or list[dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, num_classes: int, in_channels: Union[int, Sequence], widen_factor: float = 1.0, num_base_priors: int = 3, featmap_strides: Sequence[int] = (8, 16, 32), init_cfg: OptMultiConfig = None): super().__init__(init_cfg=init_cfg) self.num_classes = num_classes self.widen_factor = widen_factor self.featmap_strides = featmap_strides self.num_out_attrib = 5 + self.num_classes self.num_levels = len(self.featmap_strides) self.num_base_priors = num_base_priors if isinstance(in_channels, int): self.in_channels = [make_divisible(in_channels, widen_factor) ] * self.num_levels else: self.in_channels = [ make_divisible(i, widen_factor) for i in in_channels ] self._init_layers() def _init_layers(self): """initialize conv layers in YOLOv5 head.""" self.convs_pred = nn.ModuleList() for i in range(self.num_levels): conv_pred = nn.Conv2d(self.in_channels[i], self.num_base_priors * self.num_out_attrib, 1) self.convs_pred.append(conv_pred) def init_weights(self): """Initialize the bias of YOLOv5 head.""" super().init_weights() for mi, s in zip(self.convs_pred, self.featmap_strides): # from b = mi.bias.data.view(self.num_base_priors, -1) # obj (8 objects per 640 image) b.data[:, 4] += math.log(8 / (640 / s)**2) b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.999999)) mi.bias.data = b.view(-1) def forward(self, x: Tuple[Tensor]) -> Tuple[List]: """Forward features from the upstream network. Args: x (Tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. Returns: Tuple[List]: A tuple of multi-level classification scores, bbox predictions, and objectnesses. """ assert len(x) == self.num_levels return multi_apply(self.forward_single, x, self.convs_pred) def forward_single(self, x: Tensor, convs: nn.Module) -> Tuple[Tensor, Tensor, Tensor]: """Forward feature of a single scale level.""" pred_map = convs(x) bs, _, ny, nx = pred_map.shape pred_map = pred_map.view(bs, self.num_base_priors, self.num_out_attrib, ny, nx) cls_score = pred_map[:, :, 5:, ...].reshape(bs, -1, ny, nx) bbox_pred = pred_map[:, :, :4, ...].reshape(bs, -1, ny, nx) objectness = pred_map[:, :, 4:5, ...].reshape(bs, -1, ny, nx) return cls_score, bbox_pred, objectness @MODELS.register_module() class YOLOv5Head(BaseDenseHead): """YOLOv5Head head used in `YOLOv5`. Args: head_module(ConfigType): Base module used for YOLOv5Head prior_generator(dict): Points generator feature maps in 2D points-based detectors. bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder. loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss. loss_obj (:obj:`ConfigDict` or dict): Config of objectness loss. prior_match_thr (float): Defaults to 4.0. ignore_iof_thr (float): Defaults to -1.0. obj_level_weights (List[float]): Defaults to [4.0, 1.0, 0.4]. train_cfg (:obj:`ConfigDict` or dict, optional): Training config of anchor head. Defaults to None. test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of anchor head. Defaults to None. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or list[dict], optional): Initialization config dict. Defaults to None. """ def __init__(self, head_module: ConfigType, prior_generator: ConfigType = dict( type='mmdet.YOLOAnchorGenerator', base_sizes=[[(10, 13), (16, 30), (33, 23)], [(30, 61), (62, 45), (59, 119)], [(116, 90), (156, 198), (373, 326)]], strides=[8, 16, 32]), bbox_coder: ConfigType = dict(type='YOLOv5BBoxCoder'), loss_cls: ConfigType = dict( type='mmdet.CrossEntropyLoss', use_sigmoid=True, reduction='mean', loss_weight=0.5), loss_bbox: ConfigType = dict( type='IoULoss', iou_mode='ciou', bbox_format='xywh', eps=1e-7, reduction='mean', loss_weight=0.05, return_iou=True), loss_obj: ConfigType = dict( type='mmdet.CrossEntropyLoss', use_sigmoid=True, reduction='mean', loss_weight=1.0), prior_match_thr: float = 4.0, near_neighbor_thr: float = 0.5, ignore_iof_thr: float = -1.0, obj_level_weights: List[float] = [4.0, 1.0, 0.4], train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, init_cfg: OptMultiConfig = None): super().__init__(init_cfg=init_cfg) self.head_module = MODELS.build(head_module) self.num_classes = self.head_module.num_classes self.featmap_strides = self.head_module.featmap_strides self.num_levels = len(self.featmap_strides) self.train_cfg = train_cfg self.test_cfg = test_cfg self.loss_cls: nn.Module = MODELS.build(loss_cls) self.loss_bbox: nn.Module = MODELS.build(loss_bbox) self.loss_obj: nn.Module = MODELS.build(loss_obj) self.prior_generator = TASK_UTILS.build(prior_generator) self.bbox_coder = TASK_UTILS.build(bbox_coder) self.num_base_priors = self.prior_generator.num_base_priors[0] self.featmap_sizes = [torch.empty(1)] * self.num_levels self.prior_match_thr = prior_match_thr self.near_neighbor_thr = near_neighbor_thr self.obj_level_weights = obj_level_weights self.ignore_iof_thr = ignore_iof_thr self.special_init() def special_init(self): """Since YOLO series algorithms will inherit from YOLOv5Head, but different algorithms have special initialization process. The special_init function is designed to deal with this situation. """ assert len(self.obj_level_weights) == len( self.featmap_strides) == self.num_levels if self.prior_match_thr != 4.0: print_log( "!!!Now, you've changed the prior_match_thr " 'parameter to something other than 4.0. Please make sure ' 'that you have modified both the regression formula in ' 'bbox_coder and before loss_box computation, ' 'otherwise the accuracy may be degraded!!!') if self.num_classes == 1: print_log('!!!You are using `YOLOv5Head` with num_classes == 1.' ' The loss_cls will be 0. This is a normal phenomenon.') priors_base_sizes = torch.tensor( self.prior_generator.base_sizes, dtype=torch.float) featmap_strides = torch.tensor( self.featmap_strides, dtype=torch.float)[:, None, None] self.register_buffer( 'priors_base_sizes', priors_base_sizes / featmap_strides, persistent=False) grid_offset = torch.tensor([ [0, 0], # center [1, 0], # left [0, 1], # up [-1, 0], # right [0, -1], # bottom ]).float() self.register_buffer( 'grid_offset', grid_offset[:, None], persistent=False) prior_inds = torch.arange(self.num_base_priors).float().view( self.num_base_priors, 1) self.register_buffer('prior_inds', prior_inds, persistent=False) def forward(self, x: Tuple[Tensor]) -> Tuple[List]: """Forward features from the upstream network. Args: x (Tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. Returns: Tuple[List]: A tuple of multi-level classification scores, bbox predictions, and objectnesses. """ return self.head_module(x) def predict_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], objectnesses: Optional[List[Tensor]] = None, batch_img_metas: Optional[List[dict]] = None, cfg: Optional[ConfigDict] = None, rescale: bool = True, with_nms: bool = True) -> List[InstanceData]: """Transform a batch of output features extracted by the head into bbox results. Args: cls_scores (list[Tensor]): Classification scores for all scale levels, each is a 4D-tensor, has shape (batch_size, num_priors * num_classes, H, W). bbox_preds (list[Tensor]): Box energies / deltas for all scale levels, each is a 4D-tensor, has shape (batch_size, num_priors * 4, H, W). objectnesses (list[Tensor], Optional): Score factor for all scale level, each is a 4D-tensor, has shape (batch_size, 1, H, W). batch_img_metas (list[dict], Optional): Batch image meta info. Defaults to None. cfg (ConfigDict, optional): Test / postprocessing configuration, if None, test_cfg would be used. Defaults to None. rescale (bool): If True, return boxes in original image space. Defaults to False. with_nms (bool): If True, do nms before return boxes. Defaults to True. Returns: list[:obj:`InstanceData`]: Object detection results of each image after the post process. Each item usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). """ assert len(cls_scores) == len(bbox_preds) if objectnesses is None: with_objectnesses = False else: with_objectnesses = True assert len(cls_scores) == len(objectnesses) cfg = self.test_cfg if cfg is None else cfg cfg = copy.deepcopy(cfg) multi_label = cfg.multi_label multi_label &= self.num_classes > 1 cfg.multi_label = multi_label num_imgs = len(batch_img_metas) featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] # If the shape does not change, use the previous mlvl_priors if featmap_sizes != self.featmap_sizes: self.mlvl_priors = self.prior_generator.grid_priors( featmap_sizes, dtype=cls_scores[0].dtype, device=cls_scores[0].device) self.featmap_sizes = featmap_sizes flatten_priors = torch.cat(self.mlvl_priors) mlvl_strides = [ flatten_priors.new_full( (featmap_size.numel() * self.num_base_priors, ), stride) for featmap_size, stride in zip(featmap_sizes, self.featmap_strides) ] flatten_stride = torch.cat(mlvl_strides) # flatten cls_scores, bbox_preds and objectness flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes) for cls_score in cls_scores ] flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) for bbox_pred in bbox_preds ] flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid() flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) flatten_decoded_bboxes = self.bbox_coder.decode( flatten_priors[None], flatten_bbox_preds, flatten_stride) if with_objectnesses: flatten_objectness = [ objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1) for objectness in objectnesses ] flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid() else: flatten_objectness = [None for _ in range(num_imgs)] results_list = [] for (bboxes, scores, objectness, img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores, flatten_objectness, batch_img_metas): ori_shape = img_meta['ori_shape'] scale_factor = img_meta['scale_factor'] if 'pad_param' in img_meta: pad_param = img_meta['pad_param'] else: pad_param = None score_thr = cfg.get('score_thr', -1) # yolox_style does not require the following operations if objectness is not None and score_thr > 0 and not cfg.get( 'yolox_style', False): conf_inds = objectness > score_thr bboxes = bboxes[conf_inds, :] scores = scores[conf_inds, :] objectness = objectness[conf_inds] if objectness is not None: # conf = obj_conf * cls_conf scores *= objectness[:, None] if scores.shape[0] == 0: empty_results = InstanceData() empty_results.bboxes = bboxes empty_results.scores = scores[:, 0] empty_results.labels = scores[:, 0].int() results_list.append(empty_results) continue nms_pre = cfg.get('nms_pre', 100000) if cfg.multi_label is False: scores, labels = scores.max(1, keepdim=True) scores, _, keep_idxs, results = filter_scores_and_topk( scores, score_thr, nms_pre, results=dict(labels=labels[:, 0])) labels = results['labels'] else: scores, labels, keep_idxs, _ = filter_scores_and_topk( scores, score_thr, nms_pre) results = InstanceData( scores=scores, labels=labels, bboxes=bboxes[keep_idxs]) if rescale: if pad_param is not None: results.bboxes -= results.bboxes.new_tensor([ pad_param[2], pad_param[0], pad_param[2], pad_param[0] ]) results.bboxes /= results.bboxes.new_tensor( scale_factor).repeat((1, 2)) if cfg.get('yolox_style', False): # do not need max_per_img cfg.max_per_img = len(results) results = self._bbox_post_process( results=results, cfg=cfg, rescale=False, with_nms=with_nms, img_meta=img_meta) results.bboxes[:, 0::2].clamp_(0, ori_shape[1]) results.bboxes[:, 1::2].clamp_(0, ori_shape[0]) results_list.append(results) return results_list def loss(self, x: Tuple[Tensor], batch_data_samples: Union[list, dict]) -> dict: """Perform forward propagation and loss calculation of the detection head on the features of the upstream network. Args: x (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. batch_data_samples (List[:obj:`DetDataSample`], dict): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. Returns: dict: A dictionary of loss components. """ if isinstance(batch_data_samples, list): losses = super().loss(x, batch_data_samples) else: outs = self(x) # Fast version loss_inputs = outs + (batch_data_samples['bboxes_labels'], batch_data_samples['img_metas']) losses = self.loss_by_feat(*loss_inputs) return losses def loss_by_feat( self, cls_scores: Sequence[Tensor], bbox_preds: Sequence[Tensor], objectnesses: Sequence[Tensor], batch_gt_instances: Sequence[InstanceData], batch_img_metas: Sequence[dict], batch_gt_instances_ignore: OptInstanceList = None) -> dict: """Calculate the loss based on the features extracted by the detection head. Args: cls_scores (Sequence[Tensor]): Box scores for each scale level, each is a 4D-tensor, the channel number is num_priors * num_classes. bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale level, each is a 4D-tensor, the channel number is num_priors * 4. objectnesses (Sequence[Tensor]): Score factor for all scale level, each is a 4D-tensor, has shape (batch_size, 1, H, W). batch_gt_instances (Sequence[InstanceData]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. batch_img_metas (Sequence[dict]): Meta information of each image, e.g., image size, scaling factor, etc. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): Batch of gt_instances_ignore. It includes ``bboxes`` attribute data that is ignored during training and testing. Defaults to None. Returns: dict[str, Tensor]: A dictionary of losses. """ if self.ignore_iof_thr != -1: # TODO: Support fast version # convert ignore gt batch_target_ignore_list = [] for i, gt_instances_ignore in enumerate(batch_gt_instances_ignore): bboxes = gt_instances_ignore.bboxes labels = gt_instances_ignore.labels index = bboxes.new_full((len(bboxes), 1), i) # (batch_idx, label, bboxes) target = torch.cat((index, labels[:, None].float(), bboxes), dim=1) batch_target_ignore_list.append(target) # (num_bboxes, 6) batch_gt_targets_ignore = torch.cat( batch_target_ignore_list, dim=0) if batch_gt_targets_ignore.shape[0] != 0: # Consider regions with ignore in annotations return self._loss_by_feat_with_ignore( cls_scores, bbox_preds, objectnesses, batch_gt_instances=batch_gt_instances, batch_img_metas=batch_img_metas, batch_gt_instances_ignore=batch_gt_targets_ignore) # 1. Convert gt to norm format batch_targets_normed = self._convert_gt_to_norm_format( batch_gt_instances, batch_img_metas) device = cls_scores[0].device loss_cls = torch.zeros(1, device=device) loss_box = torch.zeros(1, device=device) loss_obj = torch.zeros(1, device=device) scaled_factor = torch.ones(7, device=device) for i in range(self.num_levels): batch_size, _, h, w = bbox_preds[i].shape target_obj = torch.zeros_like(objectnesses[i]) # empty gt bboxes if batch_targets_normed.shape[1] == 0: loss_box += bbox_preds[i].sum() * 0 loss_cls += cls_scores[i].sum() * 0 loss_obj += self.loss_obj( objectnesses[i], target_obj) * self.obj_level_weights[i] continue priors_base_sizes_i = self.priors_base_sizes[i] # feature map scale whwh scaled_factor[2:6] = torch.tensor( bbox_preds[i].shape)[[3, 2, 3, 2]] # Scale batch_targets from range 0-1 to range 0-features_maps size. # (num_base_priors, num_bboxes, 7) batch_targets_scaled = batch_targets_normed * scaled_factor # 2. Shape match wh_ratio = batch_targets_scaled[..., 4:6] / priors_base_sizes_i[:, None] match_inds = torch.max( wh_ratio, 1 / wh_ratio).max(2)[0] < self.prior_match_thr batch_targets_scaled = batch_targets_scaled[match_inds] # no gt bbox matches anchor if batch_targets_scaled.shape[0] == 0: loss_box += bbox_preds[i].sum() * 0 loss_cls += cls_scores[i].sum() * 0 loss_obj += self.loss_obj( objectnesses[i], target_obj) * self.obj_level_weights[i] continue # 3. Positive samples with additional neighbors # check the left, up, right, bottom sides of the # targets grid, and determine whether assigned # them as positive samples as well. batch_targets_cxcy = batch_targets_scaled[:, 2:4] grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy left, up = ((batch_targets_cxcy % 1 < self.near_neighbor_thr) & (batch_targets_cxcy > 1)).T right, bottom = ((grid_xy % 1 < self.near_neighbor_thr) & (grid_xy > 1)).T offset_inds = torch.stack( (torch.ones_like(left), left, up, right, bottom)) batch_targets_scaled = batch_targets_scaled.repeat( (5, 1, 1))[offset_inds] retained_offsets = self.grid_offset.repeat(1, offset_inds.shape[1], 1)[offset_inds] # prepare pred results and positive sample indexes to # calculate class loss and bbox lo _chunk_targets = batch_targets_scaled.chunk(4, 1) img_class_inds, grid_xy, grid_wh, priors_inds = _chunk_targets priors_inds, (img_inds, class_inds) = priors_inds.long().view( -1), img_class_inds.long().T grid_xy_long = (grid_xy - retained_offsets * self.near_neighbor_thr).long() grid_x_inds, grid_y_inds = grid_xy_long.T bboxes_targets = torch.cat((grid_xy - grid_xy_long, grid_wh), 1) # 4. Calculate loss # bbox loss retained_bbox_pred = bbox_preds[i].reshape( batch_size, self.num_base_priors, -1, h, w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds] priors_base_sizes_i = priors_base_sizes_i[priors_inds] decoded_bbox_pred = self._decode_bbox_to_xywh( retained_bbox_pred, priors_base_sizes_i) loss_box_i, iou = self.loss_bbox(decoded_bbox_pred, bboxes_targets) loss_box += loss_box_i # obj loss iou = iou.detach().clamp(0) target_obj[img_inds, priors_inds, grid_y_inds, grid_x_inds] = iou.type(target_obj.dtype) loss_obj += self.loss_obj(objectnesses[i], target_obj) * self.obj_level_weights[i] # cls loss if self.num_classes > 1: pred_cls_scores = cls_scores[i].reshape( batch_size, self.num_base_priors, -1, h, w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds] target_class = torch.full_like(pred_cls_scores, 0.) target_class[range(batch_targets_scaled.shape[0]), class_inds] = 1. loss_cls += self.loss_cls(pred_cls_scores, target_class) else: loss_cls += cls_scores[i].sum() * 0 _, world_size = get_dist_info() return dict( loss_cls=loss_cls * batch_size * world_size, loss_obj=loss_obj * batch_size * world_size, loss_bbox=loss_box * batch_size * world_size) def _convert_gt_to_norm_format(self, batch_gt_instances: Sequence[InstanceData], batch_img_metas: Sequence[dict]) -> Tensor: if isinstance(batch_gt_instances, torch.Tensor): # fast version img_shape = batch_img_metas[0]['batch_input_shape'] gt_bboxes_xyxy = batch_gt_instances[:, 2:] xy1, xy2 = gt_bboxes_xyxy.split((2, 2), dim=-1) gt_bboxes_xywh = torch.cat([(xy2 + xy1) / 2, (xy2 - xy1)], dim=-1) gt_bboxes_xywh[:, 1::2] /= img_shape[0] gt_bboxes_xywh[:, 0::2] /= img_shape[1] batch_gt_instances[:, 2:] = gt_bboxes_xywh # (num_base_priors, num_bboxes, 6) batch_targets_normed = batch_gt_instances.repeat( self.num_base_priors, 1, 1) else: batch_target_list = [] # Convert xyxy bbox to yolo format. for i, gt_instances in enumerate(batch_gt_instances): img_shape = batch_img_metas[i]['batch_input_shape'] bboxes = gt_instances.bboxes labels = gt_instances.labels xy1, xy2 = bboxes.split((2, 2), dim=-1) bboxes = torch.cat([(xy2 + xy1) / 2, (xy2 - xy1)], dim=-1) # normalized to 0-1 bboxes[:, 1::2] /= img_shape[0] bboxes[:, 0::2] /= img_shape[1] index = bboxes.new_full((len(bboxes), 1), i) # (batch_idx, label, normed_bbox) target = torch.cat((index, labels[:, None].float(), bboxes), dim=1) batch_target_list.append(target) # (num_base_priors, num_bboxes, 6) batch_targets_normed = torch.cat( batch_target_list, dim=0).repeat(self.num_base_priors, 1, 1) # (num_base_priors, num_bboxes, 1) batch_targets_prior_inds = self.prior_inds.repeat( 1, batch_targets_normed.shape[1])[..., None] # (num_base_priors, num_bboxes, 7) # (img_ind, labels, bbox_cx, bbox_cy, bbox_w, bbox_h, prior_ind) batch_targets_normed = torch.cat( (batch_targets_normed, batch_targets_prior_inds), 2) return batch_targets_normed def _decode_bbox_to_xywh(self, bbox_pred, priors_base_sizes) -> Tensor: bbox_pred = bbox_pred.sigmoid() pred_xy = bbox_pred[:, :2] * 2 - 0.5 pred_wh = (bbox_pred[:, 2:] * 2)**2 * priors_base_sizes decoded_bbox_pred = torch.cat((pred_xy, pred_wh), dim=-1) return decoded_bbox_pred def _loss_by_feat_with_ignore( self, cls_scores: Sequence[Tensor], bbox_preds: Sequence[Tensor], objectnesses: Sequence[Tensor], batch_gt_instances: Sequence[InstanceData], batch_img_metas: Sequence[dict], batch_gt_instances_ignore: Sequence[Tensor]) -> dict: """Calculate the loss based on the features extracted by the detection head. Args: cls_scores (Sequence[Tensor]): Box scores for each scale level, each is a 4D-tensor, the channel number is num_priors * num_classes. bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale level, each is a 4D-tensor, the channel number is num_priors * 4. objectnesses (Sequence[Tensor]): Score factor for all scale level, each is a 4D-tensor, has shape (batch_size, 1, H, W). batch_gt_instances (Sequence[InstanceData]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. batch_img_metas (Sequence[dict]): Meta information of each image, e.g., image size, scaling factor, etc. batch_gt_instances_ignore (Sequence[Tensor]): Ignore boxes with batch_ids and labels, each is a 2D-tensor, the channel number is 6, means that (batch_id, label, xmin, ymin, xmax, ymax). Returns: dict[str, Tensor]: A dictionary of losses. """ # 1. Convert gt to norm format batch_targets_normed = self._convert_gt_to_norm_format( batch_gt_instances, batch_img_metas) featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] if featmap_sizes != self.featmap_sizes: self.mlvl_priors = self.prior_generator.grid_priors( featmap_sizes, dtype=cls_scores[0].dtype, device=cls_scores[0].device) self.featmap_sizes = featmap_sizes device = cls_scores[0].device loss_cls = torch.zeros(1, device=device) loss_box = torch.zeros(1, device=device) loss_obj = torch.zeros(1, device=device) scaled_factor = torch.ones(7, device=device) for i in range(self.num_levels): batch_size, _, h, w = bbox_preds[i].shape target_obj = torch.zeros_like(objectnesses[i]) not_ignore_flags = bbox_preds[i].new_ones(batch_size, self.num_base_priors, h, w) ignore_overlaps = bbox_overlaps(self.mlvl_priors[i], batch_gt_instances_ignore[..., 2:], 'iof') ignore_max_overlaps, ignore_max_ignore_index = ignore_overlaps.max( dim=1) batch_inds = batch_gt_instances_ignore[:, 0][ignore_max_ignore_index] ignore_inds = (ignore_max_overlaps > self.ignore_iof_thr).nonzero( as_tuple=True)[0] batch_inds = batch_inds[ignore_inds].long() ignore_priors, ignore_grid_xs, ignore_grid_ys = get_prior_xy_info( ignore_inds, self.num_base_priors, self.featmap_sizes[i]) not_ignore_flags[batch_inds, ignore_priors, ignore_grid_ys, ignore_grid_xs] = 0 # empty gt bboxes if batch_targets_normed.shape[1] == 0: loss_box += bbox_preds[i].sum() * 0 loss_cls += cls_scores[i].sum() * 0 loss_obj += self.loss_obj( objectnesses[i], target_obj, weight=not_ignore_flags, avg_factor=max(not_ignore_flags.sum(), 1)) * self.obj_level_weights[i] continue priors_base_sizes_i = self.priors_base_sizes[i] # feature map scale whwh scaled_factor[2:6] = torch.tensor( bbox_preds[i].shape)[[3, 2, 3, 2]] # Scale batch_targets from range 0-1 to range 0-features_maps size. # (num_base_priors, num_bboxes, 7) batch_targets_scaled = batch_targets_normed * scaled_factor # 2. Shape match wh_ratio = batch_targets_scaled[..., 4:6] / priors_base_sizes_i[:, None] match_inds = torch.max( wh_ratio, 1 / wh_ratio).max(2)[0] < self.prior_match_thr batch_targets_scaled = batch_targets_scaled[match_inds] # no gt bbox matches anchor if batch_targets_scaled.shape[0] == 0: loss_box += bbox_preds[i].sum() * 0 loss_cls += cls_scores[i].sum() * 0 loss_obj += self.loss_obj( objectnesses[i], target_obj, weight=not_ignore_flags, avg_factor=max(not_ignore_flags.sum(), 1)) * self.obj_level_weights[i] continue # 3. Positive samples with additional neighbors # check the left, up, right, bottom sides of the # targets grid, and determine whether assigned # them as positive samples as well. batch_targets_cxcy = batch_targets_scaled[:, 2:4] grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy left, up = ((batch_targets_cxcy % 1 < self.near_neighbor_thr) & (batch_targets_cxcy > 1)).T right, bottom = ((grid_xy % 1 < self.near_neighbor_thr) & (grid_xy > 1)).T offset_inds = torch.stack( (torch.ones_like(left), left, up, right, bottom)) batch_targets_scaled = batch_targets_scaled.repeat( (5, 1, 1))[offset_inds] retained_offsets = self.grid_offset.repeat(1, offset_inds.shape[1], 1)[offset_inds] # prepare pred results and positive sample indexes to # calculate class loss and bbox lo _chunk_targets = batch_targets_scaled.chunk(4, 1) img_class_inds, grid_xy, grid_wh, priors_inds = _chunk_targets priors_inds, (img_inds, class_inds) = priors_inds.long().view( -1), img_class_inds.long().T grid_xy_long = (grid_xy - retained_offsets * self.near_neighbor_thr).long() grid_x_inds, grid_y_inds = grid_xy_long.T bboxes_targets = torch.cat((grid_xy - grid_xy_long, grid_wh), 1) # 4. Calculate loss # bbox loss retained_bbox_pred = bbox_preds[i].reshape( batch_size, self.num_base_priors, -1, h, w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds] priors_base_sizes_i = priors_base_sizes_i[priors_inds] decoded_bbox_pred = self._decode_bbox_to_xywh( retained_bbox_pred, priors_base_sizes_i) not_ignore_weights = not_ignore_flags[img_inds, priors_inds, grid_y_inds, grid_x_inds] loss_box_i, iou = self.loss_bbox( decoded_bbox_pred, bboxes_targets, weight=not_ignore_weights, avg_factor=max(not_ignore_weights.sum(), 1)) loss_box += loss_box_i # obj loss iou = iou.detach().clamp(0) target_obj[img_inds, priors_inds, grid_y_inds, grid_x_inds] = iou.type(target_obj.dtype) loss_obj += self.loss_obj( objectnesses[i], target_obj, weight=not_ignore_flags, avg_factor=max(not_ignore_flags.sum(), 1)) * self.obj_level_weights[i] # cls loss if self.num_classes > 1: pred_cls_scores = cls_scores[i].reshape( batch_size, self.num_base_priors, -1, h, w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds] target_class = torch.full_like(pred_cls_scores, 0.) target_class[range(batch_targets_scaled.shape[0]), class_inds] = 1. loss_cls += self.loss_cls( pred_cls_scores, target_class, weight=not_ignore_weights[:, None].repeat( 1, self.num_classes), avg_factor=max(not_ignore_weights.sum(), 1)) else: loss_cls += cls_scores[i].sum() * 0 _, world_size = get_dist_info() return dict( loss_cls=loss_cls * batch_size * world_size, loss_obj=loss_obj * batch_size * world_size, loss_bbox=loss_box * batch_size * world_size)