|
import torch |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
import math |
|
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d |
|
import torch.utils.model_zoo as model_zoo |
|
|
|
def conv_bn(inp, oup, stride, BatchNorm): |
|
return nn.Sequential( |
|
nn.Conv2d(inp, oup, 3, stride, 1, bias=False), |
|
BatchNorm(oup), |
|
nn.ReLU6(inplace=True) |
|
) |
|
|
|
|
|
def fixed_padding(inputs, kernel_size, dilation): |
|
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) |
|
pad_total = kernel_size_effective - 1 |
|
pad_beg = pad_total // 2 |
|
pad_end = pad_total - pad_beg |
|
padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) |
|
return padded_inputs |
|
|
|
|
|
class InvertedResidual(nn.Module): |
|
def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): |
|
super(InvertedResidual, self).__init__() |
|
self.stride = stride |
|
assert stride in [1, 2] |
|
|
|
hidden_dim = round(inp * expand_ratio) |
|
self.use_res_connect = self.stride == 1 and inp == oup |
|
self.kernel_size = 3 |
|
self.dilation = dilation |
|
|
|
if expand_ratio == 1: |
|
self.conv = nn.Sequential( |
|
|
|
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), |
|
BatchNorm(hidden_dim), |
|
nn.ReLU6(inplace=True), |
|
|
|
nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), |
|
BatchNorm(oup), |
|
) |
|
else: |
|
self.conv = nn.Sequential( |
|
|
|
nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), |
|
BatchNorm(hidden_dim), |
|
nn.ReLU6(inplace=True), |
|
|
|
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), |
|
BatchNorm(hidden_dim), |
|
nn.ReLU6(inplace=True), |
|
|
|
nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), |
|
BatchNorm(oup), |
|
) |
|
|
|
def forward(self, x): |
|
x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) |
|
if self.use_res_connect: |
|
x = x + self.conv(x_pad) |
|
else: |
|
x = self.conv(x_pad) |
|
return x |
|
|
|
|
|
class MobileNetV2(nn.Module): |
|
def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True): |
|
super(MobileNetV2, self).__init__() |
|
block = InvertedResidual |
|
input_channel = 32 |
|
current_stride = 1 |
|
rate = 1 |
|
interverted_residual_setting = [ |
|
|
|
[1, 16, 1, 1], |
|
[6, 24, 2, 2], |
|
[6, 32, 3, 2], |
|
[6, 64, 4, 2], |
|
[6, 96, 3, 1], |
|
[6, 160, 3, 2], |
|
[6, 320, 1, 1], |
|
] |
|
|
|
|
|
input_channel = int(input_channel * width_mult) |
|
self.features = [conv_bn(3, input_channel, 2, BatchNorm)] |
|
current_stride *= 2 |
|
|
|
for t, c, n, s in interverted_residual_setting: |
|
if current_stride == output_stride: |
|
stride = 1 |
|
dilation = rate |
|
rate *= s |
|
else: |
|
stride = s |
|
dilation = 1 |
|
current_stride *= s |
|
output_channel = int(c * width_mult) |
|
for i in range(n): |
|
if i == 0: |
|
self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) |
|
else: |
|
self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) |
|
input_channel = output_channel |
|
self.features = nn.Sequential(*self.features) |
|
self._initialize_weights() |
|
|
|
if pretrained: |
|
self._load_pretrained_model() |
|
|
|
self.low_level_features = self.features[0:4] |
|
self.high_level_features = self.features[4:] |
|
|
|
def forward(self, x): |
|
low_level_feat = self.low_level_features(x) |
|
x = self.high_level_features(low_level_feat) |
|
return x, low_level_feat |
|
|
|
def _load_pretrained_model(self): |
|
pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') |
|
model_dict = {} |
|
state_dict = self.state_dict() |
|
for k, v in pretrain_dict.items(): |
|
if k in state_dict: |
|
model_dict[k] = v |
|
state_dict.update(model_dict) |
|
self.load_state_dict(state_dict) |
|
|
|
def _initialize_weights(self): |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
|
|
|
|
torch.nn.init.kaiming_normal_(m.weight) |
|
elif isinstance(m, SynchronizedBatchNorm2d): |
|
m.weight.data.fill_(1) |
|
m.bias.data.zero_() |
|
elif isinstance(m, nn.BatchNorm2d): |
|
m.weight.data.fill_(1) |
|
m.bias.data.zero_() |
|
|
|
if __name__ == "__main__": |
|
input = torch.rand(1, 3, 512, 512) |
|
model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) |
|
output, low_level_feat = model(input) |
|
print(output.size()) |
|
print(low_level_feat.size()) |
|
|