Spaces:
Runtime error
Runtime error
File size: 7,920 Bytes
3094730 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import List, Sequence, Union
import torch
import torch.nn as nn
from mmcv.cnn import build_plugin_layer
from mmdet.utils import ConfigType, OptMultiConfig
from mmengine.model import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from mmyolo.registry import MODELS
@MODELS.register_module()
class BaseBackbone(BaseModule, metaclass=ABCMeta):
"""BaseBackbone backbone used in YOLO series.
.. code:: text
Backbone model structure diagram
+-----------+
| input |
+-----------+
v
+-----------+
| stem |
| layer |
+-----------+
v
+-----------+
| stage |
| layer 1 |
+-----------+
v
+-----------+
| stage |
| layer 2 |
+-----------+
v
......
v
+-----------+
| stage |
| layer n |
+-----------+
In P5 model, n=4
In P6 model, n=5
Args:
arch_setting (list): Architecture of BaseBackbone.
plugins (list[dict]): List of plugins for stages, each dict contains:
- cfg (dict, required): Cfg dict to build plugin.
- stages (tuple[bool], optional): Stages to apply plugin, length
should be same as 'num_stages'.
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
input_channels: Number of input image channels. Defaults to 3.
out_indices (Sequence[int]): Output from which stages.
Defaults to (2, 3, 4).
frozen_stages (int): Stages to be frozen (stop grad and set eval
mode). -1 means not freezing any parameters. Defaults to -1.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to None.
act_cfg (dict): Config dict for activation layer.
Defaults to None.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
arch_setting: list,
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
input_channels: int = 3,
out_indices: Sequence[int] = (2, 3, 4),
frozen_stages: int = -1,
plugins: Union[dict, List[dict]] = None,
norm_cfg: ConfigType = None,
act_cfg: ConfigType = None,
norm_eval: bool = False,
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg)
self.num_stages = len(arch_setting)
self.arch_setting = arch_setting
assert set(out_indices).issubset(
i for i in range(len(arch_setting) + 1))
if frozen_stages not in range(-1, len(arch_setting) + 1):
raise ValueError('"frozen_stages" must be in range(-1, '
'len(arch_setting) + 1). But received '
f'{frozen_stages}')
self.input_channels = input_channels
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.widen_factor = widen_factor
self.deepen_factor = deepen_factor
self.norm_eval = norm_eval
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.plugins = plugins
self.stem = self.build_stem_layer()
self.layers = ['stem']
for idx, setting in enumerate(arch_setting):
stage = []
stage += self.build_stage_layer(idx, setting)
if plugins is not None:
stage += self.make_stage_plugins(plugins, idx, setting)
self.add_module(f'stage{idx + 1}', nn.Sequential(*stage))
self.layers.append(f'stage{idx + 1}')
@abstractmethod
def build_stem_layer(self):
"""Build a stem layer."""
pass
@abstractmethod
def build_stage_layer(self, stage_idx: int, setting: list):
"""Build a stage layer.
Args:
stage_idx (int): The index of a stage layer.
setting (list): The architecture setting of a stage layer.
"""
pass
def make_stage_plugins(self, plugins, stage_idx, setting):
"""Make plugins for backbone ``stage_idx`` th stage.
Currently we support to insert ``context_block``,
``empirical_attention_block``, ``nonlocal_block``, ``dropout_block``
into the backbone.
An example of plugins format could be:
Examples:
>>> plugins=[
... dict(cfg=dict(type='xxx', arg1='xxx'),
... stages=(False, True, True, True)),
... dict(cfg=dict(type='yyy'),
... stages=(True, True, True, True)),
... ]
>>> model = YOLOv5CSPDarknet()
>>> stage_plugins = model.make_stage_plugins(plugins, 0, setting)
>>> assert len(stage_plugins) == 1
Suppose ``stage_idx=0``, the structure of blocks in the stage would be:
.. code-block:: none
conv1 -> conv2 -> conv3 -> yyy
Suppose ``stage_idx=1``, the structure of blocks in the stage would be:
.. code-block:: none
conv1 -> conv2 -> conv3 -> xxx -> yyy
Args:
plugins (list[dict]): List of plugins cfg to build. The postfix is
required if multiple same type plugins are inserted.
stage_idx (int): Index of stage to build
If stages is missing, the plugin would be applied to all
stages.
setting (list): The architecture setting of a stage layer.
Returns:
list[nn.Module]: Plugins for current stage
"""
# TODO: It is not general enough to support any channel and needs
# to be refactored
in_channels = int(setting[1] * self.widen_factor)
plugin_layers = []
for plugin in plugins:
plugin = plugin.copy()
stages = plugin.pop('stages', None)
assert stages is None or len(stages) == self.num_stages
if stages is None or stages[stage_idx]:
name, layer = build_plugin_layer(
plugin['cfg'], in_channels=in_channels)
plugin_layers.append(layer)
return plugin_layers
def _freeze_stages(self):
"""Freeze the parameters of the specified stage so that they are no
longer updated."""
if self.frozen_stages >= 0:
for i in range(self.frozen_stages + 1):
m = getattr(self, self.layers[i])
m.eval()
for param in m.parameters():
param.requires_grad = False
def train(self, mode: bool = True):
"""Convert the model into training mode while keep normalization layer
frozen."""
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
def forward(self, x: torch.Tensor) -> tuple:
"""Forward batch_inputs from the data_preprocessor."""
outs = []
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
|