# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, List, Optional, Tuple, Union import torch from mmcv.cnn import ConvModule, build_plugin_layer from mmengine.model import BaseModule, Sequential import mmocr.utils as utils from mmocr.models.textrecog.layers import BasicBlock from mmocr.registry import MODELS @MODELS.register_module() class ResNet(BaseModule): """ Args: in_channels (int): Number of channels of input image tensor. stem_channels (list[int]): List of channels in each stem layer. E.g., [64, 128] stands for 64 and 128 channels in the first and second stem layers. block_cfgs (dict): Configs of block arch_layers (list[int]): List of Block number for each stage. arch_channels (list[int]): List of channels for each stage. strides (Sequence[int] or Sequence[tuple]): Strides of the first block of each stage. out_indices (Sequence[int], optional): Indices of output stages. If not specified, only the last stage will be returned. plugins (dict, optional): Configs of stage plugins init_cfg (dict or list[dict], optional): Initialization config dict. """ def __init__(self, in_channels: int, stem_channels: List[int], block_cfgs: dict, arch_layers: List[int], arch_channels: List[int], strides: Union[List[int], List[Tuple]], out_indices: Optional[List[int]] = None, plugins: Optional[Dict] = None, init_cfg: Optional[Union[Dict, List[Dict]]] = [ dict(type='Xavier', layer='Conv2d'), dict(type='Constant', val=1, layer='BatchNorm2d'), ]): super().__init__(init_cfg=init_cfg) assert isinstance(in_channels, int) assert isinstance(stem_channels, int) or utils.is_type_list( stem_channels, int) assert utils.is_type_list(arch_layers, int) assert utils.is_type_list(arch_channels, int) assert utils.is_type_list(strides, tuple) or utils.is_type_list( strides, int) assert len(arch_layers) == len(arch_channels) == len(strides) assert out_indices is None or isinstance(out_indices, (list, tuple)) self.out_indices = out_indices self._make_stem_layer(in_channels, stem_channels) self.num_stages = len(arch_layers) self.use_plugins = False self.arch_channels = arch_channels self.res_layers = [] if plugins is not None: self.plugin_ahead_names = [] self.plugin_after_names = [] self.use_plugins = True for i, num_blocks in enumerate(arch_layers): stride = strides[i] channel = arch_channels[i] if self.use_plugins: self._make_stage_plugins(plugins, stage_idx=i) res_layer = self._make_layer( block_cfgs=block_cfgs, inplanes=self.inplanes, planes=channel, blocks=num_blocks, stride=stride, ) self.inplanes = channel layer_name = f'layer{i + 1}' self.add_module(layer_name, res_layer) self.res_layers.append(layer_name) def _make_layer(self, block_cfgs: Dict, inplanes: int, planes: int, blocks: int, stride: int) -> Sequential: """Build resnet layer. Args: block_cfgs (dict): Configs of blocks. inplanes (int): Number of input channels. planes (int): Number of output channels. blocks (int): Number of blocks. stride (int): Stride of the first block. Returns: Sequential: A sequence of blocks. """ layers = [] downsample = None block_cfgs_ = block_cfgs.copy() if isinstance(stride, int): stride = (stride, stride) if stride[0] != 1 or stride[1] != 1 or inplanes != planes: downsample = ConvModule( inplanes, planes, 1, stride, norm_cfg=dict(type='BN'), act_cfg=None) if block_cfgs_['type'] == 'BasicBlock': block = BasicBlock block_cfgs_.pop('type') else: raise ValueError('{} not implement yet'.format(block['type'])) layers.append( block( inplanes, planes, stride=stride, downsample=downsample, **block_cfgs_)) inplanes = planes for _ in range(1, blocks): layers.append(block(inplanes, planes, **block_cfgs_)) return Sequential(*layers) def _make_stem_layer(self, in_channels: int, stem_channels: Union[int, List[int]]) -> None: """Make stem layers. Args: in_channels (int): Number of input channels. stem_channels (list[int] or int): List of channels in each stem layer. If int, only one stem layer will be created. """ if isinstance(stem_channels, int): stem_channels = [stem_channels] stem_layers = [] for _, channels in enumerate(stem_channels): stem_layer = ConvModule( in_channels, channels, kernel_size=3, stride=1, padding=1, bias=False, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU')) in_channels = channels stem_layers.append(stem_layer) self.stem_layers = Sequential(*stem_layers) self.inplanes = stem_channels[-1] def _make_stage_plugins(self, plugins: List[Dict], stage_idx: int) -> None: """Make plugins for ResNet ``stage_idx``th stage. Currently we support inserting ``nn.Maxpooling``, ``mmcv.cnn.Convmodule``into the backbone. Originally designed for ResNet31-like architectures. Examples: >>> plugins=[ ... dict(cfg=dict(type="Maxpooling", arg=(2,2)), ... stages=(True, True, False, False), ... position='before_stage'), ... dict(cfg=dict(type="Maxpooling", arg=(2,1)), ... stages=(False, False, True, Flase), ... position='before_stage'), ... dict(cfg=dict( ... type='ConvModule', ... kernel_size=3, ... stride=1, ... padding=1, ... norm_cfg=dict(type='BN'), ... act_cfg=dict(type='ReLU')), ... stages=(True, True, True, True), ... position='after_stage')] Suppose ``stage_idx=1``, the structure of stage would be: .. code-block:: none Maxpooling -> A set of Basicblocks -> ConvModule Args: plugins (list[dict]): List of plugin configs to build. stage_idx (int): Index of stage to build """ in_channels = self.arch_channels[stage_idx] self.plugin_ahead_names.append([]) self.plugin_after_names.append([]) for plugin in plugins: plugin = plugin.copy() stages = plugin.pop('stages', None) position = plugin.pop('position', None) assert stages is None or len(stages) == self.num_stages if stages[stage_idx]: if position == 'before_stage': name, layer = build_plugin_layer( plugin['cfg'], f'_before_stage_{stage_idx+1}', in_channels=in_channels, out_channels=in_channels) self.plugin_ahead_names[stage_idx].append(name) self.add_module(name, layer) elif position == 'after_stage': name, layer = build_plugin_layer( plugin['cfg'], f'_after_stage_{stage_idx+1}', in_channels=in_channels, out_channels=in_channels) self.plugin_after_names[stage_idx].append(name) self.add_module(name, layer) else: raise ValueError('uncorrect plugin position') def forward_plugin(self, x: torch.Tensor, plugin_name: List[str]) -> torch.Tensor: """Forward tensor through plugin. Args: x (torch.Tensor): Input tensor. plugin_name (list[str]): Name of plugins. Returns: torch.Tensor: Output tensor. """ out = x for name in plugin_name: out = getattr(self, name)(out) return out def forward(self, x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]: """ Args: x (Tensor): Image tensor of shape :math:`(N, 3, H, W)`. Returns: Tensor or list[Tensor]: Feature tensor. It can be a list of feature outputs at specific layers if ``out_indices`` is specified. """ x = self.stem_layers(x) outs = [] for i, layer_name in enumerate(self.res_layers): res_layer = getattr(self, layer_name) if not self.use_plugins: x = res_layer(x) if self.out_indices and i in self.out_indices: outs.append(x) else: x = self.forward_plugin(x, self.plugin_ahead_names[i]) x = res_layer(x) x = self.forward_plugin(x, self.plugin_after_names[i]) if self.out_indices and i in self.out_indices: outs.append(x) return tuple(outs) if self.out_indices else x