Spaces:
Sleeping
Sleeping
import numpy as np | |
from glob import glob | |
from os import listdir | |
from os.path import join | |
from dataset import AbstractDataset | |
SPLITS = ["train", "test"] | |
class CelebDF(AbstractDataset): | |
""" | |
Celeb-DF v2 Dataset proposed in "Celeb-DF: A Large-scale Challenging Dataset for DeepFake Forensics". | |
""" | |
def __init__(self, cfg, seed=2022, transforms=None, transform=None, target_transform=None): | |
# pre-check | |
if cfg['split'] not in SPLITS: | |
raise ValueError(f"split should be one of {SPLITS}, but found {cfg['split']}.") | |
super(CelebDF, self).__init__(cfg, seed, transforms, transform, target_transform) | |
print(f"Loading data from 'Celeb-DF' of split '{cfg['split']}'" | |
f"\nPlease wait patiently...") | |
self.categories = ['original', 'fake'] | |
self.root = cfg['root'] | |
images_ids = self.__get_images_ids() | |
test_ids = self.__get_test_ids() | |
train_ids = [images_ids[0] - test_ids[0], | |
images_ids[1] - test_ids[1], | |
images_ids[2] - test_ids[2]] | |
self.images, self.targets = self.__get_images( | |
test_ids if cfg['split'] == "test" else train_ids, cfg['balance']) | |
assert len(self.images) == len(self.targets), "The number of images and targets not consistent." | |
print("Data from 'Celeb-DF' loaded.\n") | |
print(f"Dataset contains {len(self.images)} images.\n") | |
def __get_images_ids(self): | |
youtube_real = listdir(join(self.root, 'YouTube-real', 'images')) | |
celeb_real = listdir(join(self.root, 'Celeb-real', 'images')) | |
celeb_fake = listdir(join(self.root, 'Celeb-synthesis', 'images')) | |
return set(youtube_real), set(celeb_real), set(celeb_fake) | |
def __get_test_ids(self): | |
youtube_real = set() | |
celeb_real = set() | |
celeb_fake = set() | |
with open(join(self.root, "List_of_testing_videos.txt"), "r", encoding="utf-8") as f: | |
contents = f.readlines() | |
for line in contents: | |
name = line.split(" ")[-1] | |
number = name.split("/")[-1].split(".")[0] | |
if "YouTube-real" in name: | |
youtube_real.add(number) | |
elif "Celeb-real" in name: | |
celeb_real.add(number) | |
elif "Celeb-synthesis" in name: | |
celeb_fake.add(number) | |
else: | |
raise ValueError("'List_of_testing_videos.txt' file corrupted.") | |
return youtube_real, celeb_real, celeb_fake | |
def __get_images(self, ids, balance=False): | |
real = list() | |
fake = list() | |
# YouTube-real | |
for _ in ids[0]: | |
real.extend(glob(join(self.root, 'YouTube-real', 'images', _, '*.png'))) | |
# Celeb-real | |
for _ in ids[1]: | |
real.extend(glob(join(self.root, 'Celeb-real', 'images', _, '*.png'))) | |
# Celeb-synthesis | |
for _ in ids[2]: | |
fake.extend(glob(join(self.root, 'Celeb-synthesis', 'images', _, '*.png'))) | |
print(f"Real: {len(real)}, Fake: {len(fake)}") | |
if balance: | |
fake = np.random.choice(fake, size=len(real), replace=False) | |
print(f"After Balance | Real: {len(real)}, Fake: {len(fake)}") | |
real_tgt = [0] * len(real) | |
fake_tgt = [1] * len(fake) | |
return [*real, *fake], [*real_tgt, *fake_tgt] | |
if __name__ == '__main__': | |
import yaml | |
config_path = "../config/dataset/celeb_df.yml" | |
with open(config_path) as config_file: | |
config = yaml.load(config_file, Loader=yaml.FullLoader) | |
config = config["train_cfg"] | |
# config = config["test_cfg"] | |
def run_dataset(): | |
dataset = CelebDF(config) | |
print(f"dataset: {len(dataset)}") | |
for i, _ in enumerate(dataset): | |
path, target = _ | |
print(f"path: {path}, target: {target}") | |
if i >= 9: | |
break | |
def run_dataloader(display_samples=False): | |
from torch.utils import data | |
import matplotlib.pyplot as plt | |
dataset = CelebDF(config) | |
dataloader = data.DataLoader(dataset, batch_size=8, shuffle=True) | |
print(f"dataset: {len(dataset)}") | |
for i, _ in enumerate(dataloader): | |
path, targets = _ | |
image = dataloader.dataset.load_item(path) | |
print(f"image: {image.shape}, target: {targets}") | |
if display_samples: | |
plt.figure() | |
img = image[0].permute([1, 2, 0]).numpy() | |
plt.imshow(img) | |
# plt.savefig("./img_" + str(i) + ".png") | |
plt.show() | |
if i >= 9: | |
break | |
########################### | |
# run the functions below # | |
########################### | |
# run_dataset() | |
run_dataloader(False) | |