RSPrompter / mmyolo /models /dense_heads /rtmdet_rotated_head.py
KyanChen's picture
Upload 89 files
3094730
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from typing import List, Optional, Sequence, Tuple
import torch
import torch.nn as nn
from mmdet.models.utils import filter_scores_and_topk
from mmdet.structures.bbox import HorizontalBoxes, distance2bbox
from mmdet.structures.bbox.transforms import bbox_cxcywh_to_xyxy, scale_boxes
from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
OptInstanceList, OptMultiConfig, reduce_mean)
from mmengine.config import ConfigDict
from mmengine.model import normal_init
from mmengine.structures import InstanceData
from torch import Tensor
from mmyolo.registry import MODELS, TASK_UTILS
from ..utils import gt_instances_preprocess
from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule
try:
from mmrotate.structures.bbox import RotatedBoxes, distance2obb
MMROTATE_AVAILABLE = True
except ImportError:
RotatedBoxes = None
distance2obb = None
MMROTATE_AVAILABLE = False
@MODELS.register_module()
class RTMDetRotatedSepBNHeadModule(RTMDetSepBNHeadModule):
"""Detection Head Module of RTMDet-R.
Compared with RTMDet Detection Head Module, RTMDet-R adds
a conv for angle prediction.
An `angle_out_dim` arg is added, which is generated by the
angle_coder module and controls the angle pred dim.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
num_base_priors (int): The number of priors (points) at a point
on the feature grid. Defaults to 1.
feat_channels (int): Number of hidden channels. Used in child classes.
Defaults to 256
stacked_convs (int): Number of stacking convs of the head.
Defaults to 2.
featmap_strides (Sequence[int]): Downsample factor of each feature map.
Defaults to (8, 16, 32).
share_conv (bool): Whether to share conv layers between stages.
Defaults to True.
pred_kernel_size (int): Kernel size of ``nn.Conv2d``. Defaults to 1.
angle_out_dim (int): Encoded length of angle, will passed by head.
Defaults to 1.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
layer. Defaults to ``dict(type='BN')``.
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
Default: dict(type='SiLU', inplace=True).
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(
self,
num_classes: int,
in_channels: int,
widen_factor: float = 1.0,
num_base_priors: int = 1,
feat_channels: int = 256,
stacked_convs: int = 2,
featmap_strides: Sequence[int] = [8, 16, 32],
share_conv: bool = True,
pred_kernel_size: int = 1,
angle_out_dim: int = 1,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN'),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None,
):
self.angle_out_dim = angle_out_dim
super().__init__(
num_classes=num_classes,
in_channels=in_channels,
widen_factor=widen_factor,
num_base_priors=num_base_priors,
feat_channels=feat_channels,
stacked_convs=stacked_convs,
featmap_strides=featmap_strides,
share_conv=share_conv,
pred_kernel_size=pred_kernel_size,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
init_cfg=init_cfg)
def _init_layers(self):
"""Initialize layers of the head."""
super()._init_layers()
self.rtm_ang = nn.ModuleList()
for _ in range(len(self.featmap_strides)):
self.rtm_ang.append(
nn.Conv2d(
self.feat_channels,
self.num_base_priors * self.angle_out_dim,
self.pred_kernel_size,
padding=self.pred_kernel_size // 2))
def init_weights(self) -> None:
"""Initialize weights of the head."""
# Use prior in model initialization to improve stability
super().init_weights()
for rtm_ang in self.rtm_ang:
normal_init(rtm_ang, std=0.01)
def forward(self, feats: Tuple[Tensor, ...]) -> tuple:
"""Forward features from the upstream network.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: Usually a tuple of classification scores and bbox prediction
- cls_scores (list[Tensor]): Classification scores for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * num_classes.
- bbox_preds (list[Tensor]): Box energies / deltas for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * 4.
- angle_preds (list[Tensor]): Angle prediction for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * angle_out_dim.
"""
cls_scores = []
bbox_preds = []
angle_preds = []
for idx, x in enumerate(feats):
cls_feat = x
reg_feat = x
for cls_layer in self.cls_convs[idx]:
cls_feat = cls_layer(cls_feat)
cls_score = self.rtm_cls[idx](cls_feat)
for reg_layer in self.reg_convs[idx]:
reg_feat = reg_layer(reg_feat)
reg_dist = self.rtm_reg[idx](reg_feat)
angle_pred = self.rtm_ang[idx](reg_feat)
cls_scores.append(cls_score)
bbox_preds.append(reg_dist)
angle_preds.append(angle_pred)
return tuple(cls_scores), tuple(bbox_preds), tuple(angle_preds)
@MODELS.register_module()
class RTMDetRotatedHead(RTMDetHead):
"""RTMDet-R head.
Compared with RTMDetHead, RTMDetRotatedHead add some args to support
rotated object detection.
- `angle_version` used to limit angle_range during training.
- `angle_coder` used to encode and decode angle, which is similar
to bbox_coder.
- `use_hbbox_loss` and `loss_angle` allow custom regression loss
calculation for rotated box.
There are three combination options for regression:
1. `use_hbbox_loss=False` and loss_angle is None.
.. code:: text
bbox_pred────(tblr)───┐
angle_pred decode──►rbox_pred──(xywha)─►loss_bbox
│ ▲
└────►decode──(a)─┘
2. `use_hbbox_loss=False` and loss_angle is specified.
A angle loss is added on angle_pred.
.. code:: text
bbox_pred────(tblr)───┐
angle_pred decode──►rbox_pred──(xywha)─►loss_bbox
│ ▲
├────►decode──(a)─┘
└───────────────────────────────────────────►loss_angle
3. `use_hbbox_loss=True` and loss_angle is specified.
In this case the loss_angle must be set.
.. code:: text
bbox_pred──(tblr)──►decode──►hbox_pred──(xyxy)──►loss_bbox
angle_pred──────────────────────────────────────►loss_angle
- There's a `decoded_with_angle` flag in test_cfg, which is similar
to training process.
When `decoded_with_angle=True`:
.. code:: text
bbox_pred────(tblr)───┐
angle_pred decode──(xywha)──►rbox_pred
│ ▲
└────►decode──(a)─┘
When `decoded_with_angle=False`:
.. code:: text
bbox_pred──(tblr)─►decode
│ (xyxy)
format───(xywh)──►concat──(xywha)──►rbox_pred
angle_pred────────►decode────(a)───────┘
Args:
head_module(ConfigType): Base module used for RTMDetRotatedHead.
prior_generator: Points generator feature maps in
2D points-based detectors.
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
angle_version (str): Angle representations. Defaults to 'le90'.
use_hbbox_loss (bool): If true, use horizontal bbox loss and
loss_angle should not be None. Default to False.
angle_coder (:obj:`ConfigDict` or dict): Config of angle coder.
loss_angle (:obj:`ConfigDict` or dict, optional): Config of angle loss.
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
anchor head. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
anchor head. Defaults to None.
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(
self,
head_module: ConfigType,
prior_generator: ConfigType = dict(
type='mmdet.MlvlPointGenerator', strides=[8, 16, 32],
offset=0),
bbox_coder: ConfigType = dict(type='DistanceAnglePointCoder'),
loss_cls: ConfigType = dict(
type='mmdet.QualityFocalLoss',
use_sigmoid=True,
beta=2.0,
loss_weight=1.0),
loss_bbox: ConfigType = dict(
type='mmrotate.RotatedIoULoss', mode='linear',
loss_weight=2.0),
angle_version: str = 'le90',
use_hbbox_loss: bool = False,
angle_coder: ConfigType = dict(type='mmrotate.PseudoAngleCoder'),
loss_angle: OptConfigType = None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None):
if not MMROTATE_AVAILABLE:
raise ImportError(
'Please run "mim install -r requirements/mmrotate.txt" '
'to install mmrotate first for rotated detection.')
self.angle_version = angle_version
self.use_hbbox_loss = use_hbbox_loss
if self.use_hbbox_loss:
assert loss_angle is not None, \
('When use hbbox loss, loss_angle needs to be specified')
self.angle_coder = TASK_UTILS.build(angle_coder)
self.angle_out_dim = self.angle_coder.encode_size
if head_module.get('angle_out_dim') is not None:
warnings.warn('angle_out_dim will be overridden by angle_coder '
'and does not need to be set manually')
head_module['angle_out_dim'] = self.angle_out_dim
super().__init__(
head_module=head_module,
prior_generator=prior_generator,
bbox_coder=bbox_coder,
loss_cls=loss_cls,
loss_bbox=loss_bbox,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg)
if loss_angle is not None:
self.loss_angle = MODELS.build(loss_angle)
else:
self.loss_angle = None
def predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
angle_preds: List[Tensor],
objectnesses: Optional[List[Tensor]] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = True,
with_nms: bool = True) -> List[InstanceData]:
"""Transform a batch of output features extracted by the head into bbox
results.
Args:
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
angle_preds (list[Tensor]): Box angle for each scale level
with shape (N, num_points * angle_dim, H, W)
objectnesses (list[Tensor], Optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, 1, H, W).
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
list[:obj:`InstanceData`]: Object detection results of each image
after the post process. Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 5),
the last dimension 4 arrange as (x, y, w, h, angle).
"""
assert len(cls_scores) == len(bbox_preds)
if objectnesses is None:
with_objectnesses = False
else:
with_objectnesses = True
assert len(cls_scores) == len(objectnesses)
cfg = self.test_cfg if cfg is None else cfg
cfg = copy.deepcopy(cfg)
multi_label = cfg.multi_label
multi_label &= self.num_classes > 1
cfg.multi_label = multi_label
# Whether to decode rbox with angle.
# different setting lead to different final results.
# Defaults to True.
decode_with_angle = cfg.get('decode_with_angle', True)
num_imgs = len(batch_img_metas)
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
# If the shape does not change, use the previous mlvl_priors
if featmap_sizes != self.featmap_sizes:
self.mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device)
self.featmap_sizes = featmap_sizes
flatten_priors = torch.cat(self.mlvl_priors)
mlvl_strides = [
flatten_priors.new_full(
(featmap_size.numel() * self.num_base_priors, ), stride) for
featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
]
flatten_stride = torch.cat(mlvl_strides)
# flatten cls_scores, bbox_preds and objectness
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.num_classes)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
]
flatten_angle_preds = [
angle_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.angle_out_dim)
for angle_pred in angle_preds
]
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
flatten_angle_preds = torch.cat(flatten_angle_preds, dim=1)
flatten_angle_preds = self.angle_coder.decode(
flatten_angle_preds, keepdim=True)
if decode_with_angle:
flatten_rbbox_preds = torch.cat(
[flatten_bbox_preds, flatten_angle_preds], dim=-1)
flatten_decoded_bboxes = self.bbox_coder.decode(
flatten_priors[None], flatten_rbbox_preds, flatten_stride)
else:
flatten_decoded_hbboxes = self.bbox_coder.decode(
flatten_priors[None], flatten_bbox_preds, flatten_stride)
flatten_decoded_hbboxes = HorizontalBoxes.xyxy_to_cxcywh(
flatten_decoded_hbboxes)
flatten_decoded_bboxes = torch.cat(
[flatten_decoded_hbboxes, flatten_angle_preds], dim=-1)
if with_objectnesses:
flatten_objectness = [
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
for objectness in objectnesses
]
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
else:
flatten_objectness = [None for _ in range(num_imgs)]
results_list = []
for (bboxes, scores, objectness,
img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores,
flatten_objectness, batch_img_metas):
scale_factor = img_meta['scale_factor']
if 'pad_param' in img_meta:
pad_param = img_meta['pad_param']
else:
pad_param = None
score_thr = cfg.get('score_thr', -1)
# yolox_style does not require the following operations
if objectness is not None and score_thr > 0 and not cfg.get(
'yolox_style', False):
conf_inds = objectness > score_thr
bboxes = bboxes[conf_inds, :]
scores = scores[conf_inds, :]
objectness = objectness[conf_inds]
if objectness is not None:
# conf = obj_conf * cls_conf
scores *= objectness[:, None]
if scores.shape[0] == 0:
empty_results = InstanceData()
empty_results.bboxes = RotatedBoxes(bboxes)
empty_results.scores = scores[:, 0]
empty_results.labels = scores[:, 0].int()
results_list.append(empty_results)
continue
nms_pre = cfg.get('nms_pre', 100000)
if cfg.multi_label is False:
scores, labels = scores.max(1, keepdim=True)
scores, _, keep_idxs, results = filter_scores_and_topk(
scores,
score_thr,
nms_pre,
results=dict(labels=labels[:, 0]))
labels = results['labels']
else:
scores, labels, keep_idxs, _ = filter_scores_and_topk(
scores, score_thr, nms_pre)
results = InstanceData(
scores=scores,
labels=labels,
bboxes=RotatedBoxes(bboxes[keep_idxs]))
if rescale:
if pad_param is not None:
results.bboxes.translate_([-pad_param[2], -pad_param[0]])
scale_factor = [1 / s for s in img_meta['scale_factor']]
results.bboxes = scale_boxes(results.bboxes, scale_factor)
if cfg.get('yolox_style', False):
# do not need max_per_img
cfg.max_per_img = len(results)
results = self._bbox_post_process(
results=results,
cfg=cfg,
rescale=False,
with_nms=with_nms,
img_meta=img_meta)
results_list.append(results)
return results_list
def loss_by_feat(
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
angle_preds: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
"""Compute losses of the head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Decoded box for each scale
level with shape (N, num_anchors * 4, H, W) in
[tl_x, tl_y, br_x, br_y] format.
angle_preds (list[Tensor]): Angle prediction for each scale
level with shape (N, num_anchors * angle_out_dim, H, W).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
num_imgs = len(batch_img_metas)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.prior_generator.num_levels
gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs)
gt_labels = gt_info[:, :, :1]
gt_bboxes = gt_info[:, :, 1:] # xywha
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
device = cls_scores[0].device
# If the shape does not equal, generate new one
if featmap_sizes != self.featmap_sizes_train:
self.featmap_sizes_train = featmap_sizes
mlvl_priors_with_stride = self.prior_generator.grid_priors(
featmap_sizes, device=device, with_stride=True)
self.flatten_priors_train = torch.cat(
mlvl_priors_with_stride, dim=0)
flatten_cls_scores = torch.cat([
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.cls_out_channels)
for cls_score in cls_scores
], 1).contiguous()
flatten_tblrs = torch.cat([
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
], 1)
flatten_tblrs = flatten_tblrs * self.flatten_priors_train[..., -1,
None]
flatten_angles = torch.cat([
angle_pred.permute(0, 2, 3, 1).reshape(
num_imgs, -1, self.angle_out_dim) for angle_pred in angle_preds
], 1)
flatten_decoded_angle = self.angle_coder.decode(
flatten_angles, keepdim=True)
flatten_tblra = torch.cat([flatten_tblrs, flatten_decoded_angle],
dim=-1)
flatten_rbboxes = distance2obb(
self.flatten_priors_train[..., :2],
flatten_tblra,
angle_version=self.angle_version)
if self.use_hbbox_loss:
flatten_hbboxes = distance2bbox(self.flatten_priors_train[..., :2],
flatten_tblrs)
assigned_result = self.assigner(flatten_rbboxes.detach(),
flatten_cls_scores.detach(),
self.flatten_priors_train, gt_labels,
gt_bboxes, pad_bbox_flag)
labels = assigned_result['assigned_labels'].reshape(-1)
label_weights = assigned_result['assigned_labels_weights'].reshape(-1)
bbox_targets = assigned_result['assigned_bboxes'].reshape(-1, 5)
assign_metrics = assigned_result['assign_metrics'].reshape(-1)
cls_preds = flatten_cls_scores.reshape(-1, self.num_classes)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().squeeze(1)
avg_factor = reduce_mean(assign_metrics.sum()).clamp_(min=1).item()
loss_cls = self.loss_cls(
cls_preds, (labels, assign_metrics),
label_weights,
avg_factor=avg_factor)
pos_bbox_targets = bbox_targets[pos_inds]
if self.use_hbbox_loss:
bbox_preds = flatten_hbboxes.reshape(-1, 4)
pos_bbox_targets = bbox_cxcywh_to_xyxy(pos_bbox_targets[:, :4])
else:
bbox_preds = flatten_rbboxes.reshape(-1, 5)
angle_preds = flatten_angles.reshape(-1, self.angle_out_dim)
if len(pos_inds) > 0:
loss_bbox = self.loss_bbox(
bbox_preds[pos_inds],
pos_bbox_targets,
weight=assign_metrics[pos_inds],
avg_factor=avg_factor)
loss_angle = angle_preds.sum() * 0
if self.loss_angle is not None:
pos_angle_targets = bbox_targets[pos_inds][:, 4:5]
pos_angle_targets = self.angle_coder.encode(pos_angle_targets)
loss_angle = self.loss_angle(
angle_preds[pos_inds],
pos_angle_targets,
weight=assign_metrics[pos_inds],
avg_factor=avg_factor)
else:
loss_bbox = bbox_preds.sum() * 0
loss_angle = angle_preds.sum() * 0
losses = dict()
losses['loss_cls'] = loss_cls
losses['loss_bbox'] = loss_bbox
if self.loss_angle is not None:
losses['loss_angle'] = loss_angle
return losses