asdasdasdasd commited on
Commit
cdbd3dd
·
1 Parent(s): c98be58

Upload xception.py

Browse files
Files changed (1) hide show
  1. xception.py +460 -0
xception.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code from https://github.com/ondyari/FaceForensics
3
+ Author: Andreas Rössler
4
+ """
5
+ import os
6
+ import argparse
7
+
8
+
9
+ import torch
10
+ # import pretrainedmodels
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ # from lib.nets.xception import xception
14
+ import math
15
+ import torchvision
16
+
17
+ # import math
18
+ # import torch
19
+ # import torch.nn as nn
20
+ # import torch.nn.functional as F
21
+ import torch.utils.model_zoo as model_zoo
22
+ from torch.nn import init
23
+
24
+ pretrained_settings = {
25
+ 'xception': {
26
+ 'imagenet': {
27
+ 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth',
28
+ 'input_space': 'RGB',
29
+ 'input_size': [3, 299, 299],
30
+ 'input_range': [0, 1],
31
+ 'mean': [0.5, 0.5, 0.5],
32
+ 'std': [0.5, 0.5, 0.5],
33
+ 'num_classes': 1000,
34
+ 'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
35
+ }
36
+ }
37
+ }
38
+
39
+ PRETAINED_WEIGHT_PATH = './pretrained_models/xception-b5690688.pth'
40
+
41
+ class SeparableConv2d(nn.Module):
42
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
43
+ super(SeparableConv2d, self).__init__()
44
+
45
+ self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size,
46
+ stride, padding, dilation, groups=in_channels, bias=bias)
47
+ self.pointwise = nn.Conv2d(
48
+ in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
49
+
50
+ def forward(self, x):
51
+ x = self.conv1(x)
52
+ x = self.pointwise(x)
53
+ return x
54
+
55
+
56
+ class Block(nn.Module):
57
+ def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
58
+ super(Block, self).__init__()
59
+
60
+ if out_filters != in_filters or strides != 1:
61
+ self.skip = nn.Conv2d(in_filters, out_filters,
62
+ 1, stride=strides, bias=False)
63
+ self.skipbn = nn.BatchNorm2d(out_filters)
64
+ else:
65
+ self.skip = None
66
+
67
+ self.relu = nn.ReLU(inplace=True)
68
+ rep = []
69
+
70
+ filters = in_filters
71
+ if grow_first:
72
+ rep.append(self.relu)
73
+ rep.append(SeparableConv2d(in_filters, out_filters,
74
+ 3, stride=1, padding=1, bias=False))
75
+ rep.append(nn.BatchNorm2d(out_filters))
76
+ filters = out_filters
77
+
78
+ for i in range(reps-1):
79
+ rep.append(self.relu)
80
+ rep.append(SeparableConv2d(filters, filters,
81
+ 3, stride=1, padding=1, bias=False))
82
+ rep.append(nn.BatchNorm2d(filters))
83
+
84
+ if not grow_first:
85
+ rep.append(self.relu)
86
+ rep.append(SeparableConv2d(in_filters, out_filters,
87
+ 3, stride=1, padding=1, bias=False))
88
+ rep.append(nn.BatchNorm2d(out_filters))
89
+
90
+ if not start_with_relu:
91
+ rep = rep[1:]
92
+ else:
93
+ rep[0] = nn.ReLU(inplace=False)
94
+
95
+ if strides != 1:
96
+ rep.append(nn.MaxPool2d(3, strides, 1))
97
+ self.rep = nn.Sequential(*rep)
98
+
99
+ def forward(self, inp):
100
+ x = self.rep(inp)
101
+
102
+ if self.skip is not None:
103
+ skip = self.skip(inp)
104
+ skip = self.skipbn(skip)
105
+ else:
106
+ skip = inp
107
+
108
+ x += skip
109
+ return x
110
+
111
+
112
+ def add_gaussian_noise(ins, mean=0, stddev=0.2):
113
+ noise = ins.data.new(ins.size()).normal_(mean, stddev)
114
+ return ins + noise
115
+
116
+
117
+ class Xception(nn.Module):
118
+ """
119
+ Xception optimized for the ImageNet dataset, as specified in
120
+ https://arxiv.org/pdf/1610.02357.pdf
121
+ """
122
+
123
+ def __init__(self, num_classes=1000, inc=3):
124
+ """ Constructor
125
+ Args:
126
+ num_classes: number of classes
127
+ """
128
+ super(Xception, self).__init__()
129
+ self.num_classes = num_classes
130
+
131
+ # Entry flow
132
+ self.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False)
133
+ self.bn1 = nn.BatchNorm2d(32)
134
+ self.relu = nn.ReLU(inplace=True)
135
+
136
+ self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
137
+ self.bn2 = nn.BatchNorm2d(64)
138
+ # do relu here
139
+
140
+ self.block1 = Block(
141
+ 64, 128, 2, 2, start_with_relu=False, grow_first=True)
142
+ self.block2 = Block(
143
+ 128, 256, 2, 2, start_with_relu=True, grow_first=True)
144
+ self.block3 = Block(
145
+ 256, 728, 2, 2, start_with_relu=True, grow_first=True)
146
+
147
+ # middle flow
148
+ self.block4 = Block(
149
+ 728, 728, 3, 1, start_with_relu=True, grow_first=True)
150
+ self.block5 = Block(
151
+ 728, 728, 3, 1, start_with_relu=True, grow_first=True)
152
+ self.block6 = Block(
153
+ 728, 728, 3, 1, start_with_relu=True, grow_first=True)
154
+ self.block7 = Block(
155
+ 728, 728, 3, 1, start_with_relu=True, grow_first=True)
156
+
157
+ self.block8 = Block(
158
+ 728, 728, 3, 1, start_with_relu=True, grow_first=True)
159
+ self.block9 = Block(
160
+ 728, 728, 3, 1, start_with_relu=True, grow_first=True)
161
+ self.block10 = Block(
162
+ 728, 728, 3, 1, start_with_relu=True, grow_first=True)
163
+ self.block11 = Block(
164
+ 728, 728, 3, 1, start_with_relu=True, grow_first=True)
165
+
166
+ # Exit flow
167
+ self.block12 = Block(
168
+ 728, 1024, 2, 2, start_with_relu=True, grow_first=False)
169
+
170
+ self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
171
+ self.bn3 = nn.BatchNorm2d(1536)
172
+
173
+ # do relu here
174
+ self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
175
+ self.bn4 = nn.BatchNorm2d(2048)
176
+
177
+ self.fc = nn.Linear(2048, num_classes)
178
+
179
+ # #------- init weights --------
180
+ # for m in self.modules():
181
+ # if isinstance(m, nn.Conv2d):
182
+ # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
183
+ # m.weight.data.normal_(0, math.sqrt(2. / n))
184
+ # elif isinstance(m, nn.BatchNorm2d):
185
+ # m.weight.data.fill_(1)
186
+ # m.bias.data.zero_()
187
+ # #-----------------------------
188
+ def fea_part1_0(self, x):
189
+ x = self.conv1(x)
190
+ x = self.bn1(x)
191
+ x = self.relu(x)
192
+
193
+ return x
194
+
195
+ def fea_part1_1(self, x):
196
+
197
+ x = self.conv2(x)
198
+ x = self.bn2(x)
199
+ x = self.relu(x)
200
+
201
+ return x
202
+
203
+ def fea_part1(self, x):
204
+ x = self.conv1(x)
205
+ x = self.bn1(x)
206
+ x = self.relu(x)
207
+
208
+ x = self.conv2(x)
209
+ x = self.bn2(x)
210
+ x = self.relu(x)
211
+
212
+ return x
213
+
214
+ def fea_part2(self, x):
215
+ x = self.block1(x)
216
+ x = self.block2(x)
217
+ x = self.block3(x)
218
+
219
+ return x
220
+
221
+ def fea_part3(self, x):
222
+ x = self.block4(x)
223
+ x = self.block5(x)
224
+ x = self.block6(x)
225
+ x = self.block7(x)
226
+
227
+ return x
228
+
229
+ def fea_part4(self, x):
230
+ x = self.block8(x)
231
+ x = self.block9(x)
232
+ x = self.block10(x)
233
+ x = self.block11(x)
234
+
235
+ return x
236
+
237
+ def fea_part5(self, x):
238
+ x = self.block12(x)
239
+
240
+ x = self.conv3(x)
241
+ x = self.bn3(x)
242
+ x = self.relu(x)
243
+
244
+ x = self.conv4(x)
245
+ x = self.bn4(x)
246
+
247
+ return x
248
+
249
+ def features(self, input):
250
+ x = self.fea_part1(input)
251
+
252
+ x = self.fea_part2(x)
253
+ x = self.fea_part3(x)
254
+ x = self.fea_part4(x)
255
+
256
+ x = self.fea_part5(x)
257
+ return x
258
+
259
+ def classifier(self, features):
260
+ x = self.relu(features)
261
+
262
+ x = F.adaptive_avg_pool2d(x, (1, 1))
263
+ x = x.view(x.size(0), -1)
264
+ out = self.last_linear(x)
265
+ return out, x
266
+
267
+ def forward(self, input):
268
+ x = self.features(input)
269
+ out, x = self.classifier(x)
270
+ return out, x
271
+
272
+
273
+ def xception(num_classes=1000, pretrained='imagenet', inc=3):
274
+ model = Xception(num_classes=num_classes, inc=inc)
275
+ if pretrained:
276
+ settings = pretrained_settings['xception'][pretrained]
277
+ assert num_classes == settings['num_classes'], \
278
+ "num_classes should be {}, but is {}".format(
279
+ settings['num_classes'], num_classes)
280
+
281
+ model = Xception(num_classes=num_classes)
282
+ model.load_state_dict(model_zoo.load_url(settings['url']))
283
+
284
+ model.input_space = settings['input_space']
285
+ model.input_size = settings['input_size']
286
+ model.input_range = settings['input_range']
287
+ model.mean = settings['mean']
288
+ model.std = settings['std']
289
+
290
+ # TODO: ugly
291
+ model.last_linear = model.fc
292
+ del model.fc
293
+ return model
294
+
295
+
296
+ class TransferModel(nn.Module):
297
+ """
298
+ Simple transfer learning model that takes an imagenet pretrained model with
299
+ a fc layer as base model and retrains a new fc layer for num_out_classes
300
+ """
301
+
302
+ def __init__(self, modelchoice, num_out_classes=2, dropout=0.0,
303
+ weight_norm=False, return_fea=False, inc=3):
304
+ super(TransferModel, self).__init__()
305
+ self.modelchoice = modelchoice
306
+ self.return_fea = return_fea
307
+
308
+ if modelchoice == 'xception':
309
+
310
+ def return_pytorch04_xception(pretrained=True):
311
+ # Raises warning "src not broadcastable to dst" but thats fine
312
+ model = xception(pretrained=False)
313
+ if pretrained:
314
+ # Load model in torch 0.4+
315
+ model.fc = model.last_linear
316
+ del model.last_linear
317
+ state_dict = torch.load(
318
+ PRETAINED_WEIGHT_PATH)
319
+ for name, weights in state_dict.items():
320
+ if 'pointwise' in name:
321
+ state_dict[name] = weights.unsqueeze(
322
+ -1).unsqueeze(-1)
323
+ model.load_state_dict(state_dict)
324
+ model.last_linear = model.fc
325
+ del model.fc
326
+ return model
327
+
328
+ self.model = return_pytorch04_xception()
329
+ # Replace fc
330
+ num_ftrs = self.model.last_linear.in_features
331
+ if not dropout:
332
+ if weight_norm:
333
+ print('Using Weight_Norm')
334
+ self.model.last_linear = nn.utils.weight_norm(
335
+ nn.Linear(num_ftrs, num_out_classes), name='weight')
336
+ self.model.last_linear = nn.Linear(num_ftrs, num_out_classes)
337
+ else:
338
+ print('Using dropout', dropout)
339
+ if weight_norm:
340
+ print('Using Weight_Norm')
341
+ self.model.last_linear = nn.Sequential(
342
+ nn.Dropout(p=dropout),
343
+ nn.utils.weight_norm(
344
+ nn.Linear(num_ftrs, num_out_classes), name='weight')
345
+ )
346
+
347
+ self.model.last_linear = nn.Sequential(
348
+ nn.Dropout(p=dropout),
349
+ nn.Linear(num_ftrs, num_out_classes)
350
+ )
351
+
352
+ if inc != 3:
353
+ self.model.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False)
354
+ nn.init.xavier_normal(self.model.conv1.weight.data, gain=0.02)
355
+
356
+ elif modelchoice == 'resnet50' or modelchoice == 'resnet18':
357
+ if modelchoice == 'resnet50':
358
+ self.model = torchvision.models.resnet50(pretrained=True)
359
+ if modelchoice == 'resnet18':
360
+ self.model = torchvision.models.resnet18(pretrained=True)
361
+ # Replace fc
362
+ num_ftrs = self.model.fc.in_features
363
+ if not dropout:
364
+ self.model.fc = nn.Linear(num_ftrs, num_out_classes)
365
+ else:
366
+ self.model.fc = nn.Sequential(
367
+ nn.Dropout(p=dropout),
368
+ nn.Linear(num_ftrs, num_out_classes)
369
+ )
370
+ else:
371
+ raise Exception('Choose valid model, e.g. resnet50')
372
+
373
+ def set_trainable_up_to(self, boolean=False, layername="Conv2d_4a_3x3"):
374
+ """
375
+ Freezes all layers below a specific layer and sets the following layers
376
+ to true if boolean else only the fully connected final layer
377
+ :param boolean:
378
+ :param layername: depends on lib, for inception e.g. Conv2d_4a_3x3
379
+ :return:
380
+ """
381
+ # Stage-1: freeze all the layers
382
+ if layername is None:
383
+ for i, param in self.model.named_parameters():
384
+ param.requires_grad = True
385
+ return
386
+ else:
387
+ for i, param in self.model.named_parameters():
388
+ param.requires_grad = False
389
+ if boolean:
390
+ # Make all layers following the layername layer trainable
391
+ ct = []
392
+ found = False
393
+ for name, child in self.model.named_children():
394
+ if layername in ct:
395
+ found = True
396
+ for params in child.parameters():
397
+ params.requires_grad = True
398
+ ct.append(name)
399
+ if not found:
400
+ raise NotImplementedError('Layer not found, cant finetune!'.format(
401
+ layername))
402
+ else:
403
+ if self.modelchoice == 'xception':
404
+ # Make fc trainable
405
+ for param in self.model.last_linear.parameters():
406
+ param.requires_grad = True
407
+
408
+ else:
409
+ # Make fc trainable
410
+ for param in self.model.fc.parameters():
411
+ param.requires_grad = True
412
+
413
+ def forward(self, x):
414
+ out, x = self.model(x)
415
+ if self.return_fea:
416
+ return out, x
417
+ else:
418
+ return out
419
+
420
+ def features(self, x):
421
+ x = self.model.features(x)
422
+ return x
423
+
424
+ def classifier(self, x):
425
+ out, x = self.model.classifier(x)
426
+ return out, x
427
+
428
+
429
+ def model_selection(modelname, num_out_classes,
430
+ dropout=None):
431
+ """
432
+ :param modelname:
433
+ :return: model, image size, pretraining<yes/no>, input_list
434
+ """
435
+ if modelname == 'xception':
436
+ return TransferModel(modelchoice='xception',
437
+ num_out_classes=num_out_classes), 299, \
438
+ True, ['image'], None
439
+ elif modelname == 'resnet18':
440
+ return TransferModel(modelchoice='resnet18', dropout=dropout,
441
+ num_out_classes=num_out_classes), \
442
+ 224, True, ['image'], None
443
+ else:
444
+ raise NotImplementedError(modelname)
445
+
446
+
447
+ if __name__ == '__main__':
448
+ model = TransferModel('xception', dropout=0.5)
449
+ print(model)
450
+ # model = model.cuda()
451
+ # from torchsummary import summary
452
+ # input_s = (3, image_size, image_size)
453
+ # print(summary(model, input_s))
454
+ dummy = torch.rand(10, 3, 256, 256)
455
+ out = model(dummy)
456
+ print(out.size())
457
+ x = model.features(dummy)
458
+ out, x = model.classifier(x)
459
+ print(out.size())
460
+ print(x.size())