qubvel-hf's picture
qubvel-hf HF Staff
Init project
c509e76
import torch
import torch.nn.functional as F
import torch.nn as nn
import math
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
import torch.utils.model_zoo as model_zoo
def conv_bn(inp, oup, stride, BatchNorm):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
BatchNorm(oup),
nn.ReLU6(inplace=True)
)
def fixed_padding(inputs, kernel_size, dilation):
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
pad_total = kernel_size_effective - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
return padded_inputs
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = round(inp * expand_ratio)
self.use_res_connect = self.stride == 1 and inp == oup
self.kernel_size = 3
self.dilation = dilation
if expand_ratio == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
BatchNorm(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False),
BatchNorm(oup),
)
else:
self.conv = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False),
BatchNorm(hidden_dim),
nn.ReLU6(inplace=True),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
BatchNorm(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False),
BatchNorm(oup),
)
def forward(self, x):
x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation)
if self.use_res_connect:
x = x + self.conv(x_pad)
else:
x = self.conv(x_pad)
return x
class MobileNetV2(nn.Module):
def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True):
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
current_stride = 1
rate = 1
interverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# building first layer
input_channel = int(input_channel * width_mult)
self.features = [conv_bn(3, input_channel, 2, BatchNorm)]
current_stride *= 2
# building inverted residual blocks
for t, c, n, s in interverted_residual_setting:
if current_stride == output_stride:
stride = 1
dilation = rate
rate *= s
else:
stride = s
dilation = 1
current_stride *= s
output_channel = int(c * width_mult)
for i in range(n):
if i == 0:
self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm))
else:
self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm))
input_channel = output_channel
self.features = nn.Sequential(*self.features)
self._initialize_weights()
if pretrained:
self._load_pretrained_model()
self.low_level_features = self.features[0:4]
self.high_level_features = self.features[4:]
def forward(self, x):
low_level_feat = self.low_level_features(x)
x = self.high_level_features(low_level_feat)
return x, low_level_feat
def _load_pretrained_model(self):
pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth')
model_dict = {}
state_dict = self.state_dict()
for k, v in pretrain_dict.items():
if k in state_dict:
model_dict[k] = v
state_dict.update(model_dict)
self.load_state_dict(state_dict)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
# m.weight.data.normal_(0, math.sqrt(2. / n))
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, SynchronizedBatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
if __name__ == "__main__":
input = torch.rand(1, 3, 512, 512)
model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d)
output, low_level_feat = model(input)
print(output.size())
print(low_level_feat.size())