File size: 5,402 Bytes
c509e76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
import torch.nn.functional as F
import torch.nn as nn
import math
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
import torch.utils.model_zoo as model_zoo
def conv_bn(inp, oup, stride, BatchNorm):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
BatchNorm(oup),
nn.ReLU6(inplace=True)
)
def fixed_padding(inputs, kernel_size, dilation):
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
pad_total = kernel_size_effective - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
return padded_inputs
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = round(inp * expand_ratio)
self.use_res_connect = self.stride == 1 and inp == oup
self.kernel_size = 3
self.dilation = dilation
if expand_ratio == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
BatchNorm(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False),
BatchNorm(oup),
)
else:
self.conv = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False),
BatchNorm(hidden_dim),
nn.ReLU6(inplace=True),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
BatchNorm(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False),
BatchNorm(oup),
)
def forward(self, x):
x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation)
if self.use_res_connect:
x = x + self.conv(x_pad)
else:
x = self.conv(x_pad)
return x
class MobileNetV2(nn.Module):
def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True):
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
current_stride = 1
rate = 1
interverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# building first layer
input_channel = int(input_channel * width_mult)
self.features = [conv_bn(3, input_channel, 2, BatchNorm)]
current_stride *= 2
# building inverted residual blocks
for t, c, n, s in interverted_residual_setting:
if current_stride == output_stride:
stride = 1
dilation = rate
rate *= s
else:
stride = s
dilation = 1
current_stride *= s
output_channel = int(c * width_mult)
for i in range(n):
if i == 0:
self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm))
else:
self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm))
input_channel = output_channel
self.features = nn.Sequential(*self.features)
self._initialize_weights()
if pretrained:
self._load_pretrained_model()
self.low_level_features = self.features[0:4]
self.high_level_features = self.features[4:]
def forward(self, x):
low_level_feat = self.low_level_features(x)
x = self.high_level_features(low_level_feat)
return x, low_level_feat
def _load_pretrained_model(self):
pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth')
model_dict = {}
state_dict = self.state_dict()
for k, v in pretrain_dict.items():
if k in state_dict:
model_dict[k] = v
state_dict.update(model_dict)
self.load_state_dict(state_dict)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
# m.weight.data.normal_(0, math.sqrt(2. / n))
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, SynchronizedBatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
if __name__ == "__main__":
input = torch.rand(1, 3, 512, 512)
model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d)
output, low_level_feat = model(input)
print(output.size())
print(low_level_feat.size())
|