P-DFD / dataset /dfdc.py
mrneuralnet's picture
Initial commit
982865f
import json
from glob import glob
from os.path import join
from dataset import AbstractDataset
SPLIT = ["train", "val", "test"]
LABEL_MAP = {"REAL": 0, "FAKE": 1}
class DFDC(AbstractDataset):
"""
Deepfake Detection Challenge organized by Facebook
"""
def __init__(self, cfg, seed=2022, transforms=None, transform=None, target_transform=None):
# pre-check
if cfg['split'] not in SPLIT:
raise ValueError(f"split should be one of {SPLIT}, but found {cfg['split']}.")
super(DFDC, self).__init__(cfg, seed, transforms, transform, target_transform)
print(f"Loading data from 'DFDC' of split '{cfg['split']}'"
f"\nPlease wait patiently...")
self.categories = ['original', 'fake']
self.root = cfg['root']
self.num_real = 0
self.num_fake = 0
if self.split == "test":
self.__load_test_data()
elif self.split == "train":
self.__load_train_data()
assert len(self.images) == len(self.targets), "Length of images and targets not the same!"
print(f"Data from 'DFDC' loaded.")
print(f"Real: {self.num_real}, Fake: {self.num_fake}.")
print(f"Dataset contains {len(self.images)} images\n")
def __load_test_data(self):
label_path = join(self.root, "test", "labels.csv")
with open(label_path, encoding="utf-8") as file:
content = file.readlines()
for _ in content:
if ".mp4" in _:
key = _.split(".")[0]
label = _.split(",")[1].strip()
label = int(label)
imgs = glob(join(self.root, "test", "images", key, "*.png"))
num = len(imgs)
self.images.extend(imgs)
self.targets.extend([label] * num)
if label == 0:
self.num_real += num
elif label == 1:
self.num_fake += num
def __load_train_data(self):
train_folds = glob(join(self.root, "dfdc_train_part_*"))
for fold in train_folds:
fold_imgs = list()
fold_tgts = list()
metadata_path = join(fold, "metadata.json")
try:
with open(metadata_path, "r", encoding="utf-8") as file:
metadata = json.loads(file.readline())
for k, v in metadata.items():
index = k.split(".")[0]
label = LABEL_MAP[v["label"]]
imgs = glob(join(fold, "images", index, "*.png"))
fold_imgs.extend(imgs)
fold_tgts.extend([label] * len(imgs))
if label == 0:
self.num_real += len(imgs)
elif label == 1:
self.num_fake += len(imgs)
self.images.extend(fold_imgs)
self.targets.extend(fold_tgts)
except FileNotFoundError:
continue
if __name__ == '__main__':
import yaml
config_path = "../config/dataset/dfdc.yml"
with open(config_path) as config_file:
config = yaml.load(config_file, Loader=yaml.FullLoader)
config = config["train_cfg"]
# config = config["test_cfg"]
def run_dataset():
dataset = DFDC(config)
print(f"dataset: {len(dataset)}")
for i, _ in enumerate(dataset):
path, target = _
print(f"path: {path}, target: {target}")
if i >= 9:
break
def run_dataloader(display_samples=False):
from torch.utils import data
import matplotlib.pyplot as plt
dataset = DFDC(config)
dataloader = data.DataLoader(dataset, batch_size=8, shuffle=True)
print(f"dataset: {len(dataset)}")
for i, _ in enumerate(dataloader):
path, targets = _
image = dataloader.dataset.load_item(path)
print(f"image: {image.shape}, target: {targets}")
if display_samples:
plt.figure()
img = image[0].permute([1, 2, 0]).numpy()
plt.imshow(img)
# plt.savefig("./img_" + str(i) + ".png")
plt.show()
if i >= 9:
break
###########################
# run the functions below #
###########################
# run_dataset()
run_dataloader(False)