|
'''
|
|
# author: Zhiyuan Yan
|
|
# email: [email protected]
|
|
# date: 2024-01-26
|
|
|
|
The code is designed for self-blending method (SBI, CVPR 2024).
|
|
'''
|
|
|
|
import sys
|
|
sys.path.append('.')
|
|
|
|
import cv2
|
|
import yaml
|
|
import torch
|
|
import numpy as np
|
|
from copy import deepcopy
|
|
import albumentations as A
|
|
from training.dataset.albu import IsotropicResize
|
|
from training.dataset.abstract_dataset import DeepfakeAbstractBaseDataset
|
|
from training.dataset.sbi_api import SBI_API
|
|
|
|
|
|
class SBIDataset(DeepfakeAbstractBaseDataset):
|
|
def __init__(self, config=None, mode='train'):
|
|
super().__init__(config, mode)
|
|
|
|
|
|
|
|
self.real_imglist = [(img, label) for img, label in zip(self.image_list, self.label_list) if label == 0]
|
|
|
|
|
|
self.sbi = SBI_API(phase=mode,image_size=config['resolution'])
|
|
|
|
|
|
self.transform = self.init_data_aug_method()
|
|
|
|
def __getitem__(self, index):
|
|
|
|
real_image_path, real_label = self.real_imglist[index]
|
|
|
|
|
|
real_landmark_path = real_image_path.replace('frames', 'landmarks').replace('.png', '.npy')
|
|
landmark = self.load_landmark(real_landmark_path).astype(np.int32)
|
|
|
|
|
|
real_image = self.load_rgb(real_image_path)
|
|
real_image = np.array(real_image)
|
|
|
|
|
|
fake_image, real_image = self.sbi(real_image, landmark)
|
|
if fake_image is None:
|
|
fake_image = deepcopy(real_image)
|
|
fake_label = 0
|
|
else:
|
|
fake_label = 1
|
|
|
|
|
|
fake_image_trans = self.normalize(self.to_tensor(fake_image))
|
|
real_image_trans = self.normalize(self.to_tensor(real_image))
|
|
|
|
return {"fake": (fake_image_trans, fake_label),
|
|
"real": (real_image_trans, real_label)}
|
|
|
|
def __len__(self):
|
|
return len(self.real_imglist)
|
|
|
|
@staticmethod
|
|
def collate_fn(batch):
|
|
"""
|
|
Collate a batch of data points.
|
|
|
|
Args:
|
|
batch (list): A list of tuples containing the image tensor and label tensor.
|
|
|
|
Returns:
|
|
A tuple containing the image tensor, the label tensor, the landmark tensor,
|
|
and the mask tensor.
|
|
"""
|
|
|
|
fake_images, fake_labels = zip(*[data["fake"] for data in batch])
|
|
real_images, real_labels = zip(*[data["real"] for data in batch])
|
|
|
|
|
|
fake_images = torch.stack(fake_images, dim=0)
|
|
fake_labels = torch.LongTensor(fake_labels)
|
|
real_images = torch.stack(real_images, dim=0)
|
|
real_labels = torch.LongTensor(real_labels)
|
|
|
|
|
|
images = torch.cat([real_images, fake_images], dim=0)
|
|
labels = torch.cat([real_labels, fake_labels], dim=0)
|
|
|
|
data_dict = {
|
|
'image': images,
|
|
'label': labels,
|
|
'landmark': None,
|
|
'mask': None,
|
|
}
|
|
return data_dict
|
|
|
|
def init_data_aug_method(self):
|
|
trans = A.Compose([
|
|
A.HorizontalFlip(p=self.config['data_aug']['flip_prob']),
|
|
A.Rotate(limit=self.config['data_aug']['rotate_limit'], p=self.config['data_aug']['rotate_prob']),
|
|
A.GaussianBlur(blur_limit=self.config['data_aug']['blur_limit'], p=self.config['data_aug']['blur_prob']),
|
|
A.OneOf([
|
|
IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
|
|
IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR),
|
|
IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR),
|
|
], p = 0 if self.config['with_landmark'] else 1),
|
|
A.OneOf([
|
|
A.RandomBrightnessContrast(brightness_limit=self.config['data_aug']['brightness_limit'], contrast_limit=self.config['data_aug']['contrast_limit']),
|
|
A.FancyPCA(),
|
|
A.HueSaturationValue()
|
|
], p=0.5),
|
|
A.ImageCompression(quality_lower=self.config['data_aug']['quality_lower'], quality_upper=self.config['data_aug']['quality_upper'], p=0.5)
|
|
],
|
|
additional_targets={'real': 'sbi'},
|
|
)
|
|
return trans
|
|
|
|
|
|
if __name__ == '__main__':
|
|
with open('/data/home/zhiyuanyan/DeepfakeBench/training/config/detector/sbi.yaml', 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
train_set = SBIDataset(config=config, mode='train')
|
|
train_data_loader = \
|
|
torch.utils.data.DataLoader(
|
|
dataset=train_set,
|
|
batch_size=config['train_batchSize'],
|
|
shuffle=True,
|
|
num_workers=0,
|
|
collate_fn=train_set.collate_fn,
|
|
)
|
|
from tqdm import tqdm
|
|
for iteration, batch in enumerate(tqdm(train_data_loader)):
|
|
print(iteration)
|
|
if iteration > 10:
|
|
break |