KyanChen's picture
Upload 89 files
3094730
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models.utils import multi_apply
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
OptMultiConfig, reduce_mean)
from mmengine import MessageHub
from mmengine.model import BaseModule, bias_init_with_prob
from mmengine.structures import InstanceData
from torch import Tensor
from mmyolo.registry import MODELS
from ..layers.yolo_bricks import PPYOLOESELayer
from ..utils import gt_instances_preprocess
from .yolov6_head import YOLOv6Head
@MODELS.register_module()
class PPYOLOEHeadModule(BaseModule):
"""PPYOLOEHead head module used in `PPYOLOE.
<https://arxiv.org/abs/2203.16250>`_.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
num_base_priors (int): The number of priors (points) at a point
on the feature grid.
featmap_strides (Sequence[int]): Downsample factor of each feature map.
Defaults to (8, 16, 32).
reg_max (int): Max value of integral set :math: ``{0, ..., reg_max}``
in QFL setting. Defaults to 16.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='SiLU', inplace=True).
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
num_classes: int,
in_channels: Union[int, Sequence],
widen_factor: float = 1.0,
num_base_priors: int = 1,
featmap_strides: Sequence[int] = (8, 16, 32),
reg_max: int = 16,
norm_cfg: ConfigType = dict(
type='BN', momentum=0.1, eps=1e-5),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.featmap_strides = featmap_strides
self.num_levels = len(self.featmap_strides)
self.num_base_priors = num_base_priors
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.reg_max = reg_max
if isinstance(in_channels, int):
self.in_channels = [int(in_channels * widen_factor)
] * self.num_levels
else:
self.in_channels = [int(i * widen_factor) for i in in_channels]
self._init_layers()
def init_weights(self, prior_prob=0.01):
"""Initialize the weight and bias of PPYOLOE head."""
super().init_weights()
for conv in self.cls_preds:
conv.bias.data.fill_(bias_init_with_prob(prior_prob))
conv.weight.data.fill_(0.)
for conv in self.reg_preds:
conv.bias.data.fill_(1.0)
conv.weight.data.fill_(0.)
def _init_layers(self):
"""initialize conv layers in PPYOLOE head."""
self.cls_preds = nn.ModuleList()
self.reg_preds = nn.ModuleList()
self.cls_stems = nn.ModuleList()
self.reg_stems = nn.ModuleList()
for in_channel in self.in_channels:
self.cls_stems.append(
PPYOLOESELayer(
in_channel, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
self.reg_stems.append(
PPYOLOESELayer(
in_channel, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
for in_channel in self.in_channels:
self.cls_preds.append(
nn.Conv2d(in_channel, self.num_classes, 3, padding=1))
self.reg_preds.append(
nn.Conv2d(in_channel, 4 * (self.reg_max + 1), 3, padding=1))
# init proj
proj = torch.linspace(0, self.reg_max, self.reg_max + 1).view(
[1, self.reg_max + 1, 1, 1])
self.register_buffer('proj', proj, persistent=False)
def forward(self, x: Tuple[Tensor]) -> Tensor:
"""Forward features from the upstream network.
Args:
x (Tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
Tuple[List]: A tuple of multi-level classification scores, bbox
predictions.
"""
assert len(x) == self.num_levels
return multi_apply(self.forward_single, x, self.cls_stems,
self.cls_preds, self.reg_stems, self.reg_preds)
def forward_single(self, x: Tensor, cls_stem: nn.ModuleList,
cls_pred: nn.ModuleList, reg_stem: nn.ModuleList,
reg_pred: nn.ModuleList) -> Tensor:
"""Forward feature of a single scale level."""
b, _, h, w = x.shape
hw = h * w
avg_feat = F.adaptive_avg_pool2d(x, (1, 1))
cls_logit = cls_pred(cls_stem(x, avg_feat) + x)
bbox_dist_preds = reg_pred(reg_stem(x, avg_feat))
# TODO: Test whether use matmul instead of conv can speed up training.
bbox_dist_preds = bbox_dist_preds.reshape(
[-1, 4, self.reg_max + 1, hw]).permute(0, 2, 3, 1)
bbox_preds = F.conv2d(F.softmax(bbox_dist_preds, dim=1), self.proj)
if self.training:
return cls_logit, bbox_preds, bbox_dist_preds
else:
return cls_logit, bbox_preds
@MODELS.register_module()
class PPYOLOEHead(YOLOv6Head):
"""PPYOLOEHead head used in `PPYOLOE <https://arxiv.org/abs/2203.16250>`_.
The YOLOv6 head and the PPYOLOE head are only slightly different.
Distribution focal loss is extra used in PPYOLOE, but not in YOLOv6.
Args:
head_module(ConfigType): Base module used for YOLOv5Head
prior_generator(dict): Points generator feature maps in
2D points-based detectors.
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
loss_dfl (:obj:`ConfigDict` or dict): Config of distribution focal
loss.
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
anchor head. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
anchor head. Defaults to None.
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
head_module: ConfigType,
prior_generator: ConfigType = dict(
type='mmdet.MlvlPointGenerator',
offset=0.5,
strides=[8, 16, 32]),
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
loss_cls: ConfigType = dict(
type='mmdet.VarifocalLoss',
use_sigmoid=True,
alpha=0.75,
gamma=2.0,
iou_weighted=True,
reduction='sum',
loss_weight=1.0),
loss_bbox: ConfigType = dict(
type='IoULoss',
iou_mode='giou',
bbox_format='xyxy',
reduction='mean',
loss_weight=2.5,
return_iou=False),
loss_dfl: ConfigType = dict(
type='mmdet.DistributionFocalLoss',
reduction='mean',
loss_weight=0.5 / 4),
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None):
super().__init__(
head_module=head_module,
prior_generator=prior_generator,
bbox_coder=bbox_coder,
loss_cls=loss_cls,
loss_bbox=loss_bbox,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg)
self.loss_dfl = MODELS.build(loss_dfl)
# ppyoloe doesn't need loss_obj
self.loss_obj = None
def loss_by_feat(
self,
cls_scores: Sequence[Tensor],
bbox_preds: Sequence[Tensor],
bbox_dist_preds: Sequence[Tensor],
batch_gt_instances: Sequence[InstanceData],
batch_img_metas: Sequence[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
"""Calculate the loss based on the features extracted by the detection
head.
Args:
cls_scores (Sequence[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is
num_priors * num_classes.
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_priors * 4.
bbox_dist_preds (Sequence[Tensor]): Box distribution logits for
each scale level with shape (bs, reg_max + 1, H*W, 4).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of losses.
"""
# get epoch information from message hub
message_hub = MessageHub.get_current_instance()
current_epoch = message_hub.get_info('epoch')
num_imgs = len(batch_img_metas)
current_featmap_sizes = [
cls_score.shape[2:] for cls_score in cls_scores
]
# If the shape does not equal, generate new one
if current_featmap_sizes != self.featmap_sizes_train:
self.featmap_sizes_train = current_featmap_sizes
mlvl_priors_with_stride = self.prior_generator.grid_priors(
self.featmap_sizes_train,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device,
with_stride=True)
self.num_level_priors = [len(n) for n in mlvl_priors_with_stride]
self.flatten_priors_train = torch.cat(
mlvl_priors_with_stride, dim=0)
self.stride_tensor = self.flatten_priors_train[..., [2]]
# gt info
gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs)
gt_labels = gt_info[:, :, :1]
gt_bboxes = gt_info[:, :, 1:] # xyxy
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
# pred info
flatten_cls_preds = [
cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.num_classes)
for cls_pred in cls_scores
]
flatten_pred_bboxes = [
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
]
# (bs, reg_max+1, n, 4) -> (bs, n, 4, reg_max+1)
flatten_pred_dists = [
bbox_pred_org.permute(0, 2, 3, 1).reshape(
num_imgs, -1, (self.head_module.reg_max + 1) * 4)
for bbox_pred_org in bbox_dist_preds
]
flatten_dist_preds = torch.cat(flatten_pred_dists, dim=1)
flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
flatten_pred_bboxes = self.bbox_coder.decode(
self.flatten_priors_train[..., :2], flatten_pred_bboxes,
self.stride_tensor[..., 0])
pred_scores = torch.sigmoid(flatten_cls_preds)
if current_epoch < self.initial_epoch:
assigned_result = self.initial_assigner(
flatten_pred_bboxes.detach(), self.flatten_priors_train,
self.num_level_priors, gt_labels, gt_bboxes, pad_bbox_flag)
else:
assigned_result = self.assigner(flatten_pred_bboxes.detach(),
pred_scores.detach(),
self.flatten_priors_train,
gt_labels, gt_bboxes,
pad_bbox_flag)
assigned_bboxes = assigned_result['assigned_bboxes']
assigned_scores = assigned_result['assigned_scores']
fg_mask_pre_prior = assigned_result['fg_mask_pre_prior']
# cls loss
with torch.cuda.amp.autocast(enabled=False):
loss_cls = self.loss_cls(flatten_cls_preds, assigned_scores)
# rescale bbox
assigned_bboxes /= self.stride_tensor
flatten_pred_bboxes /= self.stride_tensor
assigned_scores_sum = assigned_scores.sum()
# reduce_mean between all gpus
assigned_scores_sum = torch.clamp(
reduce_mean(assigned_scores_sum), min=1)
loss_cls /= assigned_scores_sum
# select positive samples mask
num_pos = fg_mask_pre_prior.sum()
if num_pos > 0:
# when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox
# will not report an error
# iou loss
prior_bbox_mask = fg_mask_pre_prior.unsqueeze(-1).repeat([1, 1, 4])
pred_bboxes_pos = torch.masked_select(
flatten_pred_bboxes, prior_bbox_mask).reshape([-1, 4])
assigned_bboxes_pos = torch.masked_select(
assigned_bboxes, prior_bbox_mask).reshape([-1, 4])
bbox_weight = torch.masked_select(
assigned_scores.sum(-1), fg_mask_pre_prior).unsqueeze(-1)
loss_bbox = self.loss_bbox(
pred_bboxes_pos,
assigned_bboxes_pos,
weight=bbox_weight,
avg_factor=assigned_scores_sum)
# dfl loss
dist_mask = fg_mask_pre_prior.unsqueeze(-1).repeat(
[1, 1, (self.head_module.reg_max + 1) * 4])
pred_dist_pos = torch.masked_select(
flatten_dist_preds,
dist_mask).reshape([-1, 4, self.head_module.reg_max + 1])
assigned_ltrb = self.bbox_coder.encode(
self.flatten_priors_train[..., :2] / self.stride_tensor,
assigned_bboxes,
max_dis=self.head_module.reg_max,
eps=0.01)
assigned_ltrb_pos = torch.masked_select(
assigned_ltrb, prior_bbox_mask).reshape([-1, 4])
loss_dfl = self.loss_dfl(
pred_dist_pos.reshape(-1, self.head_module.reg_max + 1),
assigned_ltrb_pos.reshape(-1),
weight=bbox_weight.expand(-1, 4).reshape(-1),
avg_factor=assigned_scores_sum)
else:
loss_bbox = flatten_pred_bboxes.sum() * 0
loss_dfl = flatten_pred_bboxes.sum() * 0
return dict(loss_cls=loss_cls, loss_bbox=loss_bbox, loss_dfl=loss_dfl)