KyanChen's picture
Upload 89 files
3094730
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
from typing import List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
from mmdet.models.dense_heads.base_dense_head import BaseDenseHead
from mmdet.models.utils import filter_scores_and_topk, multi_apply
from mmdet.structures.bbox import bbox_overlaps
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
OptMultiConfig)
from mmengine.config import ConfigDict
from mmengine.dist import get_dist_info
from mmengine.logging import print_log
from mmengine.model import BaseModule
from mmengine.structures import InstanceData
from torch import Tensor
from mmyolo.registry import MODELS, TASK_UTILS
from ..utils import make_divisible
def get_prior_xy_info(index: int, num_base_priors: int,
featmap_sizes: int) -> Tuple[int, int, int]:
"""Get prior index and xy index in feature map by flatten index."""
_, featmap_w = featmap_sizes
priors = index % num_base_priors
xy_index = index // num_base_priors
grid_y = xy_index // featmap_w
grid_x = xy_index % featmap_w
return priors, grid_x, grid_y
@MODELS.register_module()
class YOLOv5HeadModule(BaseModule):
"""YOLOv5Head head module used in `YOLOv5`.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (Union[int, Sequence]): Number of channels in the input
feature map.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
num_base_priors (int): The number of priors (points) at a point
on the feature grid.
featmap_strides (Sequence[int]): Downsample factor of each feature map.
Defaults to (8, 16, 32).
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
num_classes: int,
in_channels: Union[int, Sequence],
widen_factor: float = 1.0,
num_base_priors: int = 3,
featmap_strides: Sequence[int] = (8, 16, 32),
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.widen_factor = widen_factor
self.featmap_strides = featmap_strides
self.num_out_attrib = 5 + self.num_classes
self.num_levels = len(self.featmap_strides)
self.num_base_priors = num_base_priors
if isinstance(in_channels, int):
self.in_channels = [make_divisible(in_channels, widen_factor)
] * self.num_levels
else:
self.in_channels = [
make_divisible(i, widen_factor) for i in in_channels
]
self._init_layers()
def _init_layers(self):
"""initialize conv layers in YOLOv5 head."""
self.convs_pred = nn.ModuleList()
for i in range(self.num_levels):
conv_pred = nn.Conv2d(self.in_channels[i],
self.num_base_priors * self.num_out_attrib,
1)
self.convs_pred.append(conv_pred)
def init_weights(self):
"""Initialize the bias of YOLOv5 head."""
super().init_weights()
for mi, s in zip(self.convs_pred, self.featmap_strides): # from
b = mi.bias.data.view(self.num_base_priors, -1)
# obj (8 objects per 640 image)
b.data[:, 4] += math.log(8 / (640 / s)**2)
b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.999999))
mi.bias.data = b.view(-1)
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
"""Forward features from the upstream network.
Args:
x (Tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
Tuple[List]: A tuple of multi-level classification scores, bbox
predictions, and objectnesses.
"""
assert len(x) == self.num_levels
return multi_apply(self.forward_single, x, self.convs_pred)
def forward_single(self, x: Tensor,
convs: nn.Module) -> Tuple[Tensor, Tensor, Tensor]:
"""Forward feature of a single scale level."""
pred_map = convs(x)
bs, _, ny, nx = pred_map.shape
pred_map = pred_map.view(bs, self.num_base_priors, self.num_out_attrib,
ny, nx)
cls_score = pred_map[:, :, 5:, ...].reshape(bs, -1, ny, nx)
bbox_pred = pred_map[:, :, :4, ...].reshape(bs, -1, ny, nx)
objectness = pred_map[:, :, 4:5, ...].reshape(bs, -1, ny, nx)
return cls_score, bbox_pred, objectness
@MODELS.register_module()
class YOLOv5Head(BaseDenseHead):
"""YOLOv5Head head used in `YOLOv5`.
Args:
head_module(ConfigType): Base module used for YOLOv5Head
prior_generator(dict): Points generator feature maps in
2D points-based detectors.
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
loss_obj (:obj:`ConfigDict` or dict): Config of objectness loss.
prior_match_thr (float): Defaults to 4.0.
ignore_iof_thr (float): Defaults to -1.0.
obj_level_weights (List[float]): Defaults to [4.0, 1.0, 0.4].
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
anchor head. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
anchor head. Defaults to None.
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
head_module: ConfigType,
prior_generator: ConfigType = dict(
type='mmdet.YOLOAnchorGenerator',
base_sizes=[[(10, 13), (16, 30), (33, 23)],
[(30, 61), (62, 45), (59, 119)],
[(116, 90), (156, 198), (373, 326)]],
strides=[8, 16, 32]),
bbox_coder: ConfigType = dict(type='YOLOv5BBoxCoder'),
loss_cls: ConfigType = dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=0.5),
loss_bbox: ConfigType = dict(
type='IoULoss',
iou_mode='ciou',
bbox_format='xywh',
eps=1e-7,
reduction='mean',
loss_weight=0.05,
return_iou=True),
loss_obj: ConfigType = dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=1.0),
prior_match_thr: float = 4.0,
near_neighbor_thr: float = 0.5,
ignore_iof_thr: float = -1.0,
obj_level_weights: List[float] = [4.0, 1.0, 0.4],
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg=init_cfg)
self.head_module = MODELS.build(head_module)
self.num_classes = self.head_module.num_classes
self.featmap_strides = self.head_module.featmap_strides
self.num_levels = len(self.featmap_strides)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.loss_cls: nn.Module = MODELS.build(loss_cls)
self.loss_bbox: nn.Module = MODELS.build(loss_bbox)
self.loss_obj: nn.Module = MODELS.build(loss_obj)
self.prior_generator = TASK_UTILS.build(prior_generator)
self.bbox_coder = TASK_UTILS.build(bbox_coder)
self.num_base_priors = self.prior_generator.num_base_priors[0]
self.featmap_sizes = [torch.empty(1)] * self.num_levels
self.prior_match_thr = prior_match_thr
self.near_neighbor_thr = near_neighbor_thr
self.obj_level_weights = obj_level_weights
self.ignore_iof_thr = ignore_iof_thr
self.special_init()
def special_init(self):
"""Since YOLO series algorithms will inherit from YOLOv5Head, but
different algorithms have special initialization process.
The special_init function is designed to deal with this situation.
"""
assert len(self.obj_level_weights) == len(
self.featmap_strides) == self.num_levels
if self.prior_match_thr != 4.0:
print_log(
"!!!Now, you've changed the prior_match_thr "
'parameter to something other than 4.0. Please make sure '
'that you have modified both the regression formula in '
'bbox_coder and before loss_box computation, '
'otherwise the accuracy may be degraded!!!')
if self.num_classes == 1:
print_log('!!!You are using `YOLOv5Head` with num_classes == 1.'
' The loss_cls will be 0. This is a normal phenomenon.')
priors_base_sizes = torch.tensor(
self.prior_generator.base_sizes, dtype=torch.float)
featmap_strides = torch.tensor(
self.featmap_strides, dtype=torch.float)[:, None, None]
self.register_buffer(
'priors_base_sizes',
priors_base_sizes / featmap_strides,
persistent=False)
grid_offset = torch.tensor([
[0, 0], # center
[1, 0], # left
[0, 1], # up
[-1, 0], # right
[0, -1], # bottom
]).float()
self.register_buffer(
'grid_offset', grid_offset[:, None], persistent=False)
prior_inds = torch.arange(self.num_base_priors).float().view(
self.num_base_priors, 1)
self.register_buffer('prior_inds', prior_inds, persistent=False)
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
"""Forward features from the upstream network.
Args:
x (Tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
Tuple[List]: A tuple of multi-level classification scores, bbox
predictions, and objectnesses.
"""
return self.head_module(x)
def predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
objectnesses: Optional[List[Tensor]] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = True,
with_nms: bool = True) -> List[InstanceData]:
"""Transform a batch of output features extracted by the head into
bbox results.
Args:
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
objectnesses (list[Tensor], Optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, 1, H, W).
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
list[:obj:`InstanceData`]: Object detection results of each image
after the post process. Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
assert len(cls_scores) == len(bbox_preds)
if objectnesses is None:
with_objectnesses = False
else:
with_objectnesses = True
assert len(cls_scores) == len(objectnesses)
cfg = self.test_cfg if cfg is None else cfg
cfg = copy.deepcopy(cfg)
multi_label = cfg.multi_label
multi_label &= self.num_classes > 1
cfg.multi_label = multi_label
num_imgs = len(batch_img_metas)
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
# If the shape does not change, use the previous mlvl_priors
if featmap_sizes != self.featmap_sizes:
self.mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device)
self.featmap_sizes = featmap_sizes
flatten_priors = torch.cat(self.mlvl_priors)
mlvl_strides = [
flatten_priors.new_full(
(featmap_size.numel() * self.num_base_priors, ), stride) for
featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
]
flatten_stride = torch.cat(mlvl_strides)
# flatten cls_scores, bbox_preds and objectness
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.num_classes)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
]
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
flatten_decoded_bboxes = self.bbox_coder.decode(
flatten_priors[None], flatten_bbox_preds, flatten_stride)
if with_objectnesses:
flatten_objectness = [
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
for objectness in objectnesses
]
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
else:
flatten_objectness = [None for _ in range(num_imgs)]
results_list = []
for (bboxes, scores, objectness,
img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores,
flatten_objectness, batch_img_metas):
ori_shape = img_meta['ori_shape']
scale_factor = img_meta['scale_factor']
if 'pad_param' in img_meta:
pad_param = img_meta['pad_param']
else:
pad_param = None
score_thr = cfg.get('score_thr', -1)
# yolox_style does not require the following operations
if objectness is not None and score_thr > 0 and not cfg.get(
'yolox_style', False):
conf_inds = objectness > score_thr
bboxes = bboxes[conf_inds, :]
scores = scores[conf_inds, :]
objectness = objectness[conf_inds]
if objectness is not None:
# conf = obj_conf * cls_conf
scores *= objectness[:, None]
if scores.shape[0] == 0:
empty_results = InstanceData()
empty_results.bboxes = bboxes
empty_results.scores = scores[:, 0]
empty_results.labels = scores[:, 0].int()
results_list.append(empty_results)
continue
nms_pre = cfg.get('nms_pre', 100000)
if cfg.multi_label is False:
scores, labels = scores.max(1, keepdim=True)
scores, _, keep_idxs, results = filter_scores_and_topk(
scores,
score_thr,
nms_pre,
results=dict(labels=labels[:, 0]))
labels = results['labels']
else:
scores, labels, keep_idxs, _ = filter_scores_and_topk(
scores, score_thr, nms_pre)
results = InstanceData(
scores=scores, labels=labels, bboxes=bboxes[keep_idxs])
if rescale:
if pad_param is not None:
results.bboxes -= results.bboxes.new_tensor([
pad_param[2], pad_param[0], pad_param[2], pad_param[0]
])
results.bboxes /= results.bboxes.new_tensor(
scale_factor).repeat((1, 2))
if cfg.get('yolox_style', False):
# do not need max_per_img
cfg.max_per_img = len(results)
results = self._bbox_post_process(
results=results,
cfg=cfg,
rescale=False,
with_nms=with_nms,
img_meta=img_meta)
results.bboxes[:, 0::2].clamp_(0, ori_shape[1])
results.bboxes[:, 1::2].clamp_(0, ori_shape[0])
results_list.append(results)
return results_list
def loss(self, x: Tuple[Tensor], batch_data_samples: Union[list,
dict]) -> dict:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
batch_data_samples (List[:obj:`DetDataSample`], dict): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
dict: A dictionary of loss components.
"""
if isinstance(batch_data_samples, list):
losses = super().loss(x, batch_data_samples)
else:
outs = self(x)
# Fast version
loss_inputs = outs + (batch_data_samples['bboxes_labels'],
batch_data_samples['img_metas'])
losses = self.loss_by_feat(*loss_inputs)
return losses
def loss_by_feat(
self,
cls_scores: Sequence[Tensor],
bbox_preds: Sequence[Tensor],
objectnesses: Sequence[Tensor],
batch_gt_instances: Sequence[InstanceData],
batch_img_metas: Sequence[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
"""Calculate the loss based on the features extracted by the detection
head.
Args:
cls_scores (Sequence[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is
num_priors * num_classes.
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_priors * 4.
objectnesses (Sequence[Tensor]): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, 1, H, W).
batch_gt_instances (Sequence[InstanceData]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (Sequence[dict]): Meta information of each image,
e.g., image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of losses.
"""
if self.ignore_iof_thr != -1:
# TODO: Support fast version
# convert ignore gt
batch_target_ignore_list = []
for i, gt_instances_ignore in enumerate(batch_gt_instances_ignore):
bboxes = gt_instances_ignore.bboxes
labels = gt_instances_ignore.labels
index = bboxes.new_full((len(bboxes), 1), i)
# (batch_idx, label, bboxes)
target = torch.cat((index, labels[:, None].float(), bboxes),
dim=1)
batch_target_ignore_list.append(target)
# (num_bboxes, 6)
batch_gt_targets_ignore = torch.cat(
batch_target_ignore_list, dim=0)
if batch_gt_targets_ignore.shape[0] != 0:
# Consider regions with ignore in annotations
return self._loss_by_feat_with_ignore(
cls_scores,
bbox_preds,
objectnesses,
batch_gt_instances=batch_gt_instances,
batch_img_metas=batch_img_metas,
batch_gt_instances_ignore=batch_gt_targets_ignore)
# 1. Convert gt to norm format
batch_targets_normed = self._convert_gt_to_norm_format(
batch_gt_instances, batch_img_metas)
device = cls_scores[0].device
loss_cls = torch.zeros(1, device=device)
loss_box = torch.zeros(1, device=device)
loss_obj = torch.zeros(1, device=device)
scaled_factor = torch.ones(7, device=device)
for i in range(self.num_levels):
batch_size, _, h, w = bbox_preds[i].shape
target_obj = torch.zeros_like(objectnesses[i])
# empty gt bboxes
if batch_targets_normed.shape[1] == 0:
loss_box += bbox_preds[i].sum() * 0
loss_cls += cls_scores[i].sum() * 0
loss_obj += self.loss_obj(
objectnesses[i], target_obj) * self.obj_level_weights[i]
continue
priors_base_sizes_i = self.priors_base_sizes[i]
# feature map scale whwh
scaled_factor[2:6] = torch.tensor(
bbox_preds[i].shape)[[3, 2, 3, 2]]
# Scale batch_targets from range 0-1 to range 0-features_maps size.
# (num_base_priors, num_bboxes, 7)
batch_targets_scaled = batch_targets_normed * scaled_factor
# 2. Shape match
wh_ratio = batch_targets_scaled[...,
4:6] / priors_base_sizes_i[:, None]
match_inds = torch.max(
wh_ratio, 1 / wh_ratio).max(2)[0] < self.prior_match_thr
batch_targets_scaled = batch_targets_scaled[match_inds]
# no gt bbox matches anchor
if batch_targets_scaled.shape[0] == 0:
loss_box += bbox_preds[i].sum() * 0
loss_cls += cls_scores[i].sum() * 0
loss_obj += self.loss_obj(
objectnesses[i], target_obj) * self.obj_level_weights[i]
continue
# 3. Positive samples with additional neighbors
# check the left, up, right, bottom sides of the
# targets grid, and determine whether assigned
# them as positive samples as well.
batch_targets_cxcy = batch_targets_scaled[:, 2:4]
grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy
left, up = ((batch_targets_cxcy % 1 < self.near_neighbor_thr) &
(batch_targets_cxcy > 1)).T
right, bottom = ((grid_xy % 1 < self.near_neighbor_thr) &
(grid_xy > 1)).T
offset_inds = torch.stack(
(torch.ones_like(left), left, up, right, bottom))
batch_targets_scaled = batch_targets_scaled.repeat(
(5, 1, 1))[offset_inds]
retained_offsets = self.grid_offset.repeat(1, offset_inds.shape[1],
1)[offset_inds]
# prepare pred results and positive sample indexes to
# calculate class loss and bbox lo
_chunk_targets = batch_targets_scaled.chunk(4, 1)
img_class_inds, grid_xy, grid_wh, priors_inds = _chunk_targets
priors_inds, (img_inds, class_inds) = priors_inds.long().view(
-1), img_class_inds.long().T
grid_xy_long = (grid_xy -
retained_offsets * self.near_neighbor_thr).long()
grid_x_inds, grid_y_inds = grid_xy_long.T
bboxes_targets = torch.cat((grid_xy - grid_xy_long, grid_wh), 1)
# 4. Calculate loss
# bbox loss
retained_bbox_pred = bbox_preds[i].reshape(
batch_size, self.num_base_priors, -1, h,
w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds]
priors_base_sizes_i = priors_base_sizes_i[priors_inds]
decoded_bbox_pred = self._decode_bbox_to_xywh(
retained_bbox_pred, priors_base_sizes_i)
loss_box_i, iou = self.loss_bbox(decoded_bbox_pred, bboxes_targets)
loss_box += loss_box_i
# obj loss
iou = iou.detach().clamp(0)
target_obj[img_inds, priors_inds, grid_y_inds,
grid_x_inds] = iou.type(target_obj.dtype)
loss_obj += self.loss_obj(objectnesses[i],
target_obj) * self.obj_level_weights[i]
# cls loss
if self.num_classes > 1:
pred_cls_scores = cls_scores[i].reshape(
batch_size, self.num_base_priors, -1, h,
w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds]
target_class = torch.full_like(pred_cls_scores, 0.)
target_class[range(batch_targets_scaled.shape[0]),
class_inds] = 1.
loss_cls += self.loss_cls(pred_cls_scores, target_class)
else:
loss_cls += cls_scores[i].sum() * 0
_, world_size = get_dist_info()
return dict(
loss_cls=loss_cls * batch_size * world_size,
loss_obj=loss_obj * batch_size * world_size,
loss_bbox=loss_box * batch_size * world_size)
def _convert_gt_to_norm_format(self,
batch_gt_instances: Sequence[InstanceData],
batch_img_metas: Sequence[dict]) -> Tensor:
if isinstance(batch_gt_instances, torch.Tensor):
# fast version
img_shape = batch_img_metas[0]['batch_input_shape']
gt_bboxes_xyxy = batch_gt_instances[:, 2:]
xy1, xy2 = gt_bboxes_xyxy.split((2, 2), dim=-1)
gt_bboxes_xywh = torch.cat([(xy2 + xy1) / 2, (xy2 - xy1)], dim=-1)
gt_bboxes_xywh[:, 1::2] /= img_shape[0]
gt_bboxes_xywh[:, 0::2] /= img_shape[1]
batch_gt_instances[:, 2:] = gt_bboxes_xywh
# (num_base_priors, num_bboxes, 6)
batch_targets_normed = batch_gt_instances.repeat(
self.num_base_priors, 1, 1)
else:
batch_target_list = []
# Convert xyxy bbox to yolo format.
for i, gt_instances in enumerate(batch_gt_instances):
img_shape = batch_img_metas[i]['batch_input_shape']
bboxes = gt_instances.bboxes
labels = gt_instances.labels
xy1, xy2 = bboxes.split((2, 2), dim=-1)
bboxes = torch.cat([(xy2 + xy1) / 2, (xy2 - xy1)], dim=-1)
# normalized to 0-1
bboxes[:, 1::2] /= img_shape[0]
bboxes[:, 0::2] /= img_shape[1]
index = bboxes.new_full((len(bboxes), 1), i)
# (batch_idx, label, normed_bbox)
target = torch.cat((index, labels[:, None].float(), bboxes),
dim=1)
batch_target_list.append(target)
# (num_base_priors, num_bboxes, 6)
batch_targets_normed = torch.cat(
batch_target_list, dim=0).repeat(self.num_base_priors, 1, 1)
# (num_base_priors, num_bboxes, 1)
batch_targets_prior_inds = self.prior_inds.repeat(
1, batch_targets_normed.shape[1])[..., None]
# (num_base_priors, num_bboxes, 7)
# (img_ind, labels, bbox_cx, bbox_cy, bbox_w, bbox_h, prior_ind)
batch_targets_normed = torch.cat(
(batch_targets_normed, batch_targets_prior_inds), 2)
return batch_targets_normed
def _decode_bbox_to_xywh(self, bbox_pred, priors_base_sizes) -> Tensor:
bbox_pred = bbox_pred.sigmoid()
pred_xy = bbox_pred[:, :2] * 2 - 0.5
pred_wh = (bbox_pred[:, 2:] * 2)**2 * priors_base_sizes
decoded_bbox_pred = torch.cat((pred_xy, pred_wh), dim=-1)
return decoded_bbox_pred
def _loss_by_feat_with_ignore(
self, cls_scores: Sequence[Tensor], bbox_preds: Sequence[Tensor],
objectnesses: Sequence[Tensor],
batch_gt_instances: Sequence[InstanceData],
batch_img_metas: Sequence[dict],
batch_gt_instances_ignore: Sequence[Tensor]) -> dict:
"""Calculate the loss based on the features extracted by the detection
head.
Args:
cls_scores (Sequence[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is
num_priors * num_classes.
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_priors * 4.
objectnesses (Sequence[Tensor]): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, 1, H, W).
batch_gt_instances (Sequence[InstanceData]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (Sequence[dict]): Meta information of each image,
e.g., image size, scaling factor, etc.
batch_gt_instances_ignore (Sequence[Tensor]): Ignore boxes with
batch_ids and labels, each is a 2D-tensor, the channel number
is 6, means that (batch_id, label, xmin, ymin, xmax, ymax).
Returns:
dict[str, Tensor]: A dictionary of losses.
"""
# 1. Convert gt to norm format
batch_targets_normed = self._convert_gt_to_norm_format(
batch_gt_instances, batch_img_metas)
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
if featmap_sizes != self.featmap_sizes:
self.mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device)
self.featmap_sizes = featmap_sizes
device = cls_scores[0].device
loss_cls = torch.zeros(1, device=device)
loss_box = torch.zeros(1, device=device)
loss_obj = torch.zeros(1, device=device)
scaled_factor = torch.ones(7, device=device)
for i in range(self.num_levels):
batch_size, _, h, w = bbox_preds[i].shape
target_obj = torch.zeros_like(objectnesses[i])
not_ignore_flags = bbox_preds[i].new_ones(batch_size,
self.num_base_priors, h,
w)
ignore_overlaps = bbox_overlaps(self.mlvl_priors[i],
batch_gt_instances_ignore[..., 2:],
'iof')
ignore_max_overlaps, ignore_max_ignore_index = ignore_overlaps.max(
dim=1)
batch_inds = batch_gt_instances_ignore[:,
0][ignore_max_ignore_index]
ignore_inds = (ignore_max_overlaps > self.ignore_iof_thr).nonzero(
as_tuple=True)[0]
batch_inds = batch_inds[ignore_inds].long()
ignore_priors, ignore_grid_xs, ignore_grid_ys = get_prior_xy_info(
ignore_inds, self.num_base_priors, self.featmap_sizes[i])
not_ignore_flags[batch_inds, ignore_priors, ignore_grid_ys,
ignore_grid_xs] = 0
# empty gt bboxes
if batch_targets_normed.shape[1] == 0:
loss_box += bbox_preds[i].sum() * 0
loss_cls += cls_scores[i].sum() * 0
loss_obj += self.loss_obj(
objectnesses[i],
target_obj,
weight=not_ignore_flags,
avg_factor=max(not_ignore_flags.sum(),
1)) * self.obj_level_weights[i]
continue
priors_base_sizes_i = self.priors_base_sizes[i]
# feature map scale whwh
scaled_factor[2:6] = torch.tensor(
bbox_preds[i].shape)[[3, 2, 3, 2]]
# Scale batch_targets from range 0-1 to range 0-features_maps size.
# (num_base_priors, num_bboxes, 7)
batch_targets_scaled = batch_targets_normed * scaled_factor
# 2. Shape match
wh_ratio = batch_targets_scaled[...,
4:6] / priors_base_sizes_i[:, None]
match_inds = torch.max(
wh_ratio, 1 / wh_ratio).max(2)[0] < self.prior_match_thr
batch_targets_scaled = batch_targets_scaled[match_inds]
# no gt bbox matches anchor
if batch_targets_scaled.shape[0] == 0:
loss_box += bbox_preds[i].sum() * 0
loss_cls += cls_scores[i].sum() * 0
loss_obj += self.loss_obj(
objectnesses[i],
target_obj,
weight=not_ignore_flags,
avg_factor=max(not_ignore_flags.sum(),
1)) * self.obj_level_weights[i]
continue
# 3. Positive samples with additional neighbors
# check the left, up, right, bottom sides of the
# targets grid, and determine whether assigned
# them as positive samples as well.
batch_targets_cxcy = batch_targets_scaled[:, 2:4]
grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy
left, up = ((batch_targets_cxcy % 1 < self.near_neighbor_thr) &
(batch_targets_cxcy > 1)).T
right, bottom = ((grid_xy % 1 < self.near_neighbor_thr) &
(grid_xy > 1)).T
offset_inds = torch.stack(
(torch.ones_like(left), left, up, right, bottom))
batch_targets_scaled = batch_targets_scaled.repeat(
(5, 1, 1))[offset_inds]
retained_offsets = self.grid_offset.repeat(1, offset_inds.shape[1],
1)[offset_inds]
# prepare pred results and positive sample indexes to
# calculate class loss and bbox lo
_chunk_targets = batch_targets_scaled.chunk(4, 1)
img_class_inds, grid_xy, grid_wh, priors_inds = _chunk_targets
priors_inds, (img_inds, class_inds) = priors_inds.long().view(
-1), img_class_inds.long().T
grid_xy_long = (grid_xy -
retained_offsets * self.near_neighbor_thr).long()
grid_x_inds, grid_y_inds = grid_xy_long.T
bboxes_targets = torch.cat((grid_xy - grid_xy_long, grid_wh), 1)
# 4. Calculate loss
# bbox loss
retained_bbox_pred = bbox_preds[i].reshape(
batch_size, self.num_base_priors, -1, h,
w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds]
priors_base_sizes_i = priors_base_sizes_i[priors_inds]
decoded_bbox_pred = self._decode_bbox_to_xywh(
retained_bbox_pred, priors_base_sizes_i)
not_ignore_weights = not_ignore_flags[img_inds, priors_inds,
grid_y_inds, grid_x_inds]
loss_box_i, iou = self.loss_bbox(
decoded_bbox_pred,
bboxes_targets,
weight=not_ignore_weights,
avg_factor=max(not_ignore_weights.sum(), 1))
loss_box += loss_box_i
# obj loss
iou = iou.detach().clamp(0)
target_obj[img_inds, priors_inds, grid_y_inds,
grid_x_inds] = iou.type(target_obj.dtype)
loss_obj += self.loss_obj(
objectnesses[i],
target_obj,
weight=not_ignore_flags,
avg_factor=max(not_ignore_flags.sum(),
1)) * self.obj_level_weights[i]
# cls loss
if self.num_classes > 1:
pred_cls_scores = cls_scores[i].reshape(
batch_size, self.num_base_priors, -1, h,
w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds]
target_class = torch.full_like(pred_cls_scores, 0.)
target_class[range(batch_targets_scaled.shape[0]),
class_inds] = 1.
loss_cls += self.loss_cls(
pred_cls_scores,
target_class,
weight=not_ignore_weights[:, None].repeat(
1, self.num_classes),
avg_factor=max(not_ignore_weights.sum(), 1))
else:
loss_cls += cls_scores[i].sum() * 0
_, world_size = get_dist_info()
return dict(
loss_cls=loss_cls * batch_size * world_size,
loss_obj=loss_obj * batch_size * world_size,
loss_bbox=loss_box * batch_size * world_size)