P-DFD / model /network /Recce.py
mrneuralnet's picture
Initial commit
982865f
from functools import partial
from timm.models import xception
from model.common import SeparableConv2d, Block
from model.common import GuidedAttention, GraphReasoning
import torch
import torch.nn as nn
import torch.nn.functional as F
encoder_params = {
"xception": {
"features": 2048,
"init_op": partial(xception, pretrained=True)
}
}
class Recce(nn.Module):
""" End-to-End Reconstruction-Classification Learning for Face Forgery Detection """
def __init__(self, num_classes, drop_rate=0.2):
super(Recce, self).__init__()
self.name = "xception"
self.loss_inputs = dict()
self.encoder = encoder_params[self.name]["init_op"]()
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout(drop_rate)
self.fc = nn.Linear(encoder_params[self.name]["features"], num_classes)
self.attention = GuidedAttention(depth=728, drop_rate=drop_rate)
self.reasoning = GraphReasoning(728, 256, 256, 256, 128, 256, [2, 4], drop_rate)
self.decoder1 = nn.Sequential(
nn.UpsamplingNearest2d(scale_factor=2),
SeparableConv2d(728, 256, 3, 1, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True)
)
self.decoder2 = Block(256, 256, 3, 1)
self.decoder3 = nn.Sequential(
nn.UpsamplingNearest2d(scale_factor=2),
SeparableConv2d(256, 128, 3, 1, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True)
)
self.decoder4 = Block(128, 128, 3, 1)
self.decoder5 = nn.Sequential(
nn.UpsamplingNearest2d(scale_factor=2),
SeparableConv2d(128, 64, 3, 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.decoder6 = nn.Sequential(
nn.Conv2d(64, 3, 1, 1, bias=False),
nn.Tanh()
)
def norm_n_corr(self, x):
norm_embed = F.normalize(self.global_pool(x), p=2, dim=1)
corr = (torch.matmul(norm_embed.squeeze(), norm_embed.squeeze().T) + 1.) / 2.
return norm_embed, corr
@staticmethod
def add_white_noise(tensor, mean=0., std=1e-6):
rand = torch.rand([tensor.shape[0], 1, 1, 1])
rand = torch.where(rand > 0.5, 1., 0.).to(tensor.device)
white_noise = torch.normal(mean, std, size=tensor.shape, device=tensor.device)
noise_t = tensor + white_noise * rand
noise_t = torch.clip(noise_t, -1., 1.)
return noise_t
def forward(self, x):
# clear the loss inputs
self.loss_inputs = dict(recons=[], contra=[])
noise_x = self.add_white_noise(x) if self.training else x
out = self.encoder.conv1(noise_x)
out = self.encoder.bn1(out)
out = self.encoder.act1(out)
out = self.encoder.conv2(out)
out = self.encoder.bn2(out)
out = self.encoder.act2(out)
out = self.encoder.block1(out)
out = self.encoder.block2(out)
out = self.encoder.block3(out)
embedding = self.encoder.block4(out)
norm_embed, corr = self.norm_n_corr(embedding)
self.loss_inputs['contra'].append(corr)
out = self.dropout(embedding)
out = self.decoder1(out)
out_d2 = self.decoder2(out)
norm_embed, corr = self.norm_n_corr(out_d2)
self.loss_inputs['contra'].append(corr)
out = self.decoder3(out_d2)
out_d4 = self.decoder4(out)
norm_embed, corr = self.norm_n_corr(out_d4)
self.loss_inputs['contra'].append(corr)
out = self.decoder5(out_d4)
pred = self.decoder6(out)
recons_x = F.interpolate(pred, size=x.shape[-2:], mode='bilinear', align_corners=True)
self.loss_inputs['recons'].append(recons_x)
embedding = self.encoder.block5(embedding)
embedding = self.encoder.block6(embedding)
embedding = self.encoder.block7(embedding)
fusion = self.reasoning(embedding, out_d2, out_d4) + embedding
embedding = self.encoder.block8(fusion)
img_att = self.attention(x, recons_x, embedding)
embedding = self.encoder.block9(img_att)
embedding = self.encoder.block10(embedding)
embedding = self.encoder.block11(embedding)
embedding = self.encoder.block12(embedding)
embedding = self.encoder.conv3(embedding)
embedding = self.encoder.bn3(embedding)
embedding = self.encoder.act3(embedding)
embedding = self.encoder.conv4(embedding)
embedding = self.encoder.bn4(embedding)
embedding = self.encoder.act4(embedding)
embedding = self.global_pool(embedding).squeeze()
out = self.dropout(embedding)
return self.fc(out)