anyantudre's picture
moved from training repo to inference
caa56d6
'''
# author: Zhiyuan Yan
# email: [email protected]
# date: 2023-03-30
The code is designed for Face X-ray.
'''
import os
import sys
import json
import pickle
import time
import lmdb
import numpy as np
import albumentations as A
import cv2
import random
from PIL import Image
from skimage.util import random_noise
from scipy import linalg
import heapq as hq
import lmdb
import torch
from torch.autograd import Variable
from torch.utils import data
from torchvision import transforms as T
import torchvision
from dataset.utils.face_blend import *
from dataset.utils.face_align import get_align_mat_new
from dataset.utils.color_transfer import color_transfer
from dataset.utils.faceswap_utils import blendImages as alpha_blend_fea
from dataset.utils.faceswap_utils import AlphaBlend as alpha_blend
from dataset.utils.face_aug import aug_one_im, change_res
from dataset.utils.image_ae import get_pretraiend_ae
from dataset.utils.warp import warp_mask
from dataset.utils import faceswap
from scipy.ndimage.filters import gaussian_filter
class RandomDownScale(A.core.transforms_interface.ImageOnlyTransform):
def apply(self,img,**params):
return self.randomdownscale(img)
def randomdownscale(self,img):
keep_ratio=True
keep_input_shape=True
H,W,C=img.shape
ratio_list=[2,4]
r=ratio_list[np.random.randint(len(ratio_list))]
img_ds=cv2.resize(img,(int(W/r),int(H/r)),interpolation=cv2.INTER_NEAREST)
if keep_input_shape:
img_ds=cv2.resize(img_ds,(W,H),interpolation=cv2.INTER_LINEAR)
return img_ds
class FFBlendDataset(data.Dataset):
def __init__(self, config=None):
self.lmdb = config.get('lmdb', False)
if self.lmdb:
lmdb_path = os.path.join(config['lmdb_dir'], f"FaceForensics++_lmdb")
self.env = lmdb.open(lmdb_path, create=False, subdir=True, readonly=True, lock=False)
# Check if the dictionary has already been created
if os.path.exists('training/lib/nearest_face_info.pkl'):
with open('training/lib/nearest_face_info.pkl', 'rb') as f:
face_info = pickle.load(f)
else:
raise ValueError(f"Need to run the dataset/generate_xray_nearest.py before training the face xray.")
self.face_info = face_info
# Check if the dictionary has already been created
if os.path.exists('training/lib/landmark_dict_ffall.pkl'):
with open('training/lib/landmark_dict_ffall.pkl', 'rb') as f:
landmark_dict = pickle.load(f)
self.landmark_dict = landmark_dict
self.imid_list = self.get_training_imglist()
self.transforms = T.Compose([
# T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
# T.ColorJitter(hue=.05, saturation=.05),
# T.RandomHorizontalFlip(),
# T.RandomRotation(20, resample=Image.BILINEAR),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
])
self.data_dict = {
'imid_list': self.imid_list
}
self.config=config
# def data_aug(self, im):
# """
# Apply data augmentation on the input image.
# """
# transform = T.Compose([
# T.ToPILImage(),
# T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
# T.ColorJitter(hue=.05, saturation=.05),
# ])
# # Apply transformations
# im_aug = transform(im)
# return im_aug
def blended_aug(self, im):
transform = A.Compose([
A.RGBShift((-20,20),(-20,20),(-20,20),p=0.3),
A.HueSaturationValue(hue_shift_limit=(-0.3,0.3), sat_shift_limit=(-0.3,0.3), val_shift_limit=(-0.3,0.3), p=0.3),
A.RandomBrightnessContrast(brightness_limit=(-0.3,0.3), contrast_limit=(-0.3,0.3), p=0.3),
A.ImageCompression(quality_lower=40, quality_upper=100,p=0.5)
])
# Apply transformations
im_aug = transform(image=im)
return im_aug['image']
def data_aug(self, im):
"""
Apply data augmentation on the input image using albumentations.
"""
transform = A.Compose([
A.Compose([
A.RGBShift((-20,20),(-20,20),(-20,20),p=0.3),
A.HueSaturationValue(hue_shift_limit=(-0.3,0.3), sat_shift_limit=(-0.3,0.3), val_shift_limit=(-0.3,0.3), p=1),
A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1,0.1), p=1),
],p=1),
A.OneOf([
RandomDownScale(p=1),
A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=1),
],p=1),
], p=1.)
# Apply transformations
im_aug = transform(image=im)
return im_aug['image']
def get_training_imglist(self):
"""
Get the list of training images.
"""
random.seed(1024) # Fix the random seed for reproducibility
imid_list = list(self.landmark_dict.keys())
# imid_list = [imid.replace('landmarks', 'frames').replace('npy', 'png') for imid in imid_list]
random.shuffle(imid_list)
return imid_list
def load_rgb(self, file_path):
"""
Load an RGB image from a file path and resize it to a specified resolution.
Args:
file_path: A string indicating the path to the image file.
Returns:
An Image object containing the loaded and resized image.
Raises:
ValueError: If the loaded image is None.
"""
size = self.config['resolution'] # if self.mode == "train" else self.config['resolution']
if not self.lmdb:
if not file_path[0] == '.':
file_path = f'./{self.config["rgb_dir"]}\\'+file_path
assert os.path.exists(file_path), f"{file_path} does not exist"
img = cv2.imread(file_path)
if img is None:
raise ValueError('Loaded image is None: {}'.format(file_path))
elif self.lmdb:
with self.env.begin(write=False) as txn:
# transfer the path format from rgb-path to lmdb-key
if file_path[0]=='.':
file_path=file_path.replace('./datasets\\','')
image_bin = txn.get(file_path.encode())
image_buf = np.frombuffer(image_bin, dtype=np.uint8)
img = cv2.imdecode(image_buf, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
return np.array(img, dtype=np.uint8)
def load_mask(self, file_path):
"""
Load a binary mask image from a file path and resize it to a specified resolution.
Args:
file_path: A string indicating the path to the mask file.
Returns:
A numpy array containing the loaded and resized mask.
Raises:
None.
"""
size = self.config['resolution']
if file_path is None:
if not file_path[0] == '.':
file_path = f'./{self.config["rgb_dir"]}\\'+file_path
return np.zeros((size, size, 1))
if not self.lmdb:
if os.path.exists(file_path):
mask = cv2.imread(file_path, 0)
if mask is None:
mask = np.zeros((size, size))
else:
return np.zeros((size, size, 1))
else:
with self.env.begin(write=False) as txn:
# transfer the path format from rgb-path to lmdb-key
if file_path[0]=='.':
file_path=file_path.replace('./datasets\\','')
image_bin = txn.get(file_path.encode())
image_buf = np.frombuffer(image_bin, dtype=np.uint8)
# cv2.IMREAD_GRAYSCALE为灰度图,cv2.IMREAD_COLOR为彩色图
mask = cv2.imdecode(image_buf, cv2.IMREAD_COLOR)
mask = cv2.resize(mask, (size, size)) / 255
mask = np.expand_dims(mask, axis=2)
return np.float32(mask)
def load_landmark(self, file_path):
"""
Load 2D facial landmarks from a file path.
Args:
file_path: A string indicating the path to the landmark file.
Returns:
A numpy array containing the loaded landmarks.
Raises:
None.
"""
if file_path is None:
return np.zeros((81, 2))
if not self.lmdb:
if not file_path[0] == '.':
file_path = f'./{self.config["rgb_dir"]}\\'+file_path
if os.path.exists(file_path):
landmark = np.load(file_path)
else:
return np.zeros((81, 2))
else:
with self.env.begin(write=False) as txn:
# transfer the path format from rgb-path to lmdb-key
if file_path[0]=='.':
file_path=file_path.replace('./datasets\\','')
binary = txn.get(file_path.encode())
landmark = np.frombuffer(binary, dtype=np.uint32).reshape((81, 2))
return np.float32(landmark)
def preprocess_images(self, imid_fg, imid_bg):
"""
Load foreground and background images and face shapes.
"""
fg_im = self.load_rgb(imid_fg.replace('landmarks', 'frames').replace('npy', 'png'))
fg_im = np.array(self.data_aug(fg_im))
fg_shape = self.landmark_dict[imid_fg]
fg_shape = np.array(fg_shape, dtype=np.int32)
bg_im = self.load_rgb(imid_bg.replace('landmarks', 'frames').replace('npy', 'png'))
bg_im = np.array(self.data_aug(bg_im))
bg_shape = self.landmark_dict[imid_bg]
bg_shape = np.array(bg_shape, dtype=np.int32)
if fg_im is None:
return bg_im, bg_shape, bg_im, bg_shape
elif bg_im is None:
return fg_im, fg_shape, fg_im, fg_shape
return fg_im, fg_shape, bg_im, bg_shape
def get_fg_bg(self, one_lmk_path):
"""
Get foreground and background paths
"""
bg_lmk_path = one_lmk_path
# Randomly pick one from the nearest neighbors for the foreground
if bg_lmk_path in self.face_info:
fg_lmk_path = random.choice(self.face_info[bg_lmk_path])
else:
fg_lmk_path = bg_lmk_path
return fg_lmk_path, bg_lmk_path
def generate_masks(self, fg_im, fg_shape, bg_im, bg_shape):
"""
Generate masks for foreground and background images.
"""
fg_mask = get_mask(fg_shape, fg_im, deform=False)
bg_mask = get_mask(bg_shape, bg_im, deform=True)
# # Only do the postprocess for the background mask
bg_mask_postprocess = warp_mask(bg_mask, std=20)
return fg_mask, bg_mask_postprocess
def warp_images(self, fg_im, fg_shape, bg_im, bg_shape, fg_mask):
"""
Warp foreground face onto background image using affine or 3D warping.
"""
H, W, C = bg_im.shape
use_3d_warp = np.random.rand() < 0.5
if not use_3d_warp:
aff_param = np.array(get_align_mat_new(fg_shape, bg_shape)).reshape(2, 3)
warped_face = cv2.warpAffine(fg_im, aff_param, (W, H), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REFLECT)
fg_mask = cv2.warpAffine(fg_mask, aff_param, (W, H), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REFLECT)
fg_mask = fg_mask > 0
else:
warped_face = faceswap.warp_image_3d(fg_im, np.array(fg_shape[:48]), np.array(bg_shape[:48]), (H, W))
fg_mask = np.mean(warped_face, axis=2) > 0
return warped_face, fg_mask
def colorTransfer(self, src, dst, mask):
transferredDst = np.copy(dst)
maskIndices = np.where(mask != 0)
maskedSrc = src[maskIndices[0], maskIndices[1]].astype(np.float32)
maskedDst = dst[maskIndices[0], maskIndices[1]].astype(np.float32)
# Compute means and standard deviations
meanSrc = np.mean(maskedSrc, axis=0)
stdSrc = np.std(maskedSrc, axis=0)
meanDst = np.mean(maskedDst, axis=0)
stdDst = np.std(maskedDst, axis=0)
# Perform color transfer
maskedDst = (maskedDst - meanDst) * (stdSrc / stdDst) + meanSrc
maskedDst = np.clip(maskedDst, 0, 255)
# Copy the entire background into transferredDst
transferredDst = np.copy(dst)
# Now apply color transfer only to the masked region
transferredDst[maskIndices[0], maskIndices[1]] = maskedDst.astype(np.uint8)
return transferredDst
def blend_images(self, color_corrected_fg, bg_im, bg_mask, featherAmount=0.2):
"""
Blend foreground and background images together.
"""
# normalize the mask to have values between 0 and 1
b_mask = bg_mask / 255.
# Add an extra dimension and repeat the mask to match the number of channels in color_corrected_fg and bg_im
b_mask = np.repeat(b_mask[:, :, np.newaxis], 3, axis=2)
# Compute the alpha blending
maskIndices = np.where(b_mask != 0)
maskPts = np.hstack((maskIndices[1][:, np.newaxis], maskIndices[0][:, np.newaxis]))
# FIXME: deal with the bugs of empty maskpts
if maskPts.size == 0:
print(f"No non-zero values found in bg_mask for blending. Skipping this image.")
return color_corrected_fg # or handle this situation differently according to the needs
faceSize = np.max(maskPts, axis=0) - np.min(maskPts, axis=0)
featherAmount = featherAmount * np.max(faceSize)
hull = cv2.convexHull(maskPts)
dists = np.zeros(maskPts.shape[0])
for i in range(maskPts.shape[0]):
dists[i] = cv2.pointPolygonTest(hull, (int(maskPts[i, 0]), int(maskPts[i, 1])), True)
weights = np.clip(dists / featherAmount, 0, 1)
# Perform the blending operation
color_corrected_fg = color_corrected_fg.astype(float)
bg_im = bg_im.astype(float)
blended_image = np.copy(bg_im)
blended_image[maskIndices[0], maskIndices[1]] = weights[:, np.newaxis] * color_corrected_fg[maskIndices[0], maskIndices[1]] + (1 - weights[:, np.newaxis]) * bg_im[maskIndices[0], maskIndices[1]]
# Convert the blended image to 8-bit unsigned integers
blended_image = np.clip(blended_image, 0, 255)
blended_image = blended_image.astype(np.uint8)
return blended_image
def process_images(self, imid_fg, imid_bg, index):
"""
Overview:
Process foreground and background images following the data generation pipeline (BI dataset).
Terminology:
Foreground (fg) image: The image containing the face that will be blended onto the background image.
Background (bg) image: The image onto which the face from the foreground image will be blended.
"""
fg_im, fg_shape, bg_im, bg_shape = self.preprocess_images(imid_fg, imid_bg)
fg_mask, bg_mask = self.generate_masks(fg_im, fg_shape, bg_im, bg_shape)
warped_face, fg_mask = self.warp_images(fg_im, fg_shape, bg_im, bg_shape, fg_mask)
try:
# add the below two lines to make sure the bg_mask is strictly within the fg_mask
bg_mask[fg_mask == 0] = 0
color_corrected_fg = self.colorTransfer(bg_im, warped_face, bg_mask)
blended_image = self.blend_images(color_corrected_fg, bg_im, bg_mask)
# FIXME: ugly, in order to fix the problem of mask (all zero values for bg_mask)
except:
color_corrected_fg = self.colorTransfer(bg_im, warped_face, bg_mask)
blended_image = self.blend_images(color_corrected_fg, bg_im, bg_mask)
boundary = get_boundary(bg_mask)
# # Prepare images and titles for the combined image
# images = [fg_im, np.where(fg_mask>0, 255, 0), bg_im, bg_mask, color_corrected_fg, blended_image, np.where(boundary>0, 255, 0)]
# titles = ["Fg Image", "Fg Mask", "Bg Image",
# "Bg Mask", "Blended Region",
# "Blended Image", "Boundary"]
# # Save the combined image
# os.makedirs('facexray_examples_3', exist_ok=True)
# self.save_combined_image(images, titles, index, f'facexray_examples_3/combined_image_{index}.png')
return blended_image, boundary, bg_im
def post_proc(self, img):
'''
if self.mode == 'train':
#if np.random.rand() < 0.5:
# img = random_add_noise(img)
#add_gaussian_noise(img)
if np.random.rand() < 0.5:
#img, _ = change_res(img)
img = gaussian_blur(img)
'''
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
im_aug = self.blended_aug(img)
im_aug = Image.fromarray(np.uint8(img))
im_aug = self.transforms(im_aug)
return im_aug
@staticmethod
def save_combined_image(images, titles, index, save_path):
"""
Save the combined image with titles for each single image.
Args:
images (List[np.ndarray]): List of images to be combined.
titles (List[str]): List of titles for each image.
index (int): Index of the image.
save_path (str): Path to save the combined image.
"""
# Determine the maximum height and width among the images
max_height = max(image.shape[0] for image in images)
max_width = max(image.shape[1] for image in images)
# Create the canvas
canvas = np.zeros((max_height * len(images), max_width, 3), dtype=np.uint8)
# Place the images and titles on the canvas
current_height = 0
for image, title in zip(images, titles):
height, width = image.shape[:2]
# Check if image has a third dimension (color channels)
if image.ndim == 2:
# If not, add a third dimension
image = np.tile(image[..., None], (1, 1, 3))
canvas[current_height : current_height + height, :width] = image
cv2.putText(
canvas, title, (10, current_height + 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2
)
current_height += height
# Save the combined image
cv2.imwrite(save_path, canvas)
def __getitem__(self, index):
"""
Get an item from the dataset by index.
"""
one_lmk_path = self.imid_list[index]
try:
label = 1 if one_lmk_path.split('/')[6]=='manipulated_sequences' else 0
except Exception as e:
label = 1 if one_lmk_path.split('\\')[6] == 'manipulated_sequences' else 0
imid_fg, imid_bg = self.get_fg_bg(one_lmk_path)
manipulate_img, boundary, imid_bg = self.process_images(imid_fg, imid_bg, index)
manipulate_img = self.post_proc(manipulate_img)
imid_bg = self.post_proc(imid_bg)
boundary = torch.from_numpy(boundary)
boundary = boundary.unsqueeze(2).permute(2, 0, 1)
# fake data
fake_data_tuple = (manipulate_img, boundary, 1)
# real data
real_data_tuple = (imid_bg, torch.zeros_like(boundary), label)
return fake_data_tuple, real_data_tuple
@staticmethod
def collate_fn(batch):
"""
Collates batches of data and shuffles the images.
"""
# Unzip the batch
fake_data, real_data = zip(*batch)
# Unzip the fake and real data
fake_images, fake_boundaries, fake_labels = zip(*fake_data)
real_images, real_boundaries, real_labels = zip(*real_data)
# Combine fake and real data
images = torch.stack(fake_images + real_images)
boundaries = torch.stack(fake_boundaries + real_boundaries)
labels = torch.tensor(fake_labels + real_labels)
# Combine images, boundaries, and labels into tuples
combined_data = list(zip(images, boundaries, labels))
# Shuffle the combined data
random.shuffle(combined_data)
# Unzip the shuffled data
images, boundaries, labels = zip(*combined_data)
# Create the data dictionary
data_dict = {
'image': torch.stack(images),
'label': torch.tensor(labels),
'mask': torch.stack(boundaries), # Assuming boundaries are your masks
'landmark': None # Add your landmark data if available
}
return data_dict
def __len__(self):
"""
Get the length of the dataset.
"""
return len(self.imid_list)
if __name__ == "__main__":
dataset = FFBlendDataset()
print('dataset lenth: ', len(dataset))
def tensor2bgr(im):
img = im.squeeze().cpu().numpy().transpose(1, 2, 0)
img = (img + 1)/2 * 255
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img
def tensor2gray(im):
img = im.squeeze().cpu().numpy()
img = img * 255
return img
for i, data_dict in enumerate(dataset):
if i > 20:
break
if label == 1:
if not use_mouth:
img, boudary = im
cv2.imwrite('{}_whole.png'.format(i), tensor2bgr(img))
cv2.imwrite('{}_boudnary.png'.format(i), tensor2gray(boudary))
else:
img, mouth, boudary = im
cv2.imwrite('{}_whole.png'.format(i), tensor2bgr(img))
cv2.imwrite('{}_mouth.png'.format(i), tensor2bgr(mouth))
cv2.imwrite('{}_boudnary.png'.format(i), tensor2gray(boudary))