|
import torch
|
|
import math
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
import yaml
|
|
from PIL import Image
|
|
import cv2
|
|
from torchvision import transforms as T
|
|
from skimage import measure
|
|
from skimage.transform import PiecewiseAffineTransform, warp
|
|
from torch.autograd import Variable
|
|
from scipy.ndimage import binary_erosion, binary_dilation
|
|
|
|
from dataset.pair_dataset import pairDataset
|
|
from dataset.utils.color_transfer import color_transfer
|
|
from dataset.utils.faceswap_utils_sladd import blendImages as alpha_blend_fea
|
|
from dataset.utils import faceswap
|
|
|
|
|
|
|
|
class Block(nn.Module):
|
|
def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
|
|
super(Block, self).__init__()
|
|
|
|
if out_filters != in_filters or strides != 1:
|
|
self.skip = nn.Conv2d(in_filters, out_filters,
|
|
1, stride=strides, bias=False)
|
|
self.skipbn = nn.BatchNorm2d(out_filters)
|
|
else:
|
|
self.skip = None
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
rep = []
|
|
|
|
filters = in_filters
|
|
if grow_first:
|
|
rep.append(self.relu)
|
|
rep.append(SeparableConv2d(in_filters, out_filters,
|
|
3, stride=1, padding=1, bias=False))
|
|
rep.append(nn.BatchNorm2d(out_filters))
|
|
filters = out_filters
|
|
|
|
for i in range(reps - 1):
|
|
rep.append(self.relu)
|
|
rep.append(SeparableConv2d(filters, filters,
|
|
3, stride=1, padding=1, bias=False))
|
|
rep.append(nn.BatchNorm2d(filters))
|
|
|
|
if not grow_first:
|
|
rep.append(self.relu)
|
|
rep.append(SeparableConv2d(in_filters, out_filters,
|
|
3, stride=1, padding=1, bias=False))
|
|
rep.append(nn.BatchNorm2d(out_filters))
|
|
|
|
if not start_with_relu:
|
|
rep = rep[1:]
|
|
else:
|
|
rep[0] = nn.ReLU(inplace=False)
|
|
|
|
if strides != 1:
|
|
rep.append(nn.MaxPool2d(3, strides, 1))
|
|
self.rep = nn.Sequential(*rep)
|
|
|
|
def forward(self, inp):
|
|
x = self.rep(inp)
|
|
|
|
if self.skip is not None:
|
|
skip = self.skip(inp)
|
|
skip = self.skipbn(skip)
|
|
else:
|
|
skip = inp
|
|
|
|
x += skip
|
|
return x
|
|
|
|
class SeparableConv2d(nn.Module):
|
|
def __init__(self, c_in, c_out, ks, stride=1, padding=0, dilation=1, bias=False):
|
|
super(SeparableConv2d, self).__init__()
|
|
self.c = nn.Conv2d(c_in, c_in, ks, stride, padding, dilation, groups=c_in, bias=bias)
|
|
self.pointwise = nn.Conv2d(c_in, c_out, 1, 1, 0, 1, 1, bias=bias)
|
|
|
|
def forward(self, x):
|
|
x = self.c(x)
|
|
x = self.pointwise(x)
|
|
return x
|
|
|
|
class Xception_SLADDSyn(nn.Module):
|
|
"""
|
|
Xception optimized for the ImageNet dataset, as specified in
|
|
https://arxiv.org/pdf/1610.02357.pdf
|
|
"""
|
|
|
|
def __init__(self, num_classes=2, num_region=7, num_type=2, num_mag=1, inc=6):
|
|
""" Constructor
|
|
Args:
|
|
num_classes: number of classes
|
|
"""
|
|
super(Xception_SLADDSyn, self).__init__()
|
|
self.num_region = num_region
|
|
self.num_type = num_type
|
|
self.num_mag = num_mag
|
|
dropout = 0.5
|
|
|
|
|
|
self.iniconv = nn.Conv2d(inc, 32, 3, 2, 0, bias=False)
|
|
|
|
self.bn1 = nn.BatchNorm2d(32)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
|
|
self.bn2 = nn.BatchNorm2d(64)
|
|
|
|
|
|
self.block1 = Block(
|
|
64, 128, 2, 2, start_with_relu=False, grow_first=True)
|
|
self.block2 = Block(
|
|
128, 256, 2, 2, start_with_relu=True, grow_first=True)
|
|
self.block3 = Block(
|
|
256, 728, 2, 2, start_with_relu=True, grow_first=True)
|
|
|
|
|
|
self.block4 = Block(
|
|
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
|
self.block5 = Block(
|
|
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
|
self.block6 = Block(
|
|
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
|
self.block7 = Block(
|
|
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
|
|
|
self.block8 = Block(
|
|
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
|
self.block9 = Block(
|
|
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
|
self.block10 = Block(
|
|
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
|
self.block11 = Block(
|
|
728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
|
|
|
|
|
self.block12 = Block(
|
|
728, 1024, 2, 2, start_with_relu=True, grow_first=False)
|
|
|
|
self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
|
|
self.bn3 = nn.BatchNorm2d(1536)
|
|
|
|
|
|
self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
|
|
self.bn4 = nn.BatchNorm2d(2048)
|
|
self.fc_region = nn.Sequential(nn.Dropout(p=dropout), nn.Linear(2048, num_region))
|
|
self.fc_type = nn.Sequential(nn.Dropout(p=dropout), nn.Linear(2048, num_type))
|
|
self.fc_mag = nn.Sequential(nn.Dropout(p=dropout), nn.Linear(2048, num_mag))
|
|
|
|
def fea_part1_0(self, x):
|
|
x = self.iniconv(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
|
|
return x
|
|
|
|
def fea_part1_1(self, x):
|
|
x = self.conv2(x)
|
|
x = self.bn2(x)
|
|
x = self.relu(x)
|
|
|
|
return x
|
|
|
|
def fea_part1(self, x):
|
|
x = self.iniconv(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
|
|
x = self.conv2(x)
|
|
x = self.bn2(x)
|
|
x = self.relu(x)
|
|
|
|
return x
|
|
|
|
def fea_part2(self, x):
|
|
x = self.block1(x)
|
|
x = self.block2(x)
|
|
x = self.block3(x)
|
|
|
|
return x
|
|
|
|
def fea_part3(self, x):
|
|
x = self.block4(x)
|
|
x = self.block5(x)
|
|
x = self.block6(x)
|
|
x = self.block7(x)
|
|
|
|
return x
|
|
|
|
def fea_part4(self, x):
|
|
x = self.block8(x)
|
|
x = self.block9(x)
|
|
x = self.block10(x)
|
|
x = self.block11(x)
|
|
x = self.block12(x)
|
|
|
|
return x
|
|
|
|
def fea_part5(self, x):
|
|
x = self.conv3(x)
|
|
x = self.bn3(x)
|
|
x = self.relu(x)
|
|
|
|
x = self.conv4(x)
|
|
x = self.bn4(x)
|
|
|
|
return x
|
|
|
|
def features(self, input):
|
|
x = self.fea_part1(input)
|
|
|
|
x = self.fea_part2(x)
|
|
x = self.fea_part3(x)
|
|
x = self.fea_part4(x)
|
|
|
|
x = self.fea_part5(x)
|
|
return x
|
|
|
|
def classifier(self, features):
|
|
x = self.relu(features)
|
|
|
|
x = F.adaptive_avg_pool2d(x, (1, 1))
|
|
x = x.view(x.size(0), -1)
|
|
out = self.last_linear(x)
|
|
return out, x
|
|
|
|
def forward(self, input):
|
|
x = self.features(input)
|
|
x = self.relu(x)
|
|
x = F.adaptive_avg_pool2d(x, (1, 1))
|
|
x = x.view(x.size(0), -1)
|
|
|
|
region_num = self.fc_region(x)
|
|
type_num = self.fc_type(x)
|
|
mag = self.fc_mag(x)
|
|
|
|
return region_num, type_num, mag
|
|
|
|
|
|
def mask_postprocess(mask):
|
|
def blur_mask(mask):
|
|
blur_k = 2 * np.random.randint(1, 10) - 1
|
|
|
|
|
|
|
|
|
|
mask = cv2.GaussianBlur(mask, (blur_k, blur_k), 0)
|
|
|
|
return mask
|
|
|
|
|
|
prob = np.random.rand()
|
|
if prob < 0.3:
|
|
erode_k = 2 * np.random.randint(1, 10) + 1
|
|
kernel = np.ones((erode_k, erode_k), np.uint8)
|
|
mask = cv2.erode(mask, kernel)
|
|
elif prob < 0.6:
|
|
erode_k = 2 * np.random.randint(1, 10) + 1
|
|
kernel = np.ones((erode_k, erode_k), np.uint8)
|
|
mask = cv2.dilate(mask, kernel)
|
|
|
|
|
|
if np.random.rand() < 0.9:
|
|
mask = blur_mask(mask)
|
|
|
|
return mask
|
|
|
|
def xception(num_region=7, num_type=2, num_mag=1, pretrained='imagenet', inc=6):
|
|
model = Xception_SLADDSyn(num_region=num_region, num_type=num_type, num_mag=num_mag, inc=inc)
|
|
return model
|
|
|
|
|
|
|
|
class TransferModel(nn.Module):
|
|
"""
|
|
Simple transfer learning model that takes an imagenet pretrained model with
|
|
a fc layer as base model and retrains a new fc layer for num_out_classes
|
|
"""
|
|
|
|
def __init__(self, config, num_region=7, num_type=2, num_mag=1, return_fea=False, inc=6):
|
|
super(TransferModel, self).__init__()
|
|
self.return_fea = return_fea
|
|
def return_pytorch04_xception(pretrained=True):
|
|
|
|
model = xception(num_region=num_region, num_type=num_type, num_mag=num_mag, inc=inc, pretrained=False)
|
|
if pretrained:
|
|
|
|
|
|
|
|
state_dict = torch.load(config['pretrained'])
|
|
print('Loaded pretrained model (ImageNet)....')
|
|
for name, weights in state_dict.items():
|
|
if 'pointwise' in name:
|
|
state_dict[name] = weights.unsqueeze(
|
|
-1).unsqueeze(-1)
|
|
model.load_state_dict(state_dict, strict=False)
|
|
|
|
|
|
return model
|
|
|
|
self.model = return_pytorch04_xception()
|
|
|
|
|
|
if inc != 3:
|
|
self.model.iniconv = nn.Conv2d(inc, 32, 3, 2, 0, bias=False)
|
|
nn.init.xavier_normal(self.model.iniconv.weight.data, gain=0.02)
|
|
|
|
def set_trainable_up_to(self, boolean=False, layername="Conv2d_4a_3x3"):
|
|
"""
|
|
Freezes all layers below a specific layer and sets the following layers
|
|
to true if boolean else only the fully connected final layer
|
|
:param boolean:
|
|
:param layername: depends on lib, for inception e.g. Conv2d_4a_3x3
|
|
:return:
|
|
"""
|
|
|
|
if layername is None:
|
|
for i, param in self.model.named_parameters():
|
|
param.requires_grad = True
|
|
return
|
|
else:
|
|
for i, param in self.model.named_parameters():
|
|
param.requires_grad = False
|
|
if boolean:
|
|
|
|
ct = []
|
|
found = False
|
|
for name, child in self.model.named_children():
|
|
if layername in ct:
|
|
found = True
|
|
for params in child.parameters():
|
|
params.requires_grad = True
|
|
ct.append(name)
|
|
if not found:
|
|
raise NotImplementedError('Layer not found, cant finetune!'.format(
|
|
layername))
|
|
else:
|
|
|
|
for param in self.model.last_linear.parameters():
|
|
param.requires_grad = True
|
|
|
|
def forward(self, x):
|
|
region_num, type_num, mag = self.model(x)
|
|
return region_num, type_num, mag
|
|
|
|
def features(self, x):
|
|
x = self.model.features(x)
|
|
return x
|
|
|
|
def classifier(self, x):
|
|
out, x = self.model.classifier(x)
|
|
return out, x
|
|
|
|
|
|
|
|
def dist(p1, p2):
|
|
return math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
|
|
|
|
|
|
def generate_random_mask(mask, res=256):
|
|
randwl = np.random.randint(10, 60)
|
|
randwr = np.random.randint(10, 60)
|
|
randhu = np.random.randint(10, 60)
|
|
randhd = np.random.randint(10, 60)
|
|
newmask = np.zeros(mask.shape)
|
|
mask = np.where(mask > 0.1, 1, 0)
|
|
props = measure.regionprops(mask)
|
|
if len(props) == 0:
|
|
return newmask
|
|
center_x, center_y = props[0].centroid
|
|
center_x = int(round(center_x))
|
|
center_y = int(round(center_y))
|
|
newmask[max(center_x - randwl, 0):min(center_x + randwr, res - 1),
|
|
max(center_y - randhu, 0):min(center_x + randhd, res - 1)] = 1
|
|
newmask *= mask
|
|
return newmask
|
|
|
|
|
|
def random_deform(mask, nrows, ncols, mean=0, std=10):
|
|
h, w = mask.shape[:2]
|
|
rows = np.linspace(0, h - 1, nrows).astype(np.int32)
|
|
cols = np.linspace(0, w - 1, ncols).astype(np.int32)
|
|
rows += np.random.normal(mean, std, size=rows.shape).astype(np.int32)
|
|
rows += np.random.normal(mean, std, size=cols.shape).astype(np.int32)
|
|
rows, cols = np.meshgrid(rows, cols)
|
|
anchors = np.vstack([rows.flat, cols.flat]).T
|
|
assert anchors.shape[1] == 2 and anchors.shape[0] == ncols * nrows
|
|
deformed = anchors + np.random.normal(mean, std, size=anchors.shape)
|
|
np.clip(deformed[:, 0], 0, h - 1, deformed[:, 0])
|
|
np.clip(deformed[:, 1], 0, w - 1, deformed[:, 1])
|
|
|
|
trans = PiecewiseAffineTransform()
|
|
trans.estimate(anchors, deformed.astype(np.int32))
|
|
warped = warp(mask, trans)
|
|
warped *= mask
|
|
blured = cv2.GaussianBlur(warped.astype(float), (5, 5), 3)
|
|
return blured
|
|
|
|
|
|
def get_five_key(landmarks_68):
|
|
|
|
leye_center = (landmarks_68[36] + landmarks_68[39]) * 0.5
|
|
reye_center = (landmarks_68[42] + landmarks_68[45]) * 0.5
|
|
nose = landmarks_68[33]
|
|
lmouth = landmarks_68[48]
|
|
rmouth = landmarks_68[54]
|
|
leye_left = landmarks_68[36]
|
|
leye_right = landmarks_68[39]
|
|
reye_left = landmarks_68[42]
|
|
reye_right = landmarks_68[45]
|
|
out = [tuple(x.astype('int32')) for x in [
|
|
leye_center, reye_center, nose, lmouth, rmouth, leye_left, leye_right, reye_left, reye_right
|
|
]]
|
|
return out
|
|
|
|
|
|
def remove_eyes(image, landmarks, opt):
|
|
|
|
if opt == 'l':
|
|
(x1, y1), (x2, y2) = landmarks[5:7]
|
|
elif opt == 'r':
|
|
(x1, y1), (x2, y2) = landmarks[7:9]
|
|
elif opt == 'b':
|
|
(x1, y1), (x2, y2) = landmarks[:2]
|
|
else:
|
|
print('wrong region')
|
|
mask = np.zeros_like(image[..., 0])
|
|
line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2)
|
|
w = dist((x1, y1), (x2, y2))
|
|
dilation = int(w // 4)
|
|
if opt != 'b':
|
|
dilation *= 4
|
|
line = binary_dilation(line, iterations=dilation)
|
|
return line
|
|
|
|
|
|
def remove_nose(image, landmarks):
|
|
(x1, y1), (x2, y2) = landmarks[:2]
|
|
x3, y3 = landmarks[2]
|
|
mask = np.zeros_like(image[..., 0])
|
|
x4 = int((x1 + x2) / 2)
|
|
y4 = int((y1 + y2) / 2)
|
|
line = cv2.line(mask, (x3, y3), (x4, y4), color=(1), thickness=2)
|
|
w = dist((x1, y1), (x2, y2))
|
|
dilation = int(w // 4)
|
|
line = binary_dilation(line, iterations=dilation)
|
|
return line
|
|
|
|
|
|
def remove_mouth(image, landmarks):
|
|
(x1, y1), (x2, y2) = landmarks[3:5]
|
|
mask = np.zeros_like(image[..., 0])
|
|
line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2)
|
|
w = dist((x1, y1), (x2, y2))
|
|
dilation = int(w // 3)
|
|
line = binary_dilation(line, iterations=dilation)
|
|
return line
|
|
|
|
|
|
def blend_fake_to_real(realimg, real_lmk, fakeimg, fakemask, fake_lmk, deformed_fakemask, type, mag):
|
|
|
|
|
|
realimg = ((realimg + 1) / 2 * 255).astype(np.uint8)
|
|
fakeimg = ((fakeimg + 1) / 2 * 255).astype(np.uint8)
|
|
H, W, C = realimg.shape
|
|
|
|
aligned_src = fakeimg
|
|
src_mask = deformed_fakemask
|
|
src_mask = src_mask > 0
|
|
|
|
tgt_mask = np.asarray(src_mask, dtype=np.uint8)
|
|
tgt_mask = mask_postprocess(tgt_mask)
|
|
|
|
ct_modes = ['rct-m', 'rct-fs', 'avg-align', 'faceswap']
|
|
mode_idx = np.random.randint(len(ct_modes))
|
|
mode = ct_modes[mode_idx]
|
|
|
|
if mode != 'faceswap':
|
|
c_mask = tgt_mask / 255.
|
|
c_mask[c_mask > 0] = 1
|
|
if len(c_mask.shape) < 3:
|
|
c_mask = np.expand_dims(c_mask, 2)
|
|
src_crop = color_transfer(mode, aligned_src, realimg, c_mask)
|
|
else:
|
|
c_mask = tgt_mask.copy()
|
|
c_mask[c_mask > 0] = 255
|
|
masked_tgt = faceswap.apply_mask(realimg, c_mask)
|
|
masked_src = faceswap.apply_mask(aligned_src, c_mask)
|
|
src_crop = faceswap.correct_colours(masked_tgt, masked_src, np.array(real_lmk))
|
|
|
|
if tgt_mask.mean() < 0.005 or src_crop.max() == 0:
|
|
out_blend = realimg
|
|
else:
|
|
if type == 0:
|
|
out_blend, a_mask = alpha_blend_fea(src_crop, realimg, tgt_mask,
|
|
featherAmount=0.2 * np.random.rand())
|
|
elif type == 1:
|
|
b_mask = (tgt_mask * 255).astype(np.uint8)
|
|
l, t, w, h = cv2.boundingRect(b_mask)
|
|
center = (int(l + w / 2), int(t + h / 2))
|
|
out_blend = cv2.seamlessClone(src_crop, realimg, b_mask, center, cv2.NORMAL_CLONE)
|
|
else:
|
|
out_blend = copy_fake_to_real(realimg, src_crop, tgt_mask, mag)
|
|
|
|
return out_blend, tgt_mask
|
|
|
|
|
|
def copy_fake_to_real(realimg, fakeimg, mask, mag):
|
|
mask = np.expand_dims(mask, 2)
|
|
newimg = fakeimg * mask * mag + realimg * (1 - mask) + realimg * mask * (1 - mag)
|
|
return newimg
|
|
|
|
|
|
class synthesizer(nn.Module):
|
|
def __init__(self,config):
|
|
super(synthesizer, self).__init__()
|
|
self.netG = TransferModel(config=config,num_region=10, num_type=4, num_mag=1, inc=6)
|
|
normalize = T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
self.transforms = T.Compose([T.ToTensor(), normalize])
|
|
|
|
def parse(self, img, reg, real_lmk, fakemask):
|
|
five_key = get_five_key(real_lmk)
|
|
if reg == 0:
|
|
mask = remove_eyes(img, five_key, 'l')
|
|
elif reg == 1:
|
|
mask = remove_eyes(img, five_key, 'r')
|
|
elif reg == 2:
|
|
mask = remove_eyes(img, five_key, 'b')
|
|
elif reg == 3:
|
|
mask = remove_nose(img, five_key)
|
|
elif reg == 4:
|
|
mask = remove_mouth(img, five_key)
|
|
elif reg == 5:
|
|
mask = remove_nose(img, five_key) + remove_eyes(img, five_key, 'l')
|
|
elif reg == 6:
|
|
mask = remove_nose(img, five_key) + remove_eyes(img, five_key, 'r')
|
|
elif reg == 7:
|
|
mask = remove_nose(img, five_key) + remove_eyes(img, five_key, 'b')
|
|
elif reg == 8:
|
|
mask = remove_nose(img, five_key) + remove_mouth(img, five_key)
|
|
elif reg == 9:
|
|
mask = remove_eyes(img, five_key, 'b') + remove_nose(img, five_key) + remove_mouth(img, five_key)
|
|
else:
|
|
mask = generate_random_mask(fakemask)
|
|
mask = random_deform(mask, 5, 5)
|
|
return mask * 1.0
|
|
|
|
def get_variable(self, inputs, cuda=False, **kwargs):
|
|
if type(inputs) in [list, np.ndarray]:
|
|
inputs = torch.Tensor(inputs)
|
|
if cuda:
|
|
out = Variable(inputs.cuda(), **kwargs)
|
|
else:
|
|
out = Variable(inputs, **kwargs)
|
|
return out
|
|
|
|
def calculate(self, logits):
|
|
if logits.shape[1] != 1:
|
|
probs = F.softmax(logits, dim=-1)
|
|
log_prob = F.log_softmax(logits, dim=-1)
|
|
entropy = -(log_prob * probs).sum(1, keepdim=False)
|
|
action = probs.multinomial(num_samples=1).data
|
|
selected_log_prob = log_prob.gather(1, self.get_variable(action, requires_grad=False))
|
|
else:
|
|
probs = torch.sigmoid(logits)
|
|
log_prob = torch.log(torch.sigmoid(logits))
|
|
entropy = -(log_prob * probs).sum(1, keepdim=False)
|
|
action = probs
|
|
selected_log_prob = log_prob
|
|
return entropy, selected_log_prob[:, 0], action[:, 0]
|
|
|
|
def forward(self, img, fake_img, real_lmk, fake_lmk, real_mask, fake_mask, label=None):
|
|
|
|
region_num, type_num, mag = self.netG(torch.cat((img, fake_img), 1))
|
|
reg_etp, reg_log_prob, reg = self.calculate(region_num)
|
|
type_etp, type_log_prob, type = self.calculate(type_num)
|
|
mag_etp, mag_log_prob, mag = self.calculate(mag)
|
|
entropy = reg_etp + type_etp + mag_etp
|
|
log_prob = reg_log_prob + type_log_prob + mag_log_prob
|
|
newlabel = []
|
|
typelabel = []
|
|
maglabel = []
|
|
magmask = []
|
|
|
|
alt_img = torch.ones(img.shape)
|
|
alt_mask = np.zeros((img.shape[0], 16, 16))
|
|
if label is None:
|
|
label=np.zeros(img.shape[0])
|
|
for i in range(img.shape[0]):
|
|
imgcp = np.transpose(img[i].cpu().numpy(), (1, 2, 0)).copy()
|
|
fake_imgcp = np.transpose(fake_img[i].cpu().numpy(), (1, 2, 0)).copy()
|
|
|
|
if label[i] == 0 and type[i] != 3:
|
|
mask = self.parse(fake_imgcp, reg[i], fake_lmk[i].cpu().numpy(),
|
|
fake_mask[i].cpu().numpy())
|
|
newimg, newmask = blend_fake_to_real(imgcp, real_lmk[i].cpu().numpy(),
|
|
fake_imgcp, fake_mask.cpu().numpy(),
|
|
fake_lmk[i].cpu().numpy(), mask, type[i],
|
|
mag[i].detach().cpu().numpy())
|
|
newimg = self.transforms(Image.fromarray(np.array(newimg, dtype=np.uint8)))
|
|
newlabel.append(int(1))
|
|
typelabel.append(int(type[i].cpu().numpy()))
|
|
if type[i] == 2:
|
|
magmask.append(int(1))
|
|
else:
|
|
magmask.append(int(0))
|
|
else:
|
|
newimg = self.transforms(Image.fromarray(np.array((imgcp + 1) / 2 * 255, dtype=np.uint8)))
|
|
newmask =real_mask[i].squeeze(2)[:,:,0].cpu().numpy()
|
|
newlabel.append(int(label[i]))
|
|
if label[i] == 0:
|
|
typelabel.append(int(3))
|
|
else:
|
|
typelabel.append(int(4))
|
|
magmask.append(int(0))
|
|
if newmask is None:
|
|
newmask = np.zeros((16, 16))
|
|
newmask = cv2.resize(newmask, (16, 16), interpolation=cv2.INTER_CUBIC)
|
|
alt_img[i] = newimg
|
|
alt_mask[i] = newmask
|
|
|
|
alt_mask = torch.from_numpy(alt_mask.astype(np.float32)).unsqueeze(1)
|
|
newlabel = torch.tensor(newlabel)
|
|
typelabel = torch.tensor(typelabel)
|
|
maglabel = mag
|
|
magmask = torch.tensor(magmask)
|
|
return log_prob, entropy, alt_img.detach(), alt_mask.detach(), \
|
|
newlabel.detach(), typelabel.detach(), maglabel.detach(), magmask.detach()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
with open(r'H:\code\DeepfakeBench\training\config\detector\sladd_xception.yaml', 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
syn=synthesizer(config=config).cuda()
|
|
config['data_manner'] = 'lmdb'
|
|
config['dataset_json_folder'] = 'preprocessing/dataset_json_v3'
|
|
config['sample_size']=256
|
|
config['with_mask']=True
|
|
config['with_landmark']=True
|
|
config['use_data_augmentation']=True
|
|
config['data_aug']['rotate_prob']=1
|
|
train_set = pairDataset(config=config, mode='train')
|
|
train_data_loader = \
|
|
torch.utils.data.DataLoader(
|
|
dataset=train_set,
|
|
batch_size=config['train_batchSize'],
|
|
shuffle=True,
|
|
num_workers=0,
|
|
collate_fn=train_set.collate_fn,
|
|
)
|
|
from tqdm import tqdm
|
|
for iteration, batch in enumerate(tqdm(train_data_loader)):
|
|
print(iteration)
|
|
imgs,lmks,msks=batch['image'].cuda(),batch['landmark'].cuda(),batch['mask'].cuda()
|
|
half = len(imgs) // 2
|
|
img, fake_img, real_lmk, fake_lmk, real_mask, fake_mask = imgs[:half],imgs[half:],lmks[:half],lmks[half:],msks[:half],msks[half:]
|
|
log_prob, entropy, new_img, alt_mask, label, type_label, mag_label, mag_mask = \
|
|
syn(img, fake_img, real_lmk, fake_lmk, real_mask, fake_mask)
|
|
|
|
if iteration > 10:
|
|
break
|
|
... |