Spaces:
Configuration error
Configuration error
"""senet in pytorch | |
[1] Jie Hu, Li Shen, Samuel Albanie, Gang Sun, Enhua Wu | |
Squeeze-and-Excitation Networks | |
https://arxiv.org/abs/1709.01507 | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class BasicResidualSEBlock(nn.Module): | |
expansion = 1 | |
def __init__(self, in_channels, out_channels, stride, r=16): | |
super().__init__() | |
self.residual = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_channels, out_channels * self.expansion, 3, padding=1), | |
nn.BatchNorm2d(out_channels * self.expansion), | |
nn.ReLU(inplace=True) | |
) | |
self.shortcut = nn.Sequential() | |
if stride != 1 or in_channels != out_channels * self.expansion: | |
self.shortcut = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels * self.expansion, 1, stride=stride), | |
nn.BatchNorm2d(out_channels * self.expansion) | |
) | |
self.squeeze = nn.AdaptiveAvgPool2d(1) | |
self.excitation = nn.Sequential( | |
nn.Linear(out_channels * self.expansion, out_channels * self.expansion // r), | |
nn.ReLU(inplace=True), | |
nn.Linear(out_channels * self.expansion // r, out_channels * self.expansion), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
shortcut = self.shortcut(x) | |
residual = self.residual(x) | |
squeeze = self.squeeze(residual) | |
squeeze = squeeze.view(squeeze.size(0), -1) | |
excitation = self.excitation(squeeze) | |
excitation = excitation.view(residual.size(0), residual.size(1), 1, 1) | |
x = residual * excitation.expand_as(residual) + shortcut | |
return F.relu(x) | |
class BottleneckResidualSEBlock(nn.Module): | |
expansion = 4 | |
def __init__(self, in_channels, out_channels, stride, r=16): | |
super().__init__() | |
self.residual = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, 1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_channels, out_channels, 3, stride=stride, padding=1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_channels, out_channels * self.expansion, 1), | |
nn.BatchNorm2d(out_channels * self.expansion), | |
nn.ReLU(inplace=True) | |
) | |
self.squeeze = nn.AdaptiveAvgPool2d(1) | |
self.excitation = nn.Sequential( | |
nn.Linear(out_channels * self.expansion, out_channels * self.expansion // r), | |
nn.ReLU(inplace=True), | |
nn.Linear(out_channels * self.expansion // r, out_channels * self.expansion), | |
nn.Sigmoid() | |
) | |
self.shortcut = nn.Sequential() | |
if stride != 1 or in_channels != out_channels * self.expansion: | |
self.shortcut = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels * self.expansion, 1, stride=stride), | |
nn.BatchNorm2d(out_channels * self.expansion) | |
) | |
def forward(self, x): | |
shortcut = self.shortcut(x) | |
residual = self.residual(x) | |
squeeze = self.squeeze(residual) | |
squeeze = squeeze.view(squeeze.size(0), -1) | |
excitation = self.excitation(squeeze) | |
excitation = excitation.view(residual.size(0), residual.size(1), 1, 1) | |
x = residual * excitation.expand_as(residual) + shortcut | |
return F.relu(x) | |
class SEResNet(nn.Module): | |
def __init__(self, block, block_num, class_num=1): | |
super().__init__() | |
self.in_channels = 64 | |
self.pre = nn.Sequential( | |
nn.Conv2d(3, 64, 3, padding=1), | |
nn.BatchNorm2d(64), | |
nn.ReLU(inplace=True) | |
) | |
self.stage1 = self._make_stage(block, block_num[0], 64, 1) | |
self.stage2 = self._make_stage(block, block_num[1], 128, 2) | |
self.stage3 = self._make_stage(block, block_num[2], 256, 2) | |
self.stage4 = self._make_stage(block, block_num[3], 516, 2) | |
self.linear = nn.Linear(self.in_channels, class_num) | |
def forward(self, x): | |
x = self.pre(x) | |
x = self.stage1(x) | |
x = self.stage2(x) | |
x = self.stage3(x) | |
x = self.stage4(x) | |
x = F.adaptive_avg_pool2d(x, 1) | |
x = x.view(x.size(0), -1) | |
x = self.linear(x) | |
return x | |
def _make_stage(self, block, num, out_channels, stride): | |
layers = [] | |
layers.append(block(self.in_channels, out_channels, stride)) | |
self.in_channels = out_channels * block.expansion | |
while num - 1: | |
layers.append(block(self.in_channels, out_channels, 1)) | |
num -= 1 | |
return nn.Sequential(*layers) | |
def seresnet18(): | |
return SEResNet(BasicResidualSEBlock, [2, 2, 2, 2]) | |
def seresnet34(): | |
return SEResNet(BasicResidualSEBlock, [3, 4, 6, 3]) | |
def seresnet50(): | |
return SEResNet(BottleneckResidualSEBlock, [3, 4, 6, 3]) | |
def seresnet101(): | |
return SEResNet(BottleneckResidualSEBlock, [3, 4, 23, 3]) | |
def seresnet152(): | |
return SEResNet(BottleneckResidualSEBlock, [3, 8, 36, 3]) |