deepfake_gi_fastGAN / models.py
vlbthambawita's picture
First
7f49ac7
raw
history blame
13.7 kB
import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm
import torch.nn.functional as F
import random
seq = nn.Sequential
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
try:
m.weight.data.normal_(0.0, 0.02)
except:
pass
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def conv2d(*args, **kwargs):
return spectral_norm(nn.Conv2d(*args, **kwargs))
def convTranspose2d(*args, **kwargs):
return spectral_norm(nn.ConvTranspose2d(*args, **kwargs))
def batchNorm2d(*args, **kwargs):
return nn.BatchNorm2d(*args, **kwargs)
def linear(*args, **kwargs):
return spectral_norm(nn.Linear(*args, **kwargs))
class PixelNorm(nn.Module):
def forward(self, input):
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
class Reshape(nn.Module):
def __init__(self, shape):
super().__init__()
self.target_shape = shape
def forward(self, feat):
batch = feat.shape[0]
return feat.view(batch, *self.target_shape)
class GLU(nn.Module):
def forward(self, x):
nc = x.size(1)
assert nc % 2 == 0, 'channels dont divide 2!'
nc = int(nc/2)
return x[:, :nc] * torch.sigmoid(x[:, nc:])
class NoiseInjection(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1), requires_grad=True)
def forward(self, feat, noise=None):
if noise is None:
batch, _, height, width = feat.shape
noise = torch.randn(batch, 1, height, width).to(feat.device)
return feat + self.weight * noise
class Swish(nn.Module):
def forward(self, feat):
return feat * torch.sigmoid(feat)
class SEBlock(nn.Module):
def __init__(self, ch_in, ch_out):
super().__init__()
self.main = nn.Sequential( nn.AdaptiveAvgPool2d(4),
conv2d(ch_in, ch_out, 4, 1, 0, bias=False), Swish(),
conv2d(ch_out, ch_out, 1, 1, 0, bias=False), nn.Sigmoid() )
def forward(self, feat_small, feat_big):
return feat_big * self.main(feat_small)
class InitLayer(nn.Module):
def __init__(self, nz, channel):
super().__init__()
self.init = nn.Sequential(
convTranspose2d(nz, channel*2, 4, 1, 0, bias=False),
batchNorm2d(channel*2), GLU() )
def forward(self, noise):
noise = noise.view(noise.shape[0], -1, 1, 1)
return self.init(noise)
def UpBlock(in_planes, out_planes):
block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
#convTranspose2d(in_planes, out_planes*2, 4, 2, 1, bias=False),
batchNorm2d(out_planes*2), GLU())
return block
def UpBlockComp(in_planes, out_planes):
block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
#convTranspose2d(in_planes, out_planes*2, 4, 2, 1, bias=False),
NoiseInjection(),
batchNorm2d(out_planes*2), GLU(),
conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False),
NoiseInjection(),
batchNorm2d(out_planes*2), GLU()
)
return block
class Generator(nn.Module):
def __init__(self, ngf=64, nz=100, nc=3, im_size=1024):
super(Generator, self).__init__()
nfc_multi = {4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5, 512:0.25, 1024:0.125}
nfc = {}
for k, v in nfc_multi.items():
nfc[k] = int(v*ngf)
self.im_size = im_size
self.init = InitLayer(nz, channel=nfc[4])
self.feat_8 = UpBlockComp(nfc[4], nfc[8])
self.feat_16 = UpBlock(nfc[8], nfc[16])
self.feat_32 = UpBlockComp(nfc[16], nfc[32])
self.feat_64 = UpBlock(nfc[32], nfc[64])
self.feat_128 = UpBlockComp(nfc[64], nfc[128])
self.feat_256 = UpBlock(nfc[128], nfc[256])
self.se_64 = SEBlock(nfc[4], nfc[64])
self.se_128 = SEBlock(nfc[8], nfc[128])
self.se_256 = SEBlock(nfc[16], nfc[256])
self.to_128 = conv2d(nfc[128], nc, 1, 1, 0, bias=False)
self.to_big = conv2d(nfc[im_size], nc, 3, 1, 1, bias=False)
if im_size > 256:
self.feat_512 = UpBlockComp(nfc[256], nfc[512])
self.se_512 = SEBlock(nfc[32], nfc[512])
if im_size > 512:
self.feat_1024 = UpBlock(nfc[512], nfc[1024])
def forward(self, input):
feat_4 = self.init(input)
feat_8 = self.feat_8(feat_4)
feat_16 = self.feat_16(feat_8)
feat_32 = self.feat_32(feat_16)
feat_64 = self.se_64( feat_4, self.feat_64(feat_32) )
feat_128 = self.se_128( feat_8, self.feat_128(feat_64) )
feat_256 = self.se_256( feat_16, self.feat_256(feat_128) )
if self.im_size == 256:
return [self.to_big(feat_256), self.to_128(feat_128)]
feat_512 = self.se_512( feat_32, self.feat_512(feat_256) )
if self.im_size == 512:
return [self.to_big(feat_512), self.to_128(feat_128)]
feat_1024 = self.feat_1024(feat_512)
im_128 = torch.tanh(self.to_128(feat_128))
im_1024 = torch.tanh(self.to_big(feat_1024))
return [im_1024, im_128]
class DownBlock(nn.Module):
def __init__(self, in_planes, out_planes):
super(DownBlock, self).__init__()
self.main = nn.Sequential(
conv2d(in_planes, out_planes, 4, 2, 1, bias=False),
batchNorm2d(out_planes), nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, feat):
return self.main(feat)
class DownBlockComp(nn.Module):
def __init__(self, in_planes, out_planes):
super(DownBlockComp, self).__init__()
self.main = nn.Sequential(
conv2d(in_planes, out_planes, 4, 2, 1, bias=False),
batchNorm2d(out_planes), nn.LeakyReLU(0.2, inplace=True),
conv2d(out_planes, out_planes, 3, 1, 1, bias=False),
batchNorm2d(out_planes), nn.LeakyReLU(0.2)
)
self.direct = nn.Sequential(
nn.AvgPool2d(2, 2),
conv2d(in_planes, out_planes, 1, 1, 0, bias=False),
batchNorm2d(out_planes), nn.LeakyReLU(0.2))
def forward(self, feat):
return (self.main(feat) + self.direct(feat)) / 2
class Discriminator(nn.Module):
def __init__(self, ndf=64, nc=3, im_size=512):
super(Discriminator, self).__init__()
self.ndf = ndf
self.im_size = im_size
nfc_multi = {4:16, 8:16, 16:8, 32:4, 64:2, 128:1, 256:0.5, 512:0.25, 1024:0.125}
nfc = {}
for k, v in nfc_multi.items():
nfc[k] = int(v*ndf)
if im_size == 1024:
self.down_from_big = nn.Sequential(
conv2d(nc, nfc[1024], 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
conv2d(nfc[1024], nfc[512], 4, 2, 1, bias=False),
batchNorm2d(nfc[512]),
nn.LeakyReLU(0.2, inplace=True))
elif im_size == 512:
self.down_from_big = nn.Sequential(
conv2d(nc, nfc[512], 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True) )
elif im_size == 256:
self.down_from_big = nn.Sequential(
conv2d(nc, nfc[512], 3, 1, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True) )
self.down_4 = DownBlockComp(nfc[512], nfc[256])
self.down_8 = DownBlockComp(nfc[256], nfc[128])
self.down_16 = DownBlockComp(nfc[128], nfc[64])
self.down_32 = DownBlockComp(nfc[64], nfc[32])
self.down_64 = DownBlockComp(nfc[32], nfc[16])
self.rf_big = nn.Sequential(
conv2d(nfc[16] , nfc[8], 1, 1, 0, bias=False),
batchNorm2d(nfc[8]), nn.LeakyReLU(0.2, inplace=True),
conv2d(nfc[8], 1, 4, 1, 0, bias=False))
self.se_2_16 = SEBlock(nfc[512], nfc[64])
self.se_4_32 = SEBlock(nfc[256], nfc[32])
self.se_8_64 = SEBlock(nfc[128], nfc[16])
self.down_from_small = nn.Sequential(
conv2d(nc, nfc[256], 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
DownBlock(nfc[256], nfc[128]),
DownBlock(nfc[128], nfc[64]),
DownBlock(nfc[64], nfc[32]), )
self.rf_small = conv2d(nfc[32], 1, 4, 1, 0, bias=False)
self.decoder_big = SimpleDecoder(nfc[16], nc)
self.decoder_part = SimpleDecoder(nfc[32], nc)
self.decoder_small = SimpleDecoder(nfc[32], nc)
def forward(self, imgs, label, part=None):
if type(imgs) is not list:
imgs = [F.interpolate(imgs, size=self.im_size), F.interpolate(imgs, size=128)]
feat_2 = self.down_from_big(imgs[0])
feat_4 = self.down_4(feat_2)
feat_8 = self.down_8(feat_4)
feat_16 = self.down_16(feat_8)
feat_16 = self.se_2_16(feat_2, feat_16)
feat_32 = self.down_32(feat_16)
feat_32 = self.se_4_32(feat_4, feat_32)
feat_last = self.down_64(feat_32)
feat_last = self.se_8_64(feat_8, feat_last)
#rf_0 = torch.cat([self.rf_big_1(feat_last).view(-1),self.rf_big_2(feat_last).view(-1)])
#rff_big = torch.sigmoid(self.rf_factor_big)
rf_0 = self.rf_big(feat_last).view(-1)
feat_small = self.down_from_small(imgs[1])
#rf_1 = torch.cat([self.rf_small_1(feat_small).view(-1),self.rf_small_2(feat_small).view(-1)])
rf_1 = self.rf_small(feat_small).view(-1)
if label=='real':
rec_img_big = self.decoder_big(feat_last)
rec_img_small = self.decoder_small(feat_small)
assert part is not None
rec_img_part = None
if part==0:
rec_img_part = self.decoder_part(feat_32[:,:,:8,:8])
if part==1:
rec_img_part = self.decoder_part(feat_32[:,:,:8,8:])
if part==2:
rec_img_part = self.decoder_part(feat_32[:,:,8:,:8])
if part==3:
rec_img_part = self.decoder_part(feat_32[:,:,8:,8:])
return torch.cat([rf_0, rf_1]) , [rec_img_big, rec_img_small, rec_img_part]
return torch.cat([rf_0, rf_1])
class SimpleDecoder(nn.Module):
"""docstring for CAN_SimpleDecoder"""
def __init__(self, nfc_in=64, nc=3):
super(SimpleDecoder, self).__init__()
nfc_multi = {4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5, 512:0.25, 1024:0.125}
nfc = {}
for k, v in nfc_multi.items():
nfc[k] = int(v*32)
def upBlock(in_planes, out_planes):
block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
batchNorm2d(out_planes*2), GLU())
return block
self.main = nn.Sequential( nn.AdaptiveAvgPool2d(8),
upBlock(nfc_in, nfc[16]) ,
upBlock(nfc[16], nfc[32]),
upBlock(nfc[32], nfc[64]),
upBlock(nfc[64], nfc[128]),
conv2d(nfc[128], nc, 3, 1, 1, bias=False),
nn.Tanh() )
def forward(self, input):
# input shape: c x 4 x 4
return self.main(input)
from random import randint
def random_crop(image, size):
h, w = image.shape[2:]
ch = randint(0, h-size-1)
cw = randint(0, w-size-1)
return image[:,:,ch:ch+size,cw:cw+size]
class TextureDiscriminator(nn.Module):
def __init__(self, ndf=64, nc=3, im_size=512):
super(TextureDiscriminator, self).__init__()
self.ndf = ndf
self.im_size = im_size
nfc_multi = {4:16, 8:8, 16:8, 32:4, 64:2, 128:1, 256:0.5, 512:0.25, 1024:0.125}
nfc = {}
for k, v in nfc_multi.items():
nfc[k] = int(v*ndf)
self.down_from_small = nn.Sequential(
conv2d(nc, nfc[256], 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
DownBlock(nfc[256], nfc[128]),
DownBlock(nfc[128], nfc[64]),
DownBlock(nfc[64], nfc[32]), )
self.rf_small = nn.Sequential(
conv2d(nfc[16], 1, 4, 1, 0, bias=False))
self.decoder_small = SimpleDecoder(nfc[32], nc)
def forward(self, img, label):
img = random_crop(img, size=128)
feat_small = self.down_from_small(img)
rf = self.rf_small(feat_small).view(-1)
if label=='real':
rec_img_small = self.decoder_small(feat_small)
return rf, rec_img_small, img
return rf