"""senet in pytorch [1] Jie Hu, Li Shen, Samuel Albanie, Gang Sun, Enhua Wu Squeeze-and-Excitation Networks https://arxiv.org/abs/1709.01507 """ import torch import torch.nn as nn import torch.nn.functional as F class BasicResidualSEBlock(nn.Module): expansion = 1 def __init__(self, in_channels, out_channels, stride, r=16): super().__init__() self.residual = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels * self.expansion, 3, padding=1), nn.BatchNorm2d(out_channels * self.expansion), nn.ReLU(inplace=True) ) self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels * self.expansion: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels * self.expansion, 1, stride=stride), nn.BatchNorm2d(out_channels * self.expansion) ) self.squeeze = nn.AdaptiveAvgPool2d(1) self.excitation = nn.Sequential( nn.Linear(out_channels * self.expansion, out_channels * self.expansion // r), nn.ReLU(inplace=True), nn.Linear(out_channels * self.expansion // r, out_channels * self.expansion), nn.Sigmoid() ) def forward(self, x): shortcut = self.shortcut(x) residual = self.residual(x) squeeze = self.squeeze(residual) squeeze = squeeze.view(squeeze.size(0), -1) excitation = self.excitation(squeeze) excitation = excitation.view(residual.size(0), residual.size(1), 1, 1) x = residual * excitation.expand_as(residual) + shortcut return F.relu(x) class BottleneckResidualSEBlock(nn.Module): expansion = 4 def __init__(self, in_channels, out_channels, stride, r=16): super().__init__() self.residual = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, 3, stride=stride, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels * self.expansion, 1), nn.BatchNorm2d(out_channels * self.expansion), nn.ReLU(inplace=True) ) self.squeeze = nn.AdaptiveAvgPool2d(1) self.excitation = nn.Sequential( nn.Linear(out_channels * self.expansion, out_channels * self.expansion // r), nn.ReLU(inplace=True), nn.Linear(out_channels * self.expansion // r, out_channels * self.expansion), nn.Sigmoid() ) self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels * self.expansion: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels * self.expansion, 1, stride=stride), nn.BatchNorm2d(out_channels * self.expansion) ) def forward(self, x): shortcut = self.shortcut(x) residual = self.residual(x) squeeze = self.squeeze(residual) squeeze = squeeze.view(squeeze.size(0), -1) excitation = self.excitation(squeeze) excitation = excitation.view(residual.size(0), residual.size(1), 1, 1) x = residual * excitation.expand_as(residual) + shortcut return F.relu(x) class SEResNet(nn.Module): def __init__(self, block, block_num, class_num=1): super().__init__() self.in_channels = 64 self.pre = nn.Sequential( nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True) ) self.stage1 = self._make_stage(block, block_num[0], 64, 1) self.stage2 = self._make_stage(block, block_num[1], 128, 2) self.stage3 = self._make_stage(block, block_num[2], 256, 2) self.stage4 = self._make_stage(block, block_num[3], 516, 2) self.linear = nn.Linear(self.in_channels, class_num) def forward(self, x): x = self.pre(x) x = self.stage1(x) x = self.stage2(x) x = self.stage3(x) x = self.stage4(x) x = F.adaptive_avg_pool2d(x, 1) x = x.view(x.size(0), -1) x = self.linear(x) return x def _make_stage(self, block, num, out_channels, stride): layers = [] layers.append(block(self.in_channels, out_channels, stride)) self.in_channels = out_channels * block.expansion while num - 1: layers.append(block(self.in_channels, out_channels, 1)) num -= 1 return nn.Sequential(*layers) def seresnet18(): return SEResNet(BasicResidualSEBlock, [2, 2, 2, 2]) def seresnet34(): return SEResNet(BasicResidualSEBlock, [3, 4, 6, 3]) def seresnet50(): return SEResNet(BottleneckResidualSEBlock, [3, 4, 6, 3]) def seresnet101(): return SEResNet(BottleneckResidualSEBlock, [3, 4, 23, 3]) def seresnet152(): return SEResNet(BottleneckResidualSEBlock, [3, 8, 36, 3])