File size: 5,588 Bytes
caa56d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
'''
# author: Zhiyuan Yan
# email: [email protected]
# date: 2024-01-26
The code is designed for self-blending method (SBI, CVPR 2024).
'''
import sys
sys.path.append('.')
import cv2
import yaml
import torch
import numpy as np
from copy import deepcopy
import albumentations as A
from training.dataset.albu import IsotropicResize
from training.dataset.abstract_dataset import DeepfakeAbstractBaseDataset
from training.dataset.sbi_api import SBI_API
class SBIDataset(DeepfakeAbstractBaseDataset):
def __init__(self, config=None, mode='train'):
super().__init__(config, mode)
# Get real lists
# Fix the label of real images to be 0
self.real_imglist = [(img, label) for img, label in zip(self.image_list, self.label_list) if label == 0]
# Init SBI
self.sbi = SBI_API(phase=mode,image_size=config['resolution'])
# Init data augmentation method
self.transform = self.init_data_aug_method()
def __getitem__(self, index):
# Get the real image paths and labels
real_image_path, real_label = self.real_imglist[index]
# Get the landmark paths for real images
real_landmark_path = real_image_path.replace('frames', 'landmarks').replace('.png', '.npy')
landmark = self.load_landmark(real_landmark_path).astype(np.int32)
# Load the real images
real_image = self.load_rgb(real_image_path)
real_image = np.array(real_image) # Convert to numpy array
# Generate the corresponding SBI sample
fake_image, real_image = self.sbi(real_image, landmark)
if fake_image is None:
fake_image = deepcopy(real_image)
fake_label = 0
else:
fake_label = 1
# To tensor and normalize for fake and real images
fake_image_trans = self.normalize(self.to_tensor(fake_image))
real_image_trans = self.normalize(self.to_tensor(real_image))
return {"fake": (fake_image_trans, fake_label),
"real": (real_image_trans, real_label)}
def __len__(self):
return len(self.real_imglist)
@staticmethod
def collate_fn(batch):
"""
Collate a batch of data points.
Args:
batch (list): A list of tuples containing the image tensor and label tensor.
Returns:
A tuple containing the image tensor, the label tensor, the landmark tensor,
and the mask tensor.
"""
# Separate the image, label, landmark, and mask tensors for fake and real data
fake_images, fake_labels = zip(*[data["fake"] for data in batch])
real_images, real_labels = zip(*[data["real"] for data in batch])
# Stack the image, label, landmark, and mask tensors for fake and real data
fake_images = torch.stack(fake_images, dim=0)
fake_labels = torch.LongTensor(fake_labels)
real_images = torch.stack(real_images, dim=0)
real_labels = torch.LongTensor(real_labels)
# Combine the fake and real tensors and create a dictionary of the tensors
images = torch.cat([real_images, fake_images], dim=0)
labels = torch.cat([real_labels, fake_labels], dim=0)
data_dict = {
'image': images,
'label': labels,
'landmark': None,
'mask': None,
}
return data_dict
def init_data_aug_method(self):
trans = A.Compose([
A.HorizontalFlip(p=self.config['data_aug']['flip_prob']),
A.Rotate(limit=self.config['data_aug']['rotate_limit'], p=self.config['data_aug']['rotate_prob']),
A.GaussianBlur(blur_limit=self.config['data_aug']['blur_limit'], p=self.config['data_aug']['blur_prob']),
A.OneOf([
IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR),
IsotropicResize(max_side=self.config['resolution'], interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR),
], p = 0 if self.config['with_landmark'] else 1),
A.OneOf([
A.RandomBrightnessContrast(brightness_limit=self.config['data_aug']['brightness_limit'], contrast_limit=self.config['data_aug']['contrast_limit']),
A.FancyPCA(),
A.HueSaturationValue()
], p=0.5),
A.ImageCompression(quality_lower=self.config['data_aug']['quality_lower'], quality_upper=self.config['data_aug']['quality_upper'], p=0.5)
],
additional_targets={'real': 'sbi'},
)
return trans
if __name__ == '__main__':
with open('/data/home/zhiyuanyan/DeepfakeBench/training/config/detector/sbi.yaml', 'r') as f:
config = yaml.safe_load(f)
train_set = SBIDataset(config=config, mode='train')
train_data_loader = \
torch.utils.data.DataLoader(
dataset=train_set,
batch_size=config['train_batchSize'],
shuffle=True,
num_workers=0,
collate_fn=train_set.collate_fn,
)
from tqdm import tqdm
for iteration, batch in enumerate(tqdm(train_data_loader)):
print(iteration)
if iteration > 10:
break |