""" Code from https://github.com/ondyari/FaceForensics Author: Andreas Rössler """ import os import argparse import torch # import pretrainedmodels import torch.nn as nn import torch.nn.functional as F # from lib.nets.xception import xception import math import torchvision # import math # import torch # import torch.nn as nn # import torch.nn.functional as F import torch.utils.model_zoo as model_zoo from torch.nn import init pretrained_settings = { 'xception': { 'imagenet': { 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth', 'input_space': 'RGB', 'input_size': [3, 299, 299], 'input_range': [0, 1], 'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5], 'num_classes': 1000, 'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 } } } PRETAINED_WEIGHT_PATH = 'xception-b5690688.pth' class SeparableConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False): super(SeparableConv2d, self).__init__() self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias) self.pointwise = nn.Conv2d( in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) def forward(self, x): x = self.conv1(x) x = self.pointwise(x) return x class Block(nn.Module): def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True): super(Block, self).__init__() if out_filters != in_filters or strides != 1: self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False) self.skipbn = nn.BatchNorm2d(out_filters) else: self.skip = None self.relu = nn.ReLU(inplace=True) rep = [] filters = in_filters if grow_first: rep.append(self.relu) rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False)) rep.append(nn.BatchNorm2d(out_filters)) filters = out_filters for i in range(reps-1): rep.append(self.relu) rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False)) rep.append(nn.BatchNorm2d(filters)) if not grow_first: rep.append(self.relu) rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False)) rep.append(nn.BatchNorm2d(out_filters)) if not start_with_relu: rep = rep[1:] else: rep[0] = nn.ReLU(inplace=False) if strides != 1: rep.append(nn.MaxPool2d(3, strides, 1)) self.rep = nn.Sequential(*rep) def forward(self, inp): x = self.rep(inp) if self.skip is not None: skip = self.skip(inp) skip = self.skipbn(skip) else: skip = inp x += skip return x def add_gaussian_noise(ins, mean=0, stddev=0.2): noise = ins.data.new(ins.size()).normal_(mean, stddev) return ins + noise class Xception(nn.Module): """ Xception optimized for the ImageNet dataset, as specified in https://arxiv.org/pdf/1610.02357.pdf """ def __init__(self, num_classes=1000, inc=3): """ Constructor Args: num_classes: number of classes """ super(Xception, self).__init__() self.num_classes = num_classes # Entry flow self.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False) self.bn1 = nn.BatchNorm2d(32) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(32, 64, 3, bias=False) self.bn2 = nn.BatchNorm2d(64) # do relu here self.block1 = Block( 64, 128, 2, 2, start_with_relu=False, grow_first=True) self.block2 = Block( 128, 256, 2, 2, start_with_relu=True, grow_first=True) self.block3 = Block( 256, 728, 2, 2, start_with_relu=True, grow_first=True) # middle flow self.block4 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block5 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block6 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block7 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block8 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block9 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block10 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block11 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True) # Exit flow self.block12 = Block( 728, 1024, 2, 2, start_with_relu=True, grow_first=False) self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) self.bn3 = nn.BatchNorm2d(1536) # do relu here self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) self.bn4 = nn.BatchNorm2d(2048) self.fc = nn.Linear(2048, num_classes) # #------- init weights -------- # for m in self.modules(): # if isinstance(m, nn.Conv2d): # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels # m.weight.data.normal_(0, math.sqrt(2. / n)) # elif isinstance(m, nn.BatchNorm2d): # m.weight.data.fill_(1) # m.bias.data.zero_() # #----------------------------- def fea_part1_0(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) return x def fea_part1_1(self, x): x = self.conv2(x) x = self.bn2(x) x = self.relu(x) return x def fea_part1(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) return x def fea_part2(self, x): x = self.block1(x) x = self.block2(x) x = self.block3(x) return x def fea_part3(self, x): x = self.block4(x) x = self.block5(x) x = self.block6(x) x = self.block7(x) return x def fea_part4(self, x): x = self.block8(x) x = self.block9(x) x = self.block10(x) x = self.block11(x) return x def fea_part5(self, x): x = self.block12(x) x = self.conv3(x) x = self.bn3(x) x = self.relu(x) x = self.conv4(x) x = self.bn4(x) return x def features(self, input): x = self.fea_part1(input) x = self.fea_part2(x) x = self.fea_part3(x) x = self.fea_part4(x) x = self.fea_part5(x) return x def classifier(self, features): x = self.relu(features) x = F.adaptive_avg_pool2d(x, (1, 1)) x = x.view(x.size(0), -1) out = self.last_linear(x) return out, x def forward(self, input): x = self.features(input) out, x = self.classifier(x) return out, x def xception(num_classes=1000, pretrained='imagenet', inc=3): model = Xception(num_classes=num_classes, inc=inc) if pretrained: settings = pretrained_settings['xception'][pretrained] assert num_classes == settings['num_classes'], \ "num_classes should be {}, but is {}".format( settings['num_classes'], num_classes) model = Xception(num_classes=num_classes) model.load_state_dict(model_zoo.load_url(settings['url'])) model.input_space = settings['input_space'] model.input_size = settings['input_size'] model.input_range = settings['input_range'] model.mean = settings['mean'] model.std = settings['std'] # TODO: ugly model.last_linear = model.fc del model.fc return model class TransferModel(nn.Module): """ Simple transfer learning model that takes an imagenet pretrained model with a fc layer as base model and retrains a new fc layer for num_out_classes """ def __init__(self, modelchoice, num_out_classes=2, dropout=0.0, weight_norm=False, return_fea=False, inc=3): super(TransferModel, self).__init__() self.modelchoice = modelchoice self.return_fea = return_fea if modelchoice == 'xception': def return_pytorch04_xception(pretrained=True): # Raises warning "src not broadcastable to dst" but thats fine model = xception(pretrained=False) if pretrained: # Load model in torch 0.4+ model.fc = model.last_linear del model.last_linear state_dict = torch.load( PRETAINED_WEIGHT_PATH) for name, weights in state_dict.items(): if 'pointwise' in name: state_dict[name] = weights.unsqueeze( -1).unsqueeze(-1) model.load_state_dict(state_dict) model.last_linear = model.fc del model.fc return model self.model = return_pytorch04_xception() # Replace fc num_ftrs = self.model.last_linear.in_features if not dropout: if weight_norm: print('Using Weight_Norm') self.model.last_linear = nn.utils.weight_norm( nn.Linear(num_ftrs, num_out_classes), name='weight') self.model.last_linear = nn.Linear(num_ftrs, num_out_classes) else: print('Using dropout', dropout) if weight_norm: print('Using Weight_Norm') self.model.last_linear = nn.Sequential( nn.Dropout(p=dropout), nn.utils.weight_norm( nn.Linear(num_ftrs, num_out_classes), name='weight') ) self.model.last_linear = nn.Sequential( nn.Dropout(p=dropout), nn.Linear(num_ftrs, num_out_classes) ) if inc != 3: self.model.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False) nn.init.xavier_normal(self.model.conv1.weight.data, gain=0.02) elif modelchoice == 'resnet50' or modelchoice == 'resnet18': if modelchoice == 'resnet50': self.model = torchvision.models.resnet50(pretrained=True) if modelchoice == 'resnet18': self.model = torchvision.models.resnet18(pretrained=True) # Replace fc num_ftrs = self.model.fc.in_features if not dropout: self.model.fc = nn.Linear(num_ftrs, num_out_classes) else: self.model.fc = nn.Sequential( nn.Dropout(p=dropout), nn.Linear(num_ftrs, num_out_classes) ) else: raise Exception('Choose valid model, e.g. resnet50') def set_trainable_up_to(self, boolean=False, layername="Conv2d_4a_3x3"): """ Freezes all layers below a specific layer and sets the following layers to true if boolean else only the fully connected final layer :param boolean: :param layername: depends on lib, for inception e.g. Conv2d_4a_3x3 :return: """ # Stage-1: freeze all the layers if layername is None: for i, param in self.model.named_parameters(): param.requires_grad = True return else: for i, param in self.model.named_parameters(): param.requires_grad = False if boolean: # Make all layers following the layername layer trainable ct = [] found = False for name, child in self.model.named_children(): if layername in ct: found = True for params in child.parameters(): params.requires_grad = True ct.append(name) if not found: raise NotImplementedError('Layer not found, cant finetune!'.format( layername)) else: if self.modelchoice == 'xception': # Make fc trainable for param in self.model.last_linear.parameters(): param.requires_grad = True else: # Make fc trainable for param in self.model.fc.parameters(): param.requires_grad = True def forward(self, x): out, x = self.model(x) if self.return_fea: return out, x else: return out def features(self, x): x = self.model.features(x) return x def classifier(self, x): out, x = self.model.classifier(x) return out, x def model_selection(modelname, num_out_classes, dropout=None): """ :param modelname: :return: model, image size, pretraining, input_list """ if modelname == 'xception': return TransferModel(modelchoice='xception', num_out_classes=num_out_classes), 299, \ True, ['image'], None elif modelname == 'resnet18': return TransferModel(modelchoice='resnet18', dropout=dropout, num_out_classes=num_out_classes), \ 224, True, ['image'], None else: raise NotImplementedError(modelname) if __name__ == '__main__': model = TransferModel('xception', dropout=0.5) print(model) # model = model.cuda() # from torchsummary import summary # input_s = (3, image_size, image_size) # print(summary(model, input_s)) dummy = torch.rand(10, 3, 256, 256) out = model(dummy) print(out.size()) x = model.features(dummy) out, x = model.classifier(x) print(out.size()) print(x.size())