Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Sequence, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmdet.models.utils import multi_apply | |
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList, | |
OptMultiConfig, reduce_mean) | |
from mmengine import MessageHub | |
from mmengine.model import BaseModule, bias_init_with_prob | |
from mmengine.structures import InstanceData | |
from torch import Tensor | |
from mmyolo.registry import MODELS | |
from ..layers.yolo_bricks import PPYOLOESELayer | |
from ..utils import gt_instances_preprocess | |
from .yolov6_head import YOLOv6Head | |
class PPYOLOEHeadModule(BaseModule): | |
"""PPYOLOEHead head module used in `PPYOLOE. | |
<https://arxiv.org/abs/2203.16250>`_. | |
Args: | |
num_classes (int): Number of categories excluding the background | |
category. | |
in_channels (int): Number of channels in the input feature map. | |
widen_factor (float): Width multiplier, multiply number of | |
channels in each layer by this amount. Defaults to 1.0. | |
num_base_priors (int): The number of priors (points) at a point | |
on the feature grid. | |
featmap_strides (Sequence[int]): Downsample factor of each feature map. | |
Defaults to (8, 16, 32). | |
reg_max (int): Max value of integral set :math: ``{0, ..., reg_max}`` | |
in QFL setting. Defaults to 16. | |
norm_cfg (dict): Config dict for normalization layer. | |
Defaults to dict(type='BN', momentum=0.03, eps=0.001). | |
act_cfg (dict): Config dict for activation layer. | |
Defaults to dict(type='SiLU', inplace=True). | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Defaults to None. | |
""" | |
def __init__(self, | |
num_classes: int, | |
in_channels: Union[int, Sequence], | |
widen_factor: float = 1.0, | |
num_base_priors: int = 1, | |
featmap_strides: Sequence[int] = (8, 16, 32), | |
reg_max: int = 16, | |
norm_cfg: ConfigType = dict( | |
type='BN', momentum=0.1, eps=1e-5), | |
act_cfg: ConfigType = dict(type='SiLU', inplace=True), | |
init_cfg: OptMultiConfig = None): | |
super().__init__(init_cfg=init_cfg) | |
self.num_classes = num_classes | |
self.featmap_strides = featmap_strides | |
self.num_levels = len(self.featmap_strides) | |
self.num_base_priors = num_base_priors | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
self.reg_max = reg_max | |
if isinstance(in_channels, int): | |
self.in_channels = [int(in_channels * widen_factor) | |
] * self.num_levels | |
else: | |
self.in_channels = [int(i * widen_factor) for i in in_channels] | |
self._init_layers() | |
def init_weights(self, prior_prob=0.01): | |
"""Initialize the weight and bias of PPYOLOE head.""" | |
super().init_weights() | |
for conv in self.cls_preds: | |
conv.bias.data.fill_(bias_init_with_prob(prior_prob)) | |
conv.weight.data.fill_(0.) | |
for conv in self.reg_preds: | |
conv.bias.data.fill_(1.0) | |
conv.weight.data.fill_(0.) | |
def _init_layers(self): | |
"""initialize conv layers in PPYOLOE head.""" | |
self.cls_preds = nn.ModuleList() | |
self.reg_preds = nn.ModuleList() | |
self.cls_stems = nn.ModuleList() | |
self.reg_stems = nn.ModuleList() | |
for in_channel in self.in_channels: | |
self.cls_stems.append( | |
PPYOLOESELayer( | |
in_channel, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) | |
self.reg_stems.append( | |
PPYOLOESELayer( | |
in_channel, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) | |
for in_channel in self.in_channels: | |
self.cls_preds.append( | |
nn.Conv2d(in_channel, self.num_classes, 3, padding=1)) | |
self.reg_preds.append( | |
nn.Conv2d(in_channel, 4 * (self.reg_max + 1), 3, padding=1)) | |
# init proj | |
proj = torch.linspace(0, self.reg_max, self.reg_max + 1).view( | |
[1, self.reg_max + 1, 1, 1]) | |
self.register_buffer('proj', proj, persistent=False) | |
def forward(self, x: Tuple[Tensor]) -> Tensor: | |
"""Forward features from the upstream network. | |
Args: | |
x (Tuple[Tensor]): Features from the upstream network, each is | |
a 4D-tensor. | |
Returns: | |
Tuple[List]: A tuple of multi-level classification scores, bbox | |
predictions. | |
""" | |
assert len(x) == self.num_levels | |
return multi_apply(self.forward_single, x, self.cls_stems, | |
self.cls_preds, self.reg_stems, self.reg_preds) | |
def forward_single(self, x: Tensor, cls_stem: nn.ModuleList, | |
cls_pred: nn.ModuleList, reg_stem: nn.ModuleList, | |
reg_pred: nn.ModuleList) -> Tensor: | |
"""Forward feature of a single scale level.""" | |
b, _, h, w = x.shape | |
hw = h * w | |
avg_feat = F.adaptive_avg_pool2d(x, (1, 1)) | |
cls_logit = cls_pred(cls_stem(x, avg_feat) + x) | |
bbox_dist_preds = reg_pred(reg_stem(x, avg_feat)) | |
# TODO: Test whether use matmul instead of conv can speed up training. | |
bbox_dist_preds = bbox_dist_preds.reshape( | |
[-1, 4, self.reg_max + 1, hw]).permute(0, 2, 3, 1) | |
bbox_preds = F.conv2d(F.softmax(bbox_dist_preds, dim=1), self.proj) | |
if self.training: | |
return cls_logit, bbox_preds, bbox_dist_preds | |
else: | |
return cls_logit, bbox_preds | |
class PPYOLOEHead(YOLOv6Head): | |
"""PPYOLOEHead head used in `PPYOLOE <https://arxiv.org/abs/2203.16250>`_. | |
The YOLOv6 head and the PPYOLOE head are only slightly different. | |
Distribution focal loss is extra used in PPYOLOE, but not in YOLOv6. | |
Args: | |
head_module(ConfigType): Base module used for YOLOv5Head | |
prior_generator(dict): Points generator feature maps in | |
2D points-based detectors. | |
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder. | |
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. | |
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss. | |
loss_dfl (:obj:`ConfigDict` or dict): Config of distribution focal | |
loss. | |
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of | |
anchor head. Defaults to None. | |
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of | |
anchor head. Defaults to None. | |
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or | |
list[dict], optional): Initialization config dict. | |
Defaults to None. | |
""" | |
def __init__(self, | |
head_module: ConfigType, | |
prior_generator: ConfigType = dict( | |
type='mmdet.MlvlPointGenerator', | |
offset=0.5, | |
strides=[8, 16, 32]), | |
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'), | |
loss_cls: ConfigType = dict( | |
type='mmdet.VarifocalLoss', | |
use_sigmoid=True, | |
alpha=0.75, | |
gamma=2.0, | |
iou_weighted=True, | |
reduction='sum', | |
loss_weight=1.0), | |
loss_bbox: ConfigType = dict( | |
type='IoULoss', | |
iou_mode='giou', | |
bbox_format='xyxy', | |
reduction='mean', | |
loss_weight=2.5, | |
return_iou=False), | |
loss_dfl: ConfigType = dict( | |
type='mmdet.DistributionFocalLoss', | |
reduction='mean', | |
loss_weight=0.5 / 4), | |
train_cfg: OptConfigType = None, | |
test_cfg: OptConfigType = None, | |
init_cfg: OptMultiConfig = None): | |
super().__init__( | |
head_module=head_module, | |
prior_generator=prior_generator, | |
bbox_coder=bbox_coder, | |
loss_cls=loss_cls, | |
loss_bbox=loss_bbox, | |
train_cfg=train_cfg, | |
test_cfg=test_cfg, | |
init_cfg=init_cfg) | |
self.loss_dfl = MODELS.build(loss_dfl) | |
# ppyoloe doesn't need loss_obj | |
self.loss_obj = None | |
def loss_by_feat( | |
self, | |
cls_scores: Sequence[Tensor], | |
bbox_preds: Sequence[Tensor], | |
bbox_dist_preds: Sequence[Tensor], | |
batch_gt_instances: Sequence[InstanceData], | |
batch_img_metas: Sequence[dict], | |
batch_gt_instances_ignore: OptInstanceList = None) -> dict: | |
"""Calculate the loss based on the features extracted by the detection | |
head. | |
Args: | |
cls_scores (Sequence[Tensor]): Box scores for each scale level, | |
each is a 4D-tensor, the channel number is | |
num_priors * num_classes. | |
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale | |
level, each is a 4D-tensor, the channel number is | |
num_priors * 4. | |
bbox_dist_preds (Sequence[Tensor]): Box distribution logits for | |
each scale level with shape (bs, reg_max + 1, H*W, 4). | |
batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
gt_instance. It usually includes ``bboxes`` and ``labels`` | |
attributes. | |
batch_img_metas (list[dict]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): | |
Batch of gt_instances_ignore. It includes ``bboxes`` attribute | |
data that is ignored during training and testing. | |
Defaults to None. | |
Returns: | |
dict[str, Tensor]: A dictionary of losses. | |
""" | |
# get epoch information from message hub | |
message_hub = MessageHub.get_current_instance() | |
current_epoch = message_hub.get_info('epoch') | |
num_imgs = len(batch_img_metas) | |
current_featmap_sizes = [ | |
cls_score.shape[2:] for cls_score in cls_scores | |
] | |
# If the shape does not equal, generate new one | |
if current_featmap_sizes != self.featmap_sizes_train: | |
self.featmap_sizes_train = current_featmap_sizes | |
mlvl_priors_with_stride = self.prior_generator.grid_priors( | |
self.featmap_sizes_train, | |
dtype=cls_scores[0].dtype, | |
device=cls_scores[0].device, | |
with_stride=True) | |
self.num_level_priors = [len(n) for n in mlvl_priors_with_stride] | |
self.flatten_priors_train = torch.cat( | |
mlvl_priors_with_stride, dim=0) | |
self.stride_tensor = self.flatten_priors_train[..., [2]] | |
# gt info | |
gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs) | |
gt_labels = gt_info[:, :, :1] | |
gt_bboxes = gt_info[:, :, 1:] # xyxy | |
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float() | |
# pred info | |
flatten_cls_preds = [ | |
cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, | |
self.num_classes) | |
for cls_pred in cls_scores | |
] | |
flatten_pred_bboxes = [ | |
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) | |
for bbox_pred in bbox_preds | |
] | |
# (bs, reg_max+1, n, 4) -> (bs, n, 4, reg_max+1) | |
flatten_pred_dists = [ | |
bbox_pred_org.permute(0, 2, 3, 1).reshape( | |
num_imgs, -1, (self.head_module.reg_max + 1) * 4) | |
for bbox_pred_org in bbox_dist_preds | |
] | |
flatten_dist_preds = torch.cat(flatten_pred_dists, dim=1) | |
flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1) | |
flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1) | |
flatten_pred_bboxes = self.bbox_coder.decode( | |
self.flatten_priors_train[..., :2], flatten_pred_bboxes, | |
self.stride_tensor[..., 0]) | |
pred_scores = torch.sigmoid(flatten_cls_preds) | |
if current_epoch < self.initial_epoch: | |
assigned_result = self.initial_assigner( | |
flatten_pred_bboxes.detach(), self.flatten_priors_train, | |
self.num_level_priors, gt_labels, gt_bboxes, pad_bbox_flag) | |
else: | |
assigned_result = self.assigner(flatten_pred_bboxes.detach(), | |
pred_scores.detach(), | |
self.flatten_priors_train, | |
gt_labels, gt_bboxes, | |
pad_bbox_flag) | |
assigned_bboxes = assigned_result['assigned_bboxes'] | |
assigned_scores = assigned_result['assigned_scores'] | |
fg_mask_pre_prior = assigned_result['fg_mask_pre_prior'] | |
# cls loss | |
with torch.cuda.amp.autocast(enabled=False): | |
loss_cls = self.loss_cls(flatten_cls_preds, assigned_scores) | |
# rescale bbox | |
assigned_bboxes /= self.stride_tensor | |
flatten_pred_bboxes /= self.stride_tensor | |
assigned_scores_sum = assigned_scores.sum() | |
# reduce_mean between all gpus | |
assigned_scores_sum = torch.clamp( | |
reduce_mean(assigned_scores_sum), min=1) | |
loss_cls /= assigned_scores_sum | |
# select positive samples mask | |
num_pos = fg_mask_pre_prior.sum() | |
if num_pos > 0: | |
# when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox | |
# will not report an error | |
# iou loss | |
prior_bbox_mask = fg_mask_pre_prior.unsqueeze(-1).repeat([1, 1, 4]) | |
pred_bboxes_pos = torch.masked_select( | |
flatten_pred_bboxes, prior_bbox_mask).reshape([-1, 4]) | |
assigned_bboxes_pos = torch.masked_select( | |
assigned_bboxes, prior_bbox_mask).reshape([-1, 4]) | |
bbox_weight = torch.masked_select( | |
assigned_scores.sum(-1), fg_mask_pre_prior).unsqueeze(-1) | |
loss_bbox = self.loss_bbox( | |
pred_bboxes_pos, | |
assigned_bboxes_pos, | |
weight=bbox_weight, | |
avg_factor=assigned_scores_sum) | |
# dfl loss | |
dist_mask = fg_mask_pre_prior.unsqueeze(-1).repeat( | |
[1, 1, (self.head_module.reg_max + 1) * 4]) | |
pred_dist_pos = torch.masked_select( | |
flatten_dist_preds, | |
dist_mask).reshape([-1, 4, self.head_module.reg_max + 1]) | |
assigned_ltrb = self.bbox_coder.encode( | |
self.flatten_priors_train[..., :2] / self.stride_tensor, | |
assigned_bboxes, | |
max_dis=self.head_module.reg_max, | |
eps=0.01) | |
assigned_ltrb_pos = torch.masked_select( | |
assigned_ltrb, prior_bbox_mask).reshape([-1, 4]) | |
loss_dfl = self.loss_dfl( | |
pred_dist_pos.reshape(-1, self.head_module.reg_max + 1), | |
assigned_ltrb_pos.reshape(-1), | |
weight=bbox_weight.expand(-1, 4).reshape(-1), | |
avg_factor=assigned_scores_sum) | |
else: | |
loss_bbox = flatten_pred_bboxes.sum() * 0 | |
loss_dfl = flatten_pred_bboxes.sum() * 0 | |
return dict(loss_cls=loss_cls, loss_bbox=loss_bbox, loss_dfl=loss_dfl) | |