Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch.nn as nn | |
from mmcv.cnn import build_conv_layer, build_norm_layer | |
from mmengine.model import BaseModule, ModuleList, Sequential | |
from torch.nn.modules.batchnorm import _BatchNorm | |
from mmpretrain.registry import MODELS | |
from .resnet import BasicBlock, Bottleneck, ResLayer, get_expansion | |
class HRModule(BaseModule): | |
"""High-Resolution Module for HRNet. | |
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange | |
is in this module. | |
Args: | |
num_branches (int): The number of branches. | |
block (``BaseModule``): Convolution block module. | |
num_blocks (tuple): The number of blocks in each branch. | |
The length must be equal to ``num_branches``. | |
num_channels (tuple): The number of base channels in each branch. | |
The length must be equal to ``num_branches``. | |
multiscale_output (bool): Whether to output multi-level features | |
produced by multiple branches. If False, only the first level | |
feature will be output. Defaults to True. | |
with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
memory while slowing down the training speed. Defaults to False. | |
conv_cfg (dict, optional): Dictionary to construct and config conv | |
layer. Defaults to None. | |
norm_cfg (dict): Dictionary to construct and config norm layer. | |
Defaults to ``dict(type='BN')``. | |
block_init_cfg (dict, optional): The initialization configs of every | |
blocks. Defaults to None. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Defaults to None. | |
""" | |
def __init__(self, | |
num_branches, | |
block, | |
num_blocks, | |
in_channels, | |
num_channels, | |
multiscale_output=True, | |
with_cp=False, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
block_init_cfg=None, | |
init_cfg=None): | |
super(HRModule, self).__init__(init_cfg) | |
self.block_init_cfg = block_init_cfg | |
self._check_branches(num_branches, num_blocks, in_channels, | |
num_channels) | |
self.in_channels = in_channels | |
self.num_branches = num_branches | |
self.multiscale_output = multiscale_output | |
self.norm_cfg = norm_cfg | |
self.conv_cfg = conv_cfg | |
self.with_cp = with_cp | |
self.branches = self._make_branches(num_branches, block, num_blocks, | |
num_channels) | |
self.fuse_layers = self._make_fuse_layers() | |
self.relu = nn.ReLU(inplace=False) | |
def _check_branches(self, num_branches, num_blocks, in_channels, | |
num_channels): | |
if num_branches != len(num_blocks): | |
error_msg = f'NUM_BRANCHES({num_branches}) ' \ | |
f'!= NUM_BLOCKS({len(num_blocks)})' | |
raise ValueError(error_msg) | |
if num_branches != len(num_channels): | |
error_msg = f'NUM_BRANCHES({num_branches}) ' \ | |
f'!= NUM_CHANNELS({len(num_channels)})' | |
raise ValueError(error_msg) | |
if num_branches != len(in_channels): | |
error_msg = f'NUM_BRANCHES({num_branches}) ' \ | |
f'!= NUM_INCHANNELS({len(in_channels)})' | |
raise ValueError(error_msg) | |
def _make_branches(self, num_branches, block, num_blocks, num_channels): | |
branches = [] | |
for i in range(num_branches): | |
out_channels = num_channels[i] * get_expansion(block) | |
branches.append( | |
ResLayer( | |
block=block, | |
num_blocks=num_blocks[i], | |
in_channels=self.in_channels[i], | |
out_channels=out_channels, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
with_cp=self.with_cp, | |
init_cfg=self.block_init_cfg, | |
)) | |
return ModuleList(branches) | |
def _make_fuse_layers(self): | |
if self.num_branches == 1: | |
return None | |
num_branches = self.num_branches | |
in_channels = self.in_channels | |
fuse_layers = [] | |
num_out_branches = num_branches if self.multiscale_output else 1 | |
for i in range(num_out_branches): | |
fuse_layer = [] | |
for j in range(num_branches): | |
if j > i: | |
# Upsample the feature maps of smaller scales. | |
fuse_layer.append( | |
nn.Sequential( | |
build_conv_layer( | |
self.conv_cfg, | |
in_channels[j], | |
in_channels[i], | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False), | |
build_norm_layer(self.norm_cfg, in_channels[i])[1], | |
nn.Upsample( | |
scale_factor=2**(j - i), mode='nearest'))) | |
elif j == i: | |
# Keep the feature map with the same scale. | |
fuse_layer.append(None) | |
else: | |
# Downsample the feature maps of larger scales. | |
conv_downsamples = [] | |
for k in range(i - j): | |
# Use stacked convolution layers to downsample. | |
if k == i - j - 1: | |
conv_downsamples.append( | |
nn.Sequential( | |
build_conv_layer( | |
self.conv_cfg, | |
in_channels[j], | |
in_channels[i], | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
bias=False), | |
build_norm_layer(self.norm_cfg, | |
in_channels[i])[1])) | |
else: | |
conv_downsamples.append( | |
nn.Sequential( | |
build_conv_layer( | |
self.conv_cfg, | |
in_channels[j], | |
in_channels[j], | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
bias=False), | |
build_norm_layer(self.norm_cfg, | |
in_channels[j])[1], | |
nn.ReLU(inplace=False))) | |
fuse_layer.append(nn.Sequential(*conv_downsamples)) | |
fuse_layers.append(nn.ModuleList(fuse_layer)) | |
return nn.ModuleList(fuse_layers) | |
def forward(self, x): | |
"""Forward function.""" | |
if self.num_branches == 1: | |
return [self.branches[0](x[0])] | |
for i in range(self.num_branches): | |
x[i] = self.branches[i](x[i]) | |
x_fuse = [] | |
for i in range(len(self.fuse_layers)): | |
y = 0 | |
for j in range(self.num_branches): | |
if i == j: | |
y += x[j] | |
else: | |
y += self.fuse_layers[i][j](x[j]) | |
x_fuse.append(self.relu(y)) | |
return x_fuse | |
class HRNet(BaseModule): | |
"""HRNet backbone. | |
`High-Resolution Representations for Labeling Pixels and Regions | |
<https://arxiv.org/abs/1904.04514>`_. | |
Args: | |
arch (str): The preset HRNet architecture, includes 'w18', 'w30', | |
'w32', 'w40', 'w44', 'w48', 'w64'. It will only be used if | |
extra is ``None``. Defaults to 'w32'. | |
extra (dict, optional): Detailed configuration for each stage of HRNet. | |
There must be 4 stages, the configuration for each stage must have | |
5 keys: | |
- num_modules (int): The number of HRModule in this stage. | |
- num_branches (int): The number of branches in the HRModule. | |
- block (str): The type of convolution block. Please choose between | |
'BOTTLENECK' and 'BASIC'. | |
- num_blocks (tuple): The number of blocks in each branch. | |
The length must be equal to num_branches. | |
- num_channels (tuple): The number of base channels in each branch. | |
The length must be equal to num_branches. | |
Defaults to None. | |
in_channels (int): Number of input image channels. Defaults to 3. | |
conv_cfg (dict, optional): Dictionary to construct and config conv | |
layer. Defaults to None. | |
norm_cfg (dict): Dictionary to construct and config norm layer. | |
Defaults to ``dict(type='BN')``. | |
norm_eval (bool): Whether to set norm layers to eval mode, namely, | |
freeze running stats (mean and var). Note: Effect on Batch Norm | |
and its variants only. Defaults to False. | |
with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
memory while slowing down the training speed. Defaults to False. | |
zero_init_residual (bool): Whether to use zero init for last norm layer | |
in resblocks to let them behave as identity. Defaults to False. | |
multiscale_output (bool): Whether to output multi-level features | |
produced by multiple branches. If False, only the first level | |
feature will be output. Defaults to True. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Defaults to None. | |
Example: | |
>>> import torch | |
>>> from mmpretrain.models import HRNet | |
>>> extra = dict( | |
>>> stage1=dict( | |
>>> num_modules=1, | |
>>> num_branches=1, | |
>>> block='BOTTLENECK', | |
>>> num_blocks=(4, ), | |
>>> num_channels=(64, )), | |
>>> stage2=dict( | |
>>> num_modules=1, | |
>>> num_branches=2, | |
>>> block='BASIC', | |
>>> num_blocks=(4, 4), | |
>>> num_channels=(32, 64)), | |
>>> stage3=dict( | |
>>> num_modules=4, | |
>>> num_branches=3, | |
>>> block='BASIC', | |
>>> num_blocks=(4, 4, 4), | |
>>> num_channels=(32, 64, 128)), | |
>>> stage4=dict( | |
>>> num_modules=3, | |
>>> num_branches=4, | |
>>> block='BASIC', | |
>>> num_blocks=(4, 4, 4, 4), | |
>>> num_channels=(32, 64, 128, 256))) | |
>>> self = HRNet(extra, in_channels=1) | |
>>> self.eval() | |
>>> inputs = torch.rand(1, 1, 32, 32) | |
>>> level_outputs = self.forward(inputs) | |
>>> for level_out in level_outputs: | |
... print(tuple(level_out.shape)) | |
(1, 32, 8, 8) | |
(1, 64, 4, 4) | |
(1, 128, 2, 2) | |
(1, 256, 1, 1) | |
""" | |
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} | |
arch_zoo = { | |
# num_modules, num_branches, block, num_blocks, num_channels | |
'w18': [[1, 1, 'BOTTLENECK', (4, ), (64, )], | |
[1, 2, 'BASIC', (4, 4), (18, 36)], | |
[4, 3, 'BASIC', (4, 4, 4), (18, 36, 72)], | |
[3, 4, 'BASIC', (4, 4, 4, 4), (18, 36, 72, 144)]], | |
'w30': [[1, 1, 'BOTTLENECK', (4, ), (64, )], | |
[1, 2, 'BASIC', (4, 4), (30, 60)], | |
[4, 3, 'BASIC', (4, 4, 4), (30, 60, 120)], | |
[3, 4, 'BASIC', (4, 4, 4, 4), (30, 60, 120, 240)]], | |
'w32': [[1, 1, 'BOTTLENECK', (4, ), (64, )], | |
[1, 2, 'BASIC', (4, 4), (32, 64)], | |
[4, 3, 'BASIC', (4, 4, 4), (32, 64, 128)], | |
[3, 4, 'BASIC', (4, 4, 4, 4), (32, 64, 128, 256)]], | |
'w40': [[1, 1, 'BOTTLENECK', (4, ), (64, )], | |
[1, 2, 'BASIC', (4, 4), (40, 80)], | |
[4, 3, 'BASIC', (4, 4, 4), (40, 80, 160)], | |
[3, 4, 'BASIC', (4, 4, 4, 4), (40, 80, 160, 320)]], | |
'w44': [[1, 1, 'BOTTLENECK', (4, ), (64, )], | |
[1, 2, 'BASIC', (4, 4), (44, 88)], | |
[4, 3, 'BASIC', (4, 4, 4), (44, 88, 176)], | |
[3, 4, 'BASIC', (4, 4, 4, 4), (44, 88, 176, 352)]], | |
'w48': [[1, 1, 'BOTTLENECK', (4, ), (64, )], | |
[1, 2, 'BASIC', (4, 4), (48, 96)], | |
[4, 3, 'BASIC', (4, 4, 4), (48, 96, 192)], | |
[3, 4, 'BASIC', (4, 4, 4, 4), (48, 96, 192, 384)]], | |
'w64': [[1, 1, 'BOTTLENECK', (4, ), (64, )], | |
[1, 2, 'BASIC', (4, 4), (64, 128)], | |
[4, 3, 'BASIC', (4, 4, 4), (64, 128, 256)], | |
[3, 4, 'BASIC', (4, 4, 4, 4), (64, 128, 256, 512)]], | |
} # yapf:disable | |
def __init__(self, | |
arch='w32', | |
extra=None, | |
in_channels=3, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
norm_eval=False, | |
with_cp=False, | |
zero_init_residual=False, | |
multiscale_output=True, | |
init_cfg=[ | |
dict(type='Kaiming', layer='Conv2d'), | |
dict( | |
type='Constant', | |
val=1, | |
layer=['_BatchNorm', 'GroupNorm']) | |
]): | |
super(HRNet, self).__init__(init_cfg) | |
extra = self.parse_arch(arch, extra) | |
# Assert configurations of 4 stages are in extra | |
for i in range(1, 5): | |
assert f'stage{i}' in extra, f'Missing stage{i} config in "extra".' | |
# Assert whether the length of `num_blocks` and `num_channels` are | |
# equal to `num_branches` | |
cfg = extra[f'stage{i}'] | |
assert len(cfg['num_blocks']) == cfg['num_branches'] and \ | |
len(cfg['num_channels']) == cfg['num_branches'] | |
self.extra = extra | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.norm_eval = norm_eval | |
self.with_cp = with_cp | |
self.zero_init_residual = zero_init_residual | |
# -------------------- stem net -------------------- | |
self.conv1 = build_conv_layer( | |
self.conv_cfg, | |
in_channels, | |
out_channels=64, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
bias=False) | |
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) | |
self.add_module(self.norm1_name, norm1) | |
self.conv2 = build_conv_layer( | |
self.conv_cfg, | |
in_channels=64, | |
out_channels=64, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
bias=False) | |
self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2) | |
self.add_module(self.norm2_name, norm2) | |
self.relu = nn.ReLU(inplace=True) | |
# -------------------- stage 1 -------------------- | |
self.stage1_cfg = self.extra['stage1'] | |
base_channels = self.stage1_cfg['num_channels'] | |
block_type = self.stage1_cfg['block'] | |
num_blocks = self.stage1_cfg['num_blocks'] | |
block = self.blocks_dict[block_type] | |
num_channels = [ | |
channel * get_expansion(block) for channel in base_channels | |
] | |
# To align with the original code, use layer1 instead of stage1 here. | |
self.layer1 = ResLayer( | |
block, | |
in_channels=64, | |
out_channels=num_channels[0], | |
num_blocks=num_blocks[0]) | |
pre_num_channels = num_channels | |
# -------------------- stage 2~4 -------------------- | |
for i in range(2, 5): | |
stage_cfg = self.extra[f'stage{i}'] | |
base_channels = stage_cfg['num_channels'] | |
block = self.blocks_dict[stage_cfg['block']] | |
multiscale_output_ = multiscale_output if i == 4 else True | |
num_channels = [ | |
channel * get_expansion(block) for channel in base_channels | |
] | |
# The transition layer from layer1 to stage2 | |
transition = self._make_transition_layer(pre_num_channels, | |
num_channels) | |
self.add_module(f'transition{i-1}', transition) | |
stage = self._make_stage( | |
stage_cfg, num_channels, multiscale_output=multiscale_output_) | |
self.add_module(f'stage{i}', stage) | |
pre_num_channels = num_channels | |
def norm1(self): | |
"""nn.Module: the normalization layer named "norm1" """ | |
return getattr(self, self.norm1_name) | |
def norm2(self): | |
"""nn.Module: the normalization layer named "norm2" """ | |
return getattr(self, self.norm2_name) | |
def _make_transition_layer(self, num_channels_pre_layer, | |
num_channels_cur_layer): | |
num_branches_cur = len(num_channels_cur_layer) | |
num_branches_pre = len(num_channels_pre_layer) | |
transition_layers = [] | |
for i in range(num_branches_cur): | |
if i < num_branches_pre: | |
# For existing scale branches, | |
# add conv block when the channels are not the same. | |
if num_channels_cur_layer[i] != num_channels_pre_layer[i]: | |
transition_layers.append( | |
nn.Sequential( | |
build_conv_layer( | |
self.conv_cfg, | |
num_channels_pre_layer[i], | |
num_channels_cur_layer[i], | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=False), | |
build_norm_layer(self.norm_cfg, | |
num_channels_cur_layer[i])[1], | |
nn.ReLU(inplace=True))) | |
else: | |
transition_layers.append(nn.Identity()) | |
else: | |
# For new scale branches, add stacked downsample conv blocks. | |
# For example, num_branches_pre = 2, for the 4th branch, add | |
# stacked two downsample conv blocks. | |
conv_downsamples = [] | |
for j in range(i + 1 - num_branches_pre): | |
in_channels = num_channels_pre_layer[-1] | |
out_channels = num_channels_cur_layer[i] \ | |
if j == i - num_branches_pre else in_channels | |
conv_downsamples.append( | |
nn.Sequential( | |
build_conv_layer( | |
self.conv_cfg, | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
bias=False), | |
build_norm_layer(self.norm_cfg, out_channels)[1], | |
nn.ReLU(inplace=True))) | |
transition_layers.append(nn.Sequential(*conv_downsamples)) | |
return nn.ModuleList(transition_layers) | |
def _make_stage(self, layer_config, in_channels, multiscale_output=True): | |
num_modules = layer_config['num_modules'] | |
num_branches = layer_config['num_branches'] | |
num_blocks = layer_config['num_blocks'] | |
num_channels = layer_config['num_channels'] | |
block = self.blocks_dict[layer_config['block']] | |
hr_modules = [] | |
block_init_cfg = None | |
if self.zero_init_residual: | |
if block is BasicBlock: | |
block_init_cfg = dict( | |
type='Constant', val=0, override=dict(name='norm2')) | |
elif block is Bottleneck: | |
block_init_cfg = dict( | |
type='Constant', val=0, override=dict(name='norm3')) | |
for i in range(num_modules): | |
# multi_scale_output is only used for the last module | |
if not multiscale_output and i == num_modules - 1: | |
reset_multiscale_output = False | |
else: | |
reset_multiscale_output = True | |
hr_modules.append( | |
HRModule( | |
num_branches, | |
block, | |
num_blocks, | |
in_channels, | |
num_channels, | |
reset_multiscale_output, | |
with_cp=self.with_cp, | |
norm_cfg=self.norm_cfg, | |
conv_cfg=self.conv_cfg, | |
block_init_cfg=block_init_cfg)) | |
return Sequential(*hr_modules) | |
def forward(self, x): | |
"""Forward function.""" | |
x = self.conv1(x) | |
x = self.norm1(x) | |
x = self.relu(x) | |
x = self.conv2(x) | |
x = self.norm2(x) | |
x = self.relu(x) | |
x = self.layer1(x) | |
x_list = [x] | |
for i in range(2, 5): | |
# Apply transition | |
transition = getattr(self, f'transition{i-1}') | |
inputs = [] | |
for j, layer in enumerate(transition): | |
if j < len(x_list): | |
inputs.append(layer(x_list[j])) | |
else: | |
inputs.append(layer(x_list[-1])) | |
# Forward HRModule | |
stage = getattr(self, f'stage{i}') | |
x_list = stage(inputs) | |
return tuple(x_list) | |
def train(self, mode=True): | |
"""Convert the model into training mode will keeping the normalization | |
layer freezed.""" | |
super(HRNet, self).train(mode) | |
if mode and self.norm_eval: | |
for m in self.modules(): | |
# trick: eval have effect on BatchNorm only | |
if isinstance(m, _BatchNorm): | |
m.eval() | |
def parse_arch(self, arch, extra=None): | |
if extra is not None: | |
return extra | |
assert arch in self.arch_zoo, \ | |
('Invalid arch, please choose arch from ' | |
f'{list(self.arch_zoo.keys())}, or specify `extra` ' | |
'argument directly.') | |
extra = dict() | |
for i, stage_setting in enumerate(self.arch_zoo[arch], start=1): | |
extra[f'stage{i}'] = dict( | |
num_modules=stage_setting[0], | |
num_branches=stage_setting[1], | |
block=stage_setting[2], | |
num_blocks=stage_setting[3], | |
num_channels=stage_setting[4], | |
) | |
return extra | |