# Copyright (c) OpenMMLab. All rights reserved. import torch.nn as nn from mmcv.cnn import ConvModule, Scale from mmdet.models.utils import multi_apply from mmocr.models.textdet.heads.base import BaseTextDetHead from mmocr.registry import MODELS INF = 1e8 @MODELS.register_module() class ABCNetDetHead(BaseTextDetHead): def __init__(self, in_channels, module_loss=dict(type='ABCNetLoss'), postprocessor=dict(type='ABCNetDetPostprocessor'), num_classes=1, strides=(4, 8, 16, 32, 64), feat_channels=256, stacked_convs=4, dcn_on_last_conv=False, conv_bias='auto', norm_on_bbox=False, centerness_on_reg=False, use_sigmoid_cls=True, with_bezier=False, use_scale=False, conv_cfg=None, norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), init_cfg=dict( type='Normal', layer='Conv2d', std=0.01, override=dict( type='Normal', name='conv_cls', std=0.01, bias_prob=0.01))): super().__init__( module_loss=module_loss, postprocessor=postprocessor, init_cfg=init_cfg) self.num_classes = num_classes self.in_channels = in_channels self.strides = strides self.feat_channels = feat_channels self.stacked_convs = stacked_convs self.dcn_on_last_conv = dcn_on_last_conv assert conv_bias == 'auto' or isinstance(conv_bias, bool) self.conv_bias = conv_bias self.norm_on_bbox = norm_on_bbox self.centerness_on_reg = centerness_on_reg self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.with_bezier = with_bezier self.use_scale = use_scale self.use_sigmoid_cls = use_sigmoid_cls if self.use_sigmoid_cls: self.cls_out_channels = num_classes else: self.cls_out_channels = num_classes + 1 self._init_layers() def _init_layers(self): """Initialize layers of the head.""" self._init_cls_convs() self._init_reg_convs() self._init_predictor() self.conv_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1) # if self.use_scale: self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides]) def _init_cls_convs(self): """Initialize classification conv layers of the head.""" self.cls_convs = nn.ModuleList() for i in range(self.stacked_convs): chn = self.in_channels if i == 0 else self.feat_channels if self.dcn_on_last_conv and i == self.stacked_convs - 1: conv_cfg = dict(type='DCNv2') else: conv_cfg = self.conv_cfg self.cls_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=self.norm_cfg, bias=self.conv_bias)) def _init_reg_convs(self): """Initialize bbox regression conv layers of the head.""" self.reg_convs = nn.ModuleList() for i in range(self.stacked_convs): chn = self.in_channels if i == 0 else self.feat_channels if self.dcn_on_last_conv and i == self.stacked_convs - 1: conv_cfg = dict(type='DCNv2') else: conv_cfg = self.conv_cfg self.reg_convs.append( ConvModule( chn, self.feat_channels, 3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=self.norm_cfg, bias=self.conv_bias)) def _init_predictor(self): """Initialize predictor layers of the head.""" self.conv_cls = nn.Conv2d( self.feat_channels, self.cls_out_channels, 3, padding=1) self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1) if self.with_bezier: self.conv_bezier = nn.Conv2d( self.feat_channels, 16, kernel_size=3, stride=1, padding=1) def forward(self, feats, data_samples=None): """Forward features from the upstream network. Args: feats (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. Returns: tuple: cls_scores (list[Tensor]): Box scores for each scale level, \ each is a 4D-tensor, the channel number is \ num_points * num_classes. bbox_preds (list[Tensor]): Box energies / deltas for each \ scale level, each is a 4D-tensor, the channel number is \ num_points * 4. centernesses (list[Tensor]): centerness for each scale level, \ each is a 4D-tensor, the channel number is num_points * 1. """ return multi_apply(self.forward_single, feats[1:], self.scales, self.strides) def forward_single(self, x, scale, stride): """Forward features of a single scale level. Args: x (Tensor): FPN feature maps of the specified stride. scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize the bbox prediction. stride (int): The corresponding stride for feature maps, only used to normalize the bbox prediction when self.norm_on_bbox is True. Returns: tuple: scores for each class, bbox predictions and centerness \ predictions of input feature maps. If ``with_bezier`` is True, Bezier prediction will also be returned. """ cls_feat = x reg_feat = x for cls_layer in self.cls_convs: cls_feat = cls_layer(cls_feat) cls_score = self.conv_cls(cls_feat) for reg_layer in self.reg_convs: reg_feat = reg_layer(reg_feat) bbox_pred = self.conv_reg(reg_feat) if self.with_bezier: bezier_pred = self.conv_bezier(reg_feat) if self.centerness_on_reg: centerness = self.conv_centerness(reg_feat) else: centerness = self.conv_centerness(cls_feat) # scale the bbox_pred of different level # float to avoid overflow when enabling FP16 if self.use_scale: bbox_pred = scale(bbox_pred).float() else: bbox_pred = bbox_pred.float() if self.norm_on_bbox: # bbox_pred needed for gradient computation has been modified # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace # F.relu(bbox_pred) with bbox_pred.clamp(min=0) bbox_pred = bbox_pred.clamp(min=0) else: bbox_pred = bbox_pred.exp() if self.with_bezier: return cls_score, bbox_pred, centerness, bezier_pred else: return cls_score, bbox_pred, centerness