File size: 4,822 Bytes
caa56d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
'''
# author: Zhiyuan Yan
# email: [email protected]
# date: 2024-01-26
The code is designed for self-blending method (SBI, CVPR 2024).
'''
import sys
sys.path.append('.')
import os
import cv2
import yaml
import random
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
from copy import deepcopy
import albumentations as A
from training.dataset.abstract_dataset import DeepfakeAbstractBaseDataset
from training.dataset.sbi_api import SBI_API
from training.dataset.utils.bi_online_generation_yzy import random_get_hull
from training.dataset.SimSwap.test_one_image import self_blend
import warnings
warnings.filterwarnings('ignore')
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_processor = SegformerImageProcessor.from_pretrained("/Youtu_Pangu_Security_Public/youtu-pangu-public/zhiyuanyan/huggingface/hub/models--jonathandinu--face-parsing/snapshots/a2bf62f39dfd8f8856a3c19be8b0707a8d68abdd")
face_parser = SegformerForSemanticSegmentation.from_pretrained("/Youtu_Pangu_Security_Public/youtu-pangu-public/zhiyuanyan/huggingface/hub/models--jonathandinu--face-parsing/snapshots/a2bf62f39dfd8f8856a3c19be8b0707a8d68abdd").to(device)
def create_facial_mask(mask, with_neck=False):
facial_labels = [1, 2, 3, 4, 5, 6, 7, 10, 11, 12]
if with_neck:
facial_labels += [17]
facial_mask = np.zeros_like(mask, dtype=bool)
for label in facial_labels:
facial_mask |= (mask == label)
return facial_mask.astype(np.uint8) * 255
def face_parsing_mask(img1, with_neck=False):
# run inference on image
img1 = Image.fromarray(img1)
inputs = image_processor(images=img1, return_tensors="pt").to(device)
outputs = face_parser(**inputs)
logits = outputs.logits # shape (batch_size, num_labels, ~height/4, ~width/4)
# resize output to match input image dimensions
upsampled_logits = nn.functional.interpolate(logits,
size=img1.size[::-1], # H x W
mode='bilinear',
align_corners=False)
labels = upsampled_logits.argmax(dim=1)[0]
mask = labels.cpu().numpy()
mask = create_facial_mask(mask, with_neck)
return mask
class YZYDataset(DeepfakeAbstractBaseDataset):
def __init__(self, config=None, mode='train'):
super().__init__(config, mode)
# Get real lists
# Fix the label of real images to be 0
self.real_imglist = [(img, label) for img, label in zip(self.image_list, self.label_list) if label == 0]
def __getitem__(self, index):
# Get the real image paths and labels
real_image_path, real_label = self.real_imglist[index]
# real_image_path = real_image_path.replace('/Youtu_Pangu_Security_Public/', '/Youtu_Pangu_Security/public/')
# Load the real images
real_image = self.load_rgb(real_image_path)
real_image = np.array(real_image) # Convert to numpy array
# Face Parsing
mask = face_parsing_mask(real_image, with_neck=False)
parse_mask_path = real_image_path.replace('frames', 'parse_mask')
os.makedirs(os.path.dirname(parse_mask_path), exist_ok=True)
cv2.imwrite(parse_mask_path, mask)
# # SRI generation
# sri_image = self_blend(real_image)
# sri_path = real_image_path.replace('frames', 'sri_frames')
# os.makedirs(os.path.dirname(sri_path), exist_ok=True)
# cv2.imwrite(sri_path, sri_image)
@staticmethod
def collate_fn(batch):
data_dict = {
'image': None,
'label': None,
'landmark': None,
'mask': None,
}
return data_dict
def __len__(self):
return len(self.real_imglist)
if __name__ == '__main__':
with open('./training/config/detector/sbi.yaml', 'r') as f:
config = yaml.safe_load(f)
with open('./training/config/train_config.yaml', 'r') as f:
config2 = yaml.safe_load(f)
config2['data_manner'] = 'lmdb'
config['dataset_json_folder'] = '/Youtu_Pangu_Security_Public/youtu-pangu-public/zhiyuanyan/DeepfakeBenchv2/preprocessing/dataset_json'
config.update(config2)
train_set = YZYDataset(config=config, mode='train')
train_data_loader = \
torch.utils.data.DataLoader(
dataset=train_set,
batch_size=config['train_batchSize'],
shuffle=True,
num_workers=0,
collate_fn=train_set.collate_fn,
)
from tqdm import tqdm
for iteration, batch in enumerate(tqdm(train_data_loader)):
print(iteration) |