import json from glob import glob from os.path import join from dataset import AbstractDataset SPLIT = ["train", "val", "test"] LABEL_MAP = {"REAL": 0, "FAKE": 1} class DFDC(AbstractDataset): """ Deepfake Detection Challenge organized by Facebook """ def __init__(self, cfg, seed=2022, transforms=None, transform=None, target_transform=None): # pre-check if cfg['split'] not in SPLIT: raise ValueError(f"split should be one of {SPLIT}, but found {cfg['split']}.") super(DFDC, self).__init__(cfg, seed, transforms, transform, target_transform) print(f"Loading data from 'DFDC' of split '{cfg['split']}'" f"\nPlease wait patiently...") self.categories = ['original', 'fake'] self.root = cfg['root'] self.num_real = 0 self.num_fake = 0 if self.split == "test": self.__load_test_data() elif self.split == "train": self.__load_train_data() assert len(self.images) == len(self.targets), "Length of images and targets not the same!" print(f"Data from 'DFDC' loaded.") print(f"Real: {self.num_real}, Fake: {self.num_fake}.") print(f"Dataset contains {len(self.images)} images\n") def __load_test_data(self): label_path = join(self.root, "test", "labels.csv") with open(label_path, encoding="utf-8") as file: content = file.readlines() for _ in content: if ".mp4" in _: key = _.split(".")[0] label = _.split(",")[1].strip() label = int(label) imgs = glob(join(self.root, "test", "images", key, "*.png")) num = len(imgs) self.images.extend(imgs) self.targets.extend([label] * num) if label == 0: self.num_real += num elif label == 1: self.num_fake += num def __load_train_data(self): train_folds = glob(join(self.root, "dfdc_train_part_*")) for fold in train_folds: fold_imgs = list() fold_tgts = list() metadata_path = join(fold, "metadata.json") try: with open(metadata_path, "r", encoding="utf-8") as file: metadata = json.loads(file.readline()) for k, v in metadata.items(): index = k.split(".")[0] label = LABEL_MAP[v["label"]] imgs = glob(join(fold, "images", index, "*.png")) fold_imgs.extend(imgs) fold_tgts.extend([label] * len(imgs)) if label == 0: self.num_real += len(imgs) elif label == 1: self.num_fake += len(imgs) self.images.extend(fold_imgs) self.targets.extend(fold_tgts) except FileNotFoundError: continue if __name__ == '__main__': import yaml config_path = "../config/dataset/dfdc.yml" with open(config_path) as config_file: config = yaml.load(config_file, Loader=yaml.FullLoader) config = config["train_cfg"] # config = config["test_cfg"] def run_dataset(): dataset = DFDC(config) print(f"dataset: {len(dataset)}") for i, _ in enumerate(dataset): path, target = _ print(f"path: {path}, target: {target}") if i >= 9: break def run_dataloader(display_samples=False): from torch.utils import data import matplotlib.pyplot as plt dataset = DFDC(config) dataloader = data.DataLoader(dataset, batch_size=8, shuffle=True) print(f"dataset: {len(dataset)}") for i, _ in enumerate(dataloader): path, targets = _ image = dataloader.dataset.load_item(path) print(f"image: {image.shape}, target: {targets}") if display_samples: plt.figure() img = image[0].permute([1, 2, 0]).numpy() plt.imshow(img) # plt.savefig("./img_" + str(i) + ".png") plt.show() if i >= 9: break ########################### # run the functions below # ########################### # run_dataset() run_dataloader(False)