Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch.nn as nn | |
from mmcv.cnn import build_plugin_layer | |
def conv3x3(in_planes, out_planes, stride=1): | |
"""3x3 convolution with padding.""" | |
return nn.Conv2d( | |
in_planes, | |
out_planes, | |
kernel_size=3, | |
stride=stride, | |
padding=1, | |
bias=False) | |
def conv1x1(in_planes, out_planes): | |
"""1x1 convolution with padding.""" | |
return nn.Conv2d( | |
in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) | |
class BasicBlock(nn.Module): | |
expansion = 1 | |
def __init__(self, | |
inplanes, | |
planes, | |
stride=1, | |
downsample=None, | |
use_conv1x1=False, | |
plugins=None): | |
super().__init__() | |
if use_conv1x1: | |
self.conv1 = conv1x1(inplanes, planes) | |
self.conv2 = conv3x3(planes, planes * self.expansion, stride) | |
else: | |
self.conv1 = conv3x3(inplanes, planes, stride) | |
self.conv2 = conv3x3(planes, planes * self.expansion) | |
self.with_plugins = False | |
if plugins: | |
if isinstance(plugins, dict): | |
plugins = [plugins] | |
self.with_plugins = True | |
# collect plugins for conv1/conv2/ | |
self.before_conv1_plugin = [ | |
plugin['cfg'] for plugin in plugins | |
if plugin['position'] == 'before_conv1' | |
] | |
self.after_conv1_plugin = [ | |
plugin['cfg'] for plugin in plugins | |
if plugin['position'] == 'after_conv1' | |
] | |
self.after_conv2_plugin = [ | |
plugin['cfg'] for plugin in plugins | |
if plugin['position'] == 'after_conv2' | |
] | |
self.after_shortcut_plugin = [ | |
plugin['cfg'] for plugin in plugins | |
if plugin['position'] == 'after_shortcut' | |
] | |
self.planes = planes | |
self.bn1 = nn.BatchNorm2d(planes) | |
self.relu = nn.ReLU(inplace=True) | |
self.bn2 = nn.BatchNorm2d(planes * self.expansion) | |
self.downsample = downsample | |
self.stride = stride | |
if self.with_plugins: | |
self.before_conv1_plugin_names = self.make_block_plugins( | |
inplanes, self.before_conv1_plugin) | |
self.after_conv1_plugin_names = self.make_block_plugins( | |
planes, self.after_conv1_plugin) | |
self.after_conv2_plugin_names = self.make_block_plugins( | |
planes, self.after_conv2_plugin) | |
self.after_shortcut_plugin_names = self.make_block_plugins( | |
planes, self.after_shortcut_plugin) | |
def make_block_plugins(self, in_channels, plugins): | |
"""make plugins for block. | |
Args: | |
in_channels (int): Input channels of plugin. | |
plugins (list[dict]): List of plugins cfg to build. | |
Returns: | |
list[str]: List of the names of plugin. | |
""" | |
assert isinstance(plugins, list) | |
plugin_names = [] | |
for plugin in plugins: | |
plugin = plugin.copy() | |
name, layer = build_plugin_layer( | |
plugin, | |
in_channels=in_channels, | |
out_channels=in_channels, | |
postfix=plugin.pop('postfix', '')) | |
assert not hasattr(self, name), f'duplicate plugin {name}' | |
self.add_module(name, layer) | |
plugin_names.append(name) | |
return plugin_names | |
def forward_plugin(self, x, plugin_names): | |
out = x | |
for name in plugin_names: | |
out = getattr(self, name)(x) | |
return out | |
def forward(self, x): | |
if self.with_plugins: | |
x = self.forward_plugin(x, self.before_conv1_plugin_names) | |
residual = x | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
if self.with_plugins: | |
out = self.forward_plugin(out, self.after_conv1_plugin_names) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
if self.with_plugins: | |
out = self.forward_plugin(out, self.after_conv2_plugin_names) | |
if self.downsample is not None: | |
residual = self.downsample(x) | |
out += residual | |
out = self.relu(out) | |
if self.with_plugins: | |
out = self.forward_plugin(out, self.after_shortcut_plugin_names) | |
return out | |
class Bottleneck(nn.Module): | |
expansion = 4 | |
def __init__(self, inplanes, planes, stride=1, downsample=False): | |
super().__init__() | |
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(planes) | |
self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, bias=False) | |
self.bn2 = nn.BatchNorm2d(planes) | |
self.conv3 = nn.Conv2d( | |
planes, planes * self.expansion, kernel_size=1, bias=False) | |
self.bn3 = nn.BatchNorm2d(planes * self.expansion) | |
self.relu = nn.ReLU(inplace=True) | |
if downsample: | |
self.downsample = nn.Sequential( | |
nn.Conv2d( | |
inplanes, planes * self.expansion, 1, stride, bias=False), | |
nn.BatchNorm2d(planes * self.expansion), | |
) | |
else: | |
self.downsample = nn.Sequential() | |
def forward(self, x): | |
residual = self.downsample(x) | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
out = self.relu(out) | |
out = self.conv3(out) | |
out = self.bn3(out) | |
out += residual | |
out = self.relu(out) | |
return out | |