P-DFD / dataset /celeb_df.py
mrneuralnet's picture
Initial commit
982865f
import numpy as np
from glob import glob
from os import listdir
from os.path import join
from dataset import AbstractDataset
SPLITS = ["train", "test"]
class CelebDF(AbstractDataset):
"""
Celeb-DF v2 Dataset proposed in "Celeb-DF: A Large-scale Challenging Dataset for DeepFake Forensics".
"""
def __init__(self, cfg, seed=2022, transforms=None, transform=None, target_transform=None):
# pre-check
if cfg['split'] not in SPLITS:
raise ValueError(f"split should be one of {SPLITS}, but found {cfg['split']}.")
super(CelebDF, self).__init__(cfg, seed, transforms, transform, target_transform)
print(f"Loading data from 'Celeb-DF' of split '{cfg['split']}'"
f"\nPlease wait patiently...")
self.categories = ['original', 'fake']
self.root = cfg['root']
images_ids = self.__get_images_ids()
test_ids = self.__get_test_ids()
train_ids = [images_ids[0] - test_ids[0],
images_ids[1] - test_ids[1],
images_ids[2] - test_ids[2]]
self.images, self.targets = self.__get_images(
test_ids if cfg['split'] == "test" else train_ids, cfg['balance'])
assert len(self.images) == len(self.targets), "The number of images and targets not consistent."
print("Data from 'Celeb-DF' loaded.\n")
print(f"Dataset contains {len(self.images)} images.\n")
def __get_images_ids(self):
youtube_real = listdir(join(self.root, 'YouTube-real', 'images'))
celeb_real = listdir(join(self.root, 'Celeb-real', 'images'))
celeb_fake = listdir(join(self.root, 'Celeb-synthesis', 'images'))
return set(youtube_real), set(celeb_real), set(celeb_fake)
def __get_test_ids(self):
youtube_real = set()
celeb_real = set()
celeb_fake = set()
with open(join(self.root, "List_of_testing_videos.txt"), "r", encoding="utf-8") as f:
contents = f.readlines()
for line in contents:
name = line.split(" ")[-1]
number = name.split("/")[-1].split(".")[0]
if "YouTube-real" in name:
youtube_real.add(number)
elif "Celeb-real" in name:
celeb_real.add(number)
elif "Celeb-synthesis" in name:
celeb_fake.add(number)
else:
raise ValueError("'List_of_testing_videos.txt' file corrupted.")
return youtube_real, celeb_real, celeb_fake
def __get_images(self, ids, balance=False):
real = list()
fake = list()
# YouTube-real
for _ in ids[0]:
real.extend(glob(join(self.root, 'YouTube-real', 'images', _, '*.png')))
# Celeb-real
for _ in ids[1]:
real.extend(glob(join(self.root, 'Celeb-real', 'images', _, '*.png')))
# Celeb-synthesis
for _ in ids[2]:
fake.extend(glob(join(self.root, 'Celeb-synthesis', 'images', _, '*.png')))
print(f"Real: {len(real)}, Fake: {len(fake)}")
if balance:
fake = np.random.choice(fake, size=len(real), replace=False)
print(f"After Balance | Real: {len(real)}, Fake: {len(fake)}")
real_tgt = [0] * len(real)
fake_tgt = [1] * len(fake)
return [*real, *fake], [*real_tgt, *fake_tgt]
if __name__ == '__main__':
import yaml
config_path = "../config/dataset/celeb_df.yml"
with open(config_path) as config_file:
config = yaml.load(config_file, Loader=yaml.FullLoader)
config = config["train_cfg"]
# config = config["test_cfg"]
def run_dataset():
dataset = CelebDF(config)
print(f"dataset: {len(dataset)}")
for i, _ in enumerate(dataset):
path, target = _
print(f"path: {path}, target: {target}")
if i >= 9:
break
def run_dataloader(display_samples=False):
from torch.utils import data
import matplotlib.pyplot as plt
dataset = CelebDF(config)
dataloader = data.DataLoader(dataset, batch_size=8, shuffle=True)
print(f"dataset: {len(dataset)}")
for i, _ in enumerate(dataloader):
path, targets = _
image = dataloader.dataset.load_item(path)
print(f"image: {image.shape}, target: {targets}")
if display_samples:
plt.figure()
img = image[0].permute([1, 2, 0]).numpy()
plt.imshow(img)
# plt.savefig("./img_" + str(i) + ".png")
plt.show()
if i >= 9:
break
###########################
# run the functions below #
###########################
# run_dataset()
run_dataloader(False)