peterkros commited on
Commit
72e4293
·
verified ·
1 Parent(s): f759370

Create mobilenetv2.py

Browse files
Files changed (1) hide show
  1. src/models/backbones/mobilenetv2.py +199 -0
src/models/backbones/mobilenetv2.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ This file is adapted from https://github.com/thuyngch/Human-Segmentation-PyTorch"""
2
+
3
+ import math
4
+ import json
5
+ from functools import reduce
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+
11
+ #------------------------------------------------------------------------------
12
+ # Useful functions
13
+ #------------------------------------------------------------------------------
14
+
15
+ def _make_divisible(v, divisor, min_value=None):
16
+ if min_value is None:
17
+ min_value = divisor
18
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
19
+ # Make sure that round down does not go down by more than 10%.
20
+ if new_v < 0.9 * v:
21
+ new_v += divisor
22
+ return new_v
23
+
24
+
25
+ def conv_bn(inp, oup, stride):
26
+ return nn.Sequential(
27
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
28
+ nn.BatchNorm2d(oup),
29
+ nn.ReLU6(inplace=True)
30
+ )
31
+
32
+
33
+ def conv_1x1_bn(inp, oup):
34
+ return nn.Sequential(
35
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
36
+ nn.BatchNorm2d(oup),
37
+ nn.ReLU6(inplace=True)
38
+ )
39
+
40
+
41
+ #------------------------------------------------------------------------------
42
+ # Class of Inverted Residual block
43
+ #------------------------------------------------------------------------------
44
+
45
+ class InvertedResidual(nn.Module):
46
+ def __init__(self, inp, oup, stride, expansion, dilation=1):
47
+ super(InvertedResidual, self).__init__()
48
+ self.stride = stride
49
+ assert stride in [1, 2]
50
+
51
+ hidden_dim = round(inp * expansion)
52
+ self.use_res_connect = self.stride == 1 and inp == oup
53
+
54
+ if expansion == 1:
55
+ self.conv = nn.Sequential(
56
+ # dw
57
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
58
+ nn.BatchNorm2d(hidden_dim),
59
+ nn.ReLU6(inplace=True),
60
+ # pw-linear
61
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
62
+ nn.BatchNorm2d(oup),
63
+ )
64
+ else:
65
+ self.conv = nn.Sequential(
66
+ # pw
67
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
68
+ nn.BatchNorm2d(hidden_dim),
69
+ nn.ReLU6(inplace=True),
70
+ # dw
71
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
72
+ nn.BatchNorm2d(hidden_dim),
73
+ nn.ReLU6(inplace=True),
74
+ # pw-linear
75
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
76
+ nn.BatchNorm2d(oup),
77
+ )
78
+
79
+ def forward(self, x):
80
+ if self.use_res_connect:
81
+ return x + self.conv(x)
82
+ else:
83
+ return self.conv(x)
84
+
85
+
86
+ #------------------------------------------------------------------------------
87
+ # Class of MobileNetV2
88
+ #------------------------------------------------------------------------------
89
+
90
+ class MobileNetV2(nn.Module):
91
+ def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000):
92
+ super(MobileNetV2, self).__init__()
93
+ self.in_channels = in_channels
94
+ self.num_classes = num_classes
95
+ input_channel = 32
96
+ last_channel = 1280
97
+ interverted_residual_setting = [
98
+ # t, c, n, s
99
+ [1 , 16, 1, 1],
100
+ [expansion, 24, 2, 2],
101
+ [expansion, 32, 3, 2],
102
+ [expansion, 64, 4, 2],
103
+ [expansion, 96, 3, 1],
104
+ [expansion, 160, 3, 2],
105
+ [expansion, 320, 1, 1],
106
+ ]
107
+
108
+ # building first layer
109
+ input_channel = _make_divisible(input_channel*alpha, 8)
110
+ self.last_channel = _make_divisible(last_channel*alpha, 8) if alpha > 1.0 else last_channel
111
+ self.features = [conv_bn(self.in_channels, input_channel, 2)]
112
+
113
+ # building inverted residual blocks
114
+ for t, c, n, s in interverted_residual_setting:
115
+ output_channel = _make_divisible(int(c*alpha), 8)
116
+ for i in range(n):
117
+ if i == 0:
118
+ self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t))
119
+ else:
120
+ self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t))
121
+ input_channel = output_channel
122
+
123
+ # building last several layers
124
+ self.features.append(conv_1x1_bn(input_channel, self.last_channel))
125
+
126
+ # make it nn.Sequential
127
+ self.features = nn.Sequential(*self.features)
128
+
129
+ # building classifier
130
+ if self.num_classes is not None:
131
+ self.classifier = nn.Sequential(
132
+ nn.Dropout(0.2),
133
+ nn.Linear(self.last_channel, num_classes),
134
+ )
135
+
136
+ # Initialize weights
137
+ self._init_weights()
138
+
139
+ def forward(self, x):
140
+ # Stage1
141
+ x = self.features[0](x)
142
+ x = self.features[1](x)
143
+ # Stage2
144
+ x = self.features[2](x)
145
+ x = self.features[3](x)
146
+ # Stage3
147
+ x = self.features[4](x)
148
+ x = self.features[5](x)
149
+ x = self.features[6](x)
150
+ # Stage4
151
+ x = self.features[7](x)
152
+ x = self.features[8](x)
153
+ x = self.features[9](x)
154
+ x = self.features[10](x)
155
+ x = self.features[11](x)
156
+ x = self.features[12](x)
157
+ x = self.features[13](x)
158
+ # Stage5
159
+ x = self.features[14](x)
160
+ x = self.features[15](x)
161
+ x = self.features[16](x)
162
+ x = self.features[17](x)
163
+ x = self.features[18](x)
164
+
165
+ # Classification
166
+ if self.num_classes is not None:
167
+ x = x.mean(dim=(2,3))
168
+ x = self.classifier(x)
169
+
170
+ # Output
171
+ return x
172
+
173
+ def _load_pretrained_model(self, pretrained_file):
174
+ pretrain_dict = torch.load(pretrained_file, map_location='cpu')
175
+ model_dict = {}
176
+ state_dict = self.state_dict()
177
+ print("[MobileNetV2] Loading pretrained model...")
178
+ for k, v in pretrain_dict.items():
179
+ if k in state_dict:
180
+ model_dict[k] = v
181
+ else:
182
+ print(k, "is ignored")
183
+ state_dict.update(model_dict)
184
+ self.load_state_dict(state_dict)
185
+
186
+ def _init_weights(self):
187
+ for m in self.modules():
188
+ if isinstance(m, nn.Conv2d):
189
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
190
+ m.weight.data.normal_(0, math.sqrt(2. / n))
191
+ if m.bias is not None:
192
+ m.bias.data.zero_()
193
+ elif isinstance(m, nn.BatchNorm2d):
194
+ m.weight.data.fill_(1)
195
+ m.bias.data.zero_()
196
+ elif isinstance(m, nn.Linear):
197
+ n = m.weight.size(1)
198
+ m.weight.data.normal_(0, 0.01)
199
+ m.bias.data.zero_()