Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn.functional as F | |
import torch.nn as nn | |
import math | |
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d | |
import torch.utils.model_zoo as model_zoo | |
def conv_bn(inp, oup, stride, BatchNorm): | |
return nn.Sequential( | |
nn.Conv2d(inp, oup, 3, stride, 1, bias=False), | |
BatchNorm(oup), | |
nn.ReLU6(inplace=True) | |
) | |
def fixed_padding(inputs, kernel_size, dilation): | |
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) | |
pad_total = kernel_size_effective - 1 | |
pad_beg = pad_total // 2 | |
pad_end = pad_total - pad_beg | |
padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) | |
return padded_inputs | |
class InvertedResidual(nn.Module): | |
def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): | |
super(InvertedResidual, self).__init__() | |
self.stride = stride | |
assert stride in [1, 2] | |
hidden_dim = round(inp * expand_ratio) | |
self.use_res_connect = self.stride == 1 and inp == oup | |
self.kernel_size = 3 | |
self.dilation = dilation | |
if expand_ratio == 1: | |
self.conv = nn.Sequential( | |
# dw | |
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), | |
BatchNorm(hidden_dim), | |
nn.ReLU6(inplace=True), | |
# pw-linear | |
nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), | |
BatchNorm(oup), | |
) | |
else: | |
self.conv = nn.Sequential( | |
# pw | |
nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), | |
BatchNorm(hidden_dim), | |
nn.ReLU6(inplace=True), | |
# dw | |
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), | |
BatchNorm(hidden_dim), | |
nn.ReLU6(inplace=True), | |
# pw-linear | |
nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), | |
BatchNorm(oup), | |
) | |
def forward(self, x): | |
x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) | |
if self.use_res_connect: | |
x = x + self.conv(x_pad) | |
else: | |
x = self.conv(x_pad) | |
return x | |
class MobileNetV2(nn.Module): | |
def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True): | |
super(MobileNetV2, self).__init__() | |
block = InvertedResidual | |
input_channel = 32 | |
current_stride = 1 | |
rate = 1 | |
interverted_residual_setting = [ | |
# t, c, n, s | |
[1, 16, 1, 1], | |
[6, 24, 2, 2], | |
[6, 32, 3, 2], | |
[6, 64, 4, 2], | |
[6, 96, 3, 1], | |
[6, 160, 3, 2], | |
[6, 320, 1, 1], | |
] | |
# building first layer | |
input_channel = int(input_channel * width_mult) | |
self.features = [conv_bn(3, input_channel, 2, BatchNorm)] | |
current_stride *= 2 | |
# building inverted residual blocks | |
for t, c, n, s in interverted_residual_setting: | |
if current_stride == output_stride: | |
stride = 1 | |
dilation = rate | |
rate *= s | |
else: | |
stride = s | |
dilation = 1 | |
current_stride *= s | |
output_channel = int(c * width_mult) | |
for i in range(n): | |
if i == 0: | |
self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) | |
else: | |
self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) | |
input_channel = output_channel | |
self.features = nn.Sequential(*self.features) | |
self._initialize_weights() | |
if pretrained: | |
self._load_pretrained_model() | |
self.low_level_features = self.features[0:4] | |
self.high_level_features = self.features[4:] | |
def forward(self, x): | |
low_level_feat = self.low_level_features(x) | |
x = self.high_level_features(low_level_feat) | |
return x, low_level_feat | |
def _load_pretrained_model(self): | |
pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') | |
model_dict = {} | |
state_dict = self.state_dict() | |
for k, v in pretrain_dict.items(): | |
if k in state_dict: | |
model_dict[k] = v | |
state_dict.update(model_dict) | |
self.load_state_dict(state_dict) | |
def _initialize_weights(self): | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
# m.weight.data.normal_(0, math.sqrt(2. / n)) | |
torch.nn.init.kaiming_normal_(m.weight) | |
elif isinstance(m, SynchronizedBatchNorm2d): | |
m.weight.data.fill_(1) | |
m.bias.data.zero_() | |
elif isinstance(m, nn.BatchNorm2d): | |
m.weight.data.fill_(1) | |
m.bias.data.zero_() | |
if __name__ == "__main__": | |
input = torch.rand(1, 3, 512, 512) | |
model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) | |
output, low_level_feat = model(input) | |
print(output.size()) | |
print(low_level_feat.size()) | |