Face-forgery-detection / xception.py
asdasdasdasd's picture
Update xception.py
3a96035
"""
Code from https://github.com/ondyari/FaceForensics
Author: Andreas Rössler
"""
import os
import argparse
import torch
# import pretrainedmodels
import torch.nn as nn
import torch.nn.functional as F
# from lib.nets.xception import xception
import math
import torchvision
# import math
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.nn import init
pretrained_settings = {
'xception': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth',
'input_space': 'RGB',
'input_size': [3, 299, 299],
'input_range': [0, 1],
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'num_classes': 1000,
'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
}
}
}
PRETAINED_WEIGHT_PATH = 'xception-b5690688.pth'
class SeparableConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
super(SeparableConv2d, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size,
stride, padding, dilation, groups=in_channels, bias=bias)
self.pointwise = nn.Conv2d(
in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
def forward(self, x):
x = self.conv1(x)
x = self.pointwise(x)
return x
class Block(nn.Module):
def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
super(Block, self).__init__()
if out_filters != in_filters or strides != 1:
self.skip = nn.Conv2d(in_filters, out_filters,
1, stride=strides, bias=False)
self.skipbn = nn.BatchNorm2d(out_filters)
else:
self.skip = None
self.relu = nn.ReLU(inplace=True)
rep = []
filters = in_filters
if grow_first:
rep.append(self.relu)
rep.append(SeparableConv2d(in_filters, out_filters,
3, stride=1, padding=1, bias=False))
rep.append(nn.BatchNorm2d(out_filters))
filters = out_filters
for i in range(reps-1):
rep.append(self.relu)
rep.append(SeparableConv2d(filters, filters,
3, stride=1, padding=1, bias=False))
rep.append(nn.BatchNorm2d(filters))
if not grow_first:
rep.append(self.relu)
rep.append(SeparableConv2d(in_filters, out_filters,
3, stride=1, padding=1, bias=False))
rep.append(nn.BatchNorm2d(out_filters))
if not start_with_relu:
rep = rep[1:]
else:
rep[0] = nn.ReLU(inplace=False)
if strides != 1:
rep.append(nn.MaxPool2d(3, strides, 1))
self.rep = nn.Sequential(*rep)
def forward(self, inp):
x = self.rep(inp)
if self.skip is not None:
skip = self.skip(inp)
skip = self.skipbn(skip)
else:
skip = inp
x += skip
return x
def add_gaussian_noise(ins, mean=0, stddev=0.2):
noise = ins.data.new(ins.size()).normal_(mean, stddev)
return ins + noise
class Xception(nn.Module):
"""
Xception optimized for the ImageNet dataset, as specified in
https://arxiv.org/pdf/1610.02357.pdf
"""
def __init__(self, num_classes=1000, inc=3):
""" Constructor
Args:
num_classes: number of classes
"""
super(Xception, self).__init__()
self.num_classes = num_classes
# Entry flow
self.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False)
self.bn1 = nn.BatchNorm2d(32)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
self.bn2 = nn.BatchNorm2d(64)
# do relu here
self.block1 = Block(
64, 128, 2, 2, start_with_relu=False, grow_first=True)
self.block2 = Block(
128, 256, 2, 2, start_with_relu=True, grow_first=True)
self.block3 = Block(
256, 728, 2, 2, start_with_relu=True, grow_first=True)
# middle flow
self.block4 = Block(
728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block5 = Block(
728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block6 = Block(
728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block7 = Block(
728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block8 = Block(
728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block9 = Block(
728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block10 = Block(
728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block11 = Block(
728, 728, 3, 1, start_with_relu=True, grow_first=True)
# Exit flow
self.block12 = Block(
728, 1024, 2, 2, start_with_relu=True, grow_first=False)
self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
self.bn3 = nn.BatchNorm2d(1536)
# do relu here
self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
self.bn4 = nn.BatchNorm2d(2048)
self.fc = nn.Linear(2048, num_classes)
# #------- init weights --------
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
# m.weight.data.normal_(0, math.sqrt(2. / n))
# elif isinstance(m, nn.BatchNorm2d):
# m.weight.data.fill_(1)
# m.bias.data.zero_()
# #-----------------------------
def fea_part1_0(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
return x
def fea_part1_1(self, x):
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
return x
def fea_part1(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
return x
def fea_part2(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
return x
def fea_part3(self, x):
x = self.block4(x)
x = self.block5(x)
x = self.block6(x)
x = self.block7(x)
return x
def fea_part4(self, x):
x = self.block8(x)
x = self.block9(x)
x = self.block10(x)
x = self.block11(x)
return x
def fea_part5(self, x):
x = self.block12(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = self.conv4(x)
x = self.bn4(x)
return x
def features(self, input):
x = self.fea_part1(input)
x = self.fea_part2(x)
x = self.fea_part3(x)
x = self.fea_part4(x)
x = self.fea_part5(x)
return x
def classifier(self, features):
x = self.relu(features)
x = F.adaptive_avg_pool2d(x, (1, 1))
x = x.view(x.size(0), -1)
out = self.last_linear(x)
return out, x
def forward(self, input):
x = self.features(input)
out, x = self.classifier(x)
return out, x
def xception(num_classes=1000, pretrained='imagenet', inc=3):
model = Xception(num_classes=num_classes, inc=inc)
if pretrained:
settings = pretrained_settings['xception'][pretrained]
assert num_classes == settings['num_classes'], \
"num_classes should be {}, but is {}".format(
settings['num_classes'], num_classes)
model = Xception(num_classes=num_classes)
model.load_state_dict(model_zoo.load_url(settings['url']))
model.input_space = settings['input_space']
model.input_size = settings['input_size']
model.input_range = settings['input_range']
model.mean = settings['mean']
model.std = settings['std']
# TODO: ugly
model.last_linear = model.fc
del model.fc
return model
class TransferModel(nn.Module):
"""
Simple transfer learning model that takes an imagenet pretrained model with
a fc layer as base model and retrains a new fc layer for num_out_classes
"""
def __init__(self, modelchoice, num_out_classes=2, dropout=0.0,
weight_norm=False, return_fea=False, inc=3):
super(TransferModel, self).__init__()
self.modelchoice = modelchoice
self.return_fea = return_fea
if modelchoice == 'xception':
def return_pytorch04_xception(pretrained=True):
# Raises warning "src not broadcastable to dst" but thats fine
model = xception(pretrained=False)
if pretrained:
# Load model in torch 0.4+
model.fc = model.last_linear
del model.last_linear
state_dict = torch.load(
PRETAINED_WEIGHT_PATH)
for name, weights in state_dict.items():
if 'pointwise' in name:
state_dict[name] = weights.unsqueeze(
-1).unsqueeze(-1)
model.load_state_dict(state_dict)
model.last_linear = model.fc
del model.fc
return model
self.model = return_pytorch04_xception()
# Replace fc
num_ftrs = self.model.last_linear.in_features
if not dropout:
if weight_norm:
print('Using Weight_Norm')
self.model.last_linear = nn.utils.weight_norm(
nn.Linear(num_ftrs, num_out_classes), name='weight')
self.model.last_linear = nn.Linear(num_ftrs, num_out_classes)
else:
print('Using dropout', dropout)
if weight_norm:
print('Using Weight_Norm')
self.model.last_linear = nn.Sequential(
nn.Dropout(p=dropout),
nn.utils.weight_norm(
nn.Linear(num_ftrs, num_out_classes), name='weight')
)
self.model.last_linear = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(num_ftrs, num_out_classes)
)
if inc != 3:
self.model.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False)
nn.init.xavier_normal(self.model.conv1.weight.data, gain=0.02)
elif modelchoice == 'resnet50' or modelchoice == 'resnet18':
if modelchoice == 'resnet50':
self.model = torchvision.models.resnet50(pretrained=True)
if modelchoice == 'resnet18':
self.model = torchvision.models.resnet18(pretrained=True)
# Replace fc
num_ftrs = self.model.fc.in_features
if not dropout:
self.model.fc = nn.Linear(num_ftrs, num_out_classes)
else:
self.model.fc = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(num_ftrs, num_out_classes)
)
else:
raise Exception('Choose valid model, e.g. resnet50')
def set_trainable_up_to(self, boolean=False, layername="Conv2d_4a_3x3"):
"""
Freezes all layers below a specific layer and sets the following layers
to true if boolean else only the fully connected final layer
:param boolean:
:param layername: depends on lib, for inception e.g. Conv2d_4a_3x3
:return:
"""
# Stage-1: freeze all the layers
if layername is None:
for i, param in self.model.named_parameters():
param.requires_grad = True
return
else:
for i, param in self.model.named_parameters():
param.requires_grad = False
if boolean:
# Make all layers following the layername layer trainable
ct = []
found = False
for name, child in self.model.named_children():
if layername in ct:
found = True
for params in child.parameters():
params.requires_grad = True
ct.append(name)
if not found:
raise NotImplementedError('Layer not found, cant finetune!'.format(
layername))
else:
if self.modelchoice == 'xception':
# Make fc trainable
for param in self.model.last_linear.parameters():
param.requires_grad = True
else:
# Make fc trainable
for param in self.model.fc.parameters():
param.requires_grad = True
def forward(self, x):
out, x = self.model(x)
if self.return_fea:
return out, x
else:
return out
def features(self, x):
x = self.model.features(x)
return x
def classifier(self, x):
out, x = self.model.classifier(x)
return out, x
def model_selection(modelname, num_out_classes,
dropout=None):
"""
:param modelname:
:return: model, image size, pretraining<yes/no>, input_list
"""
if modelname == 'xception':
return TransferModel(modelchoice='xception',
num_out_classes=num_out_classes), 299, \
True, ['image'], None
elif modelname == 'resnet18':
return TransferModel(modelchoice='resnet18', dropout=dropout,
num_out_classes=num_out_classes), \
224, True, ['image'], None
else:
raise NotImplementedError(modelname)
if __name__ == '__main__':
model = TransferModel('xception', dropout=0.5)
print(model)
# model = model.cuda()
# from torchsummary import summary
# input_s = (3, image_size, image_size)
# print(summary(model, input_s))
dummy = torch.rand(10, 3, 256, 256)
out = model(dummy)
print(out.size())
x = model.features(dummy)
out, x = model.classifier(x)
print(out.size())
print(x.size())