File size: 3,555 Bytes
982865f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import numpy as np
from os.path import join
from dataset import AbstractDataset

SPLITS = ["train", "test"]


class WildDeepfake(AbstractDataset):
    """
    Wild Deepfake Dataset proposed in "WildDeepfake: A Challenging Real-World Dataset for Deepfake Detection"
    """

    def __init__(self, cfg, seed=2022, transforms=None, transform=None, target_transform=None):
        # pre-check
        if cfg['split'] not in SPLITS:
            raise ValueError(f"split should be one of {SPLITS}, but found {cfg['split']}.")
        super(WildDeepfake, self).__init__(cfg, seed, transforms, transform, target_transform)
        print(f"Loading data from 'WildDeepfake' of split '{cfg['split']}'"
              f"\nPlease wait patiently...")
        self.categories = ['original', 'fake']
        self.root = cfg['root']
        self.num_train = cfg.get('num_image_train', None)
        self.num_test = cfg.get('num_image_test', None)
        self.images, self.targets = self.__get_images()
        print(f"Data from 'WildDeepfake' loaded.")
        print(f"Dataset contains {len(self.images)} images.\n")

    def __get_images(self):
        if self.split == 'train':
            num = self.num_train
        elif self.split == 'test':
            num = self.num_test
        else:
            num = None
        real_images = torch.load(join(self.root, self.split, "real.pickle"))
        if num is not None:
            real_images = np.random.choice(real_images, num // 3, replace=False)
        real_tgts = [torch.tensor(0)] * len(real_images)
        print(f"real: {len(real_tgts)}")
        fake_images = torch.load(join(self.root, self.split, "fake.pickle"))
        if num is not None:
            fake_images = np.random.choice(fake_images, num - num // 3, replace=False)
        fake_tgts = [torch.tensor(1)] * len(fake_images)
        print(f"fake: {len(fake_tgts)}")
        return real_images + fake_images, real_tgts + fake_tgts

    def __getitem__(self, index):
        path = join(self.root, self.split, self.images[index])
        tgt = self.targets[index]
        return path, tgt


if __name__ == '__main__':
    import yaml

    config_path = "../config/dataset/wilddeepfake.yml"
    with open(config_path) as config_file:
        config = yaml.load(config_file, Loader=yaml.FullLoader)
    config = config["train_cfg"]
    # config = config["test_cfg"]


    def run_dataset():
        dataset = WildDeepfake(config)
        print(f"dataset: {len(dataset)}")
        for i, _ in enumerate(dataset):
            path, target = _
            print(f"path: {path}, target: {target}")
            if i >= 9:
                break


    def run_dataloader(display_samples=False):
        from torch.utils import data
        import matplotlib.pyplot as plt

        dataset = WildDeepfake(config)
        dataloader = data.DataLoader(dataset, batch_size=8, shuffle=True)
        print(f"dataset: {len(dataset)}")
        for i, _ in enumerate(dataloader):
            path, targets = _
            image = dataloader.dataset.load_item(path)
            print(f"image: {image.shape}, target: {targets}")
            if display_samples:
                plt.figure()
                img = image[0].permute([1, 2, 0]).numpy()
                plt.imshow(img)
                # plt.savefig("./img_" + str(i) + ".png")
                plt.show()
            if i >= 9:
                break


    ###########################
    # run the functions below #
    ###########################

    # run_dataset()
    run_dataloader(False)