|
'''
|
|
# author: Zhiyuan Yan
|
|
# email: [email protected]
|
|
# date: 2024-01-26
|
|
|
|
The code is designed for self-blending method (SBI, CVPR 2024).
|
|
'''
|
|
|
|
import sys
|
|
sys.path.append('.')
|
|
|
|
import os
|
|
import cv2
|
|
import yaml
|
|
import random
|
|
import torch
|
|
import torch.nn as nn
|
|
from PIL import Image
|
|
import numpy as np
|
|
from copy import deepcopy
|
|
import albumentations as A
|
|
from training.dataset.abstract_dataset import DeepfakeAbstractBaseDataset
|
|
from training.dataset.sbi_api import SBI_API
|
|
from training.dataset.utils.bi_online_generation_yzy import random_get_hull
|
|
from training.dataset.SimSwap.test_one_image import self_blend
|
|
|
|
import warnings
|
|
warnings.filterwarnings('ignore')
|
|
|
|
|
|
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
image_processor = SegformerImageProcessor.from_pretrained("/Youtu_Pangu_Security_Public/youtu-pangu-public/zhiyuanyan/huggingface/hub/models--jonathandinu--face-parsing/snapshots/a2bf62f39dfd8f8856a3c19be8b0707a8d68abdd")
|
|
face_parser = SegformerForSemanticSegmentation.from_pretrained("/Youtu_Pangu_Security_Public/youtu-pangu-public/zhiyuanyan/huggingface/hub/models--jonathandinu--face-parsing/snapshots/a2bf62f39dfd8f8856a3c19be8b0707a8d68abdd").to(device)
|
|
|
|
|
|
def create_facial_mask(mask, with_neck=False):
|
|
facial_labels = [1, 2, 3, 4, 5, 6, 7, 10, 11, 12]
|
|
if with_neck:
|
|
facial_labels += [17]
|
|
facial_mask = np.zeros_like(mask, dtype=bool)
|
|
for label in facial_labels:
|
|
facial_mask |= (mask == label)
|
|
return facial_mask.astype(np.uint8) * 255
|
|
|
|
|
|
def face_parsing_mask(img1, with_neck=False):
|
|
|
|
img1 = Image.fromarray(img1)
|
|
inputs = image_processor(images=img1, return_tensors="pt").to(device)
|
|
outputs = face_parser(**inputs)
|
|
logits = outputs.logits
|
|
|
|
|
|
upsampled_logits = nn.functional.interpolate(logits,
|
|
size=img1.size[::-1],
|
|
mode='bilinear',
|
|
align_corners=False)
|
|
labels = upsampled_logits.argmax(dim=1)[0]
|
|
mask = labels.cpu().numpy()
|
|
mask = create_facial_mask(mask, with_neck)
|
|
return mask
|
|
|
|
|
|
class YZYDataset(DeepfakeAbstractBaseDataset):
|
|
def __init__(self, config=None, mode='train'):
|
|
super().__init__(config, mode)
|
|
|
|
|
|
|
|
self.real_imglist = [(img, label) for img, label in zip(self.image_list, self.label_list) if label == 0]
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
real_image_path, real_label = self.real_imglist[index]
|
|
|
|
|
|
|
|
real_image = self.load_rgb(real_image_path)
|
|
real_image = np.array(real_image)
|
|
|
|
|
|
mask = face_parsing_mask(real_image, with_neck=False)
|
|
parse_mask_path = real_image_path.replace('frames', 'parse_mask')
|
|
os.makedirs(os.path.dirname(parse_mask_path), exist_ok=True)
|
|
cv2.imwrite(parse_mask_path, mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
def collate_fn(batch):
|
|
data_dict = {
|
|
'image': None,
|
|
'label': None,
|
|
'landmark': None,
|
|
'mask': None,
|
|
}
|
|
return data_dict
|
|
|
|
def __len__(self):
|
|
return len(self.real_imglist)
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
with open('./training/config/detector/sbi.yaml', 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
with open('./training/config/train_config.yaml', 'r') as f:
|
|
config2 = yaml.safe_load(f)
|
|
config2['data_manner'] = 'lmdb'
|
|
config['dataset_json_folder'] = '/Youtu_Pangu_Security_Public/youtu-pangu-public/zhiyuanyan/DeepfakeBenchv2/preprocessing/dataset_json'
|
|
config.update(config2)
|
|
train_set = YZYDataset(config=config, mode='train')
|
|
train_data_loader = \
|
|
torch.utils.data.DataLoader(
|
|
dataset=train_set,
|
|
batch_size=config['train_batchSize'],
|
|
shuffle=True,
|
|
num_workers=0,
|
|
collate_fn=train_set.collate_fn,
|
|
)
|
|
from tqdm import tqdm
|
|
for iteration, batch in enumerate(tqdm(train_data_loader)):
|
|
print(iteration) |