leozindev15 commited on
Commit
fcba1e9
Β·
verified Β·
1 Parent(s): 5261c16

Delete assets/model.py

Browse files
Files changed (1) hide show
  1. assets/model.py +0 -283
assets/model.py DELETED
@@ -1,283 +0,0 @@
1
- #!/usr/bin/python
2
- # -*- encoding: utf-8 -*-
3
-
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- import torchvision
9
-
10
- from .resnet import Resnet18
11
- # from modules.bn import InPlaceABNSync as BatchNorm2d
12
-
13
-
14
- class ConvBNReLU(nn.Module):
15
- def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16
- super(ConvBNReLU, self).__init__()
17
- self.conv = nn.Conv2d(in_chan,
18
- out_chan,
19
- kernel_size = ks,
20
- stride = stride,
21
- padding = padding,
22
- bias = False)
23
- self.bn = nn.BatchNorm2d(out_chan)
24
- self.init_weight()
25
-
26
- def forward(self, x):
27
- x = self.conv(x)
28
- x = F.relu(self.bn(x))
29
- return x
30
-
31
- def init_weight(self):
32
- for ly in self.children():
33
- if isinstance(ly, nn.Conv2d):
34
- nn.init.kaiming_normal_(ly.weight, a=1)
35
- if not ly.bias is None: nn.init.constant_(ly.bias, 0)
36
-
37
- class BiSeNetOutput(nn.Module):
38
- def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
39
- super(BiSeNetOutput, self).__init__()
40
- self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
41
- self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
42
- self.init_weight()
43
-
44
- def forward(self, x):
45
- x = self.conv(x)
46
- x = self.conv_out(x)
47
- return x
48
-
49
- def init_weight(self):
50
- for ly in self.children():
51
- if isinstance(ly, nn.Conv2d):
52
- nn.init.kaiming_normal_(ly.weight, a=1)
53
- if not ly.bias is None: nn.init.constant_(ly.bias, 0)
54
-
55
- def get_params(self):
56
- wd_params, nowd_params = [], []
57
- for name, module in self.named_modules():
58
- if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
59
- wd_params.append(module.weight)
60
- if not module.bias is None:
61
- nowd_params.append(module.bias)
62
- elif isinstance(module, nn.BatchNorm2d):
63
- nowd_params += list(module.parameters())
64
- return wd_params, nowd_params
65
-
66
-
67
- class AttentionRefinementModule(nn.Module):
68
- def __init__(self, in_chan, out_chan, *args, **kwargs):
69
- super(AttentionRefinementModule, self).__init__()
70
- self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
71
- self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
72
- self.bn_atten = nn.BatchNorm2d(out_chan)
73
- self.sigmoid_atten = nn.Sigmoid()
74
- self.init_weight()
75
-
76
- def forward(self, x):
77
- feat = self.conv(x)
78
- atten = F.avg_pool2d(feat, feat.size()[2:])
79
- atten = self.conv_atten(atten)
80
- atten = self.bn_atten(atten)
81
- atten = self.sigmoid_atten(atten)
82
- out = torch.mul(feat, atten)
83
- return out
84
-
85
- def init_weight(self):
86
- for ly in self.children():
87
- if isinstance(ly, nn.Conv2d):
88
- nn.init.kaiming_normal_(ly.weight, a=1)
89
- if not ly.bias is None: nn.init.constant_(ly.bias, 0)
90
-
91
-
92
- class ContextPath(nn.Module):
93
- def __init__(self, *args, **kwargs):
94
- super(ContextPath, self).__init__()
95
- self.resnet = Resnet18()
96
- self.arm16 = AttentionRefinementModule(256, 128)
97
- self.arm32 = AttentionRefinementModule(512, 128)
98
- self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
99
- self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100
- self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
101
-
102
- self.init_weight()
103
-
104
- def forward(self, x):
105
- H0, W0 = x.size()[2:]
106
- feat8, feat16, feat32 = self.resnet(x)
107
- H8, W8 = feat8.size()[2:]
108
- H16, W16 = feat16.size()[2:]
109
- H32, W32 = feat32.size()[2:]
110
-
111
- avg = F.avg_pool2d(feat32, feat32.size()[2:])
112
- avg = self.conv_avg(avg)
113
- avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
114
-
115
- feat32_arm = self.arm32(feat32)
116
- feat32_sum = feat32_arm + avg_up
117
- feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
118
- feat32_up = self.conv_head32(feat32_up)
119
-
120
- feat16_arm = self.arm16(feat16)
121
- feat16_sum = feat16_arm + feat32_up
122
- feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
123
- feat16_up = self.conv_head16(feat16_up)
124
-
125
- return feat8, feat16_up, feat32_up # x8, x8, x16
126
-
127
- def init_weight(self):
128
- for ly in self.children():
129
- if isinstance(ly, nn.Conv2d):
130
- nn.init.kaiming_normal_(ly.weight, a=1)
131
- if not ly.bias is None: nn.init.constant_(ly.bias, 0)
132
-
133
- def get_params(self):
134
- wd_params, nowd_params = [], []
135
- for name, module in self.named_modules():
136
- if isinstance(module, (nn.Linear, nn.Conv2d)):
137
- wd_params.append(module.weight)
138
- if not module.bias is None:
139
- nowd_params.append(module.bias)
140
- elif isinstance(module, nn.BatchNorm2d):
141
- nowd_params += list(module.parameters())
142
- return wd_params, nowd_params
143
-
144
-
145
- ### This is not used, since I replace this with the resnet feature with the same size
146
- class SpatialPath(nn.Module):
147
- def __init__(self, *args, **kwargs):
148
- super(SpatialPath, self).__init__()
149
- self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
150
- self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
151
- self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152
- self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
153
- self.init_weight()
154
-
155
- def forward(self, x):
156
- feat = self.conv1(x)
157
- feat = self.conv2(feat)
158
- feat = self.conv3(feat)
159
- feat = self.conv_out(feat)
160
- return feat
161
-
162
- def init_weight(self):
163
- for ly in self.children():
164
- if isinstance(ly, nn.Conv2d):
165
- nn.init.kaiming_normal_(ly.weight, a=1)
166
- if not ly.bias is None: nn.init.constant_(ly.bias, 0)
167
-
168
- def get_params(self):
169
- wd_params, nowd_params = [], []
170
- for name, module in self.named_modules():
171
- if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
172
- wd_params.append(module.weight)
173
- if not module.bias is None:
174
- nowd_params.append(module.bias)
175
- elif isinstance(module, nn.BatchNorm2d):
176
- nowd_params += list(module.parameters())
177
- return wd_params, nowd_params
178
-
179
-
180
- class FeatureFusionModule(nn.Module):
181
- def __init__(self, in_chan, out_chan, *args, **kwargs):
182
- super(FeatureFusionModule, self).__init__()
183
- self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
184
- self.conv1 = nn.Conv2d(out_chan,
185
- out_chan//4,
186
- kernel_size = 1,
187
- stride = 1,
188
- padding = 0,
189
- bias = False)
190
- self.conv2 = nn.Conv2d(out_chan//4,
191
- out_chan,
192
- kernel_size = 1,
193
- stride = 1,
194
- padding = 0,
195
- bias = False)
196
- self.relu = nn.ReLU(inplace=True)
197
- self.sigmoid = nn.Sigmoid()
198
- self.init_weight()
199
-
200
- def forward(self, fsp, fcp):
201
- fcat = torch.cat([fsp, fcp], dim=1)
202
- feat = self.convblk(fcat)
203
- atten = F.avg_pool2d(feat, feat.size()[2:])
204
- atten = self.conv1(atten)
205
- atten = self.relu(atten)
206
- atten = self.conv2(atten)
207
- atten = self.sigmoid(atten)
208
- feat_atten = torch.mul(feat, atten)
209
- feat_out = feat_atten + feat
210
- return feat_out
211
-
212
- def init_weight(self):
213
- for ly in self.children():
214
- if isinstance(ly, nn.Conv2d):
215
- nn.init.kaiming_normal_(ly.weight, a=1)
216
- if not ly.bias is None: nn.init.constant_(ly.bias, 0)
217
-
218
- def get_params(self):
219
- wd_params, nowd_params = [], []
220
- for name, module in self.named_modules():
221
- if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
222
- wd_params.append(module.weight)
223
- if not module.bias is None:
224
- nowd_params.append(module.bias)
225
- elif isinstance(module, nn.BatchNorm2d):
226
- nowd_params += list(module.parameters())
227
- return wd_params, nowd_params
228
-
229
-
230
- class BiSeNet(nn.Module):
231
- def __init__(self, n_classes, *args, **kwargs):
232
- super(BiSeNet, self).__init__()
233
- self.cp = ContextPath()
234
- ## here self.sp is deleted
235
- self.ffm = FeatureFusionModule(256, 256)
236
- self.conv_out = BiSeNetOutput(256, 256, n_classes)
237
- self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
238
- self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
239
- self.init_weight()
240
-
241
- def forward(self, x):
242
- H, W = x.size()[2:]
243
- feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
244
- feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
245
- feat_fuse = self.ffm(feat_sp, feat_cp8)
246
-
247
- feat_out = self.conv_out(feat_fuse)
248
- feat_out16 = self.conv_out16(feat_cp8)
249
- feat_out32 = self.conv_out32(feat_cp16)
250
-
251
- feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
252
- feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
253
- feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
254
- return feat_out, feat_out16, feat_out32
255
-
256
- def init_weight(self):
257
- for ly in self.children():
258
- if isinstance(ly, nn.Conv2d):
259
- nn.init.kaiming_normal_(ly.weight, a=1)
260
- if not ly.bias is None: nn.init.constant_(ly.bias, 0)
261
-
262
- def get_params(self):
263
- wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
264
- for name, child in self.named_children():
265
- child_wd_params, child_nowd_params = child.get_params()
266
- if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
267
- lr_mul_wd_params += child_wd_params
268
- lr_mul_nowd_params += child_nowd_params
269
- else:
270
- wd_params += child_wd_params
271
- nowd_params += child_nowd_params
272
- return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
273
-
274
-
275
- if __name__ == "__main__":
276
- net = BiSeNet(19)
277
- net.cuda()
278
- net.eval()
279
- in_ten = torch.randn(16, 3, 640, 480).cuda()
280
- out, out16, out32 = net(in_ten)
281
- print(out.shape)
282
-
283
- net.get_params()