File size: 6,642 Bytes
caa56d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
'''

# author: Zhiyuan Yan

# email: [email protected]

# date: 2023-03-30



The code is designed for scenarios such as disentanglement-based methods where it is necessary to ensure an equal number of positive and negative samples.

'''

import torch
import random
import numpy as np
from dataset.abstract_dataset import DeepfakeAbstractBaseDataset


class pairDataset(DeepfakeAbstractBaseDataset):
    def __init__(self, config=None, mode='train'):
        super().__init__(config, mode)
        
        # Get real and fake image lists
        # Fix the label of real images to be 0 and fake images to be 1
        self.fake_imglist = [(img, label, 1) for img, label in zip(self.image_list, self.label_list) if label != 0]
        self.real_imglist = [(img, label, 0) for img, label in zip(self.image_list, self.label_list) if label == 0]

    def __getitem__(self, index, norm=True):
        # Get the fake and real image paths and labels
        fake_image_path, fake_spe_label, fake_label = self.fake_imglist[index]
        real_index = random.randint(0, len(self.real_imglist) - 1)  # Randomly select a real image
        real_image_path, real_spe_label, real_label = self.real_imglist[real_index]

        # Get the mask and landmark paths for fake and real images
        fake_mask_path = fake_image_path.replace('frames', 'masks')
        fake_landmark_path = fake_image_path.replace('frames', 'landmarks').replace('.png', '.npy')
        
        real_mask_path = real_image_path.replace('frames', 'masks')
        real_landmark_path = real_image_path.replace('frames', 'landmarks').replace('.png', '.npy')

        # Load the fake and real images
        fake_image = self.load_rgb(fake_image_path)
        real_image = self.load_rgb(real_image_path)

        fake_image = np.array(fake_image)  # Convert to numpy array for data augmentation
        real_image = np.array(real_image)  # Convert to numpy array for data augmentation

        # Load mask and landmark (if needed) for fake and real images
        if self.config['with_mask']:
            fake_mask = self.load_mask(fake_mask_path)
            real_mask = self.load_mask(real_mask_path)
        else:
            fake_mask, real_mask = None, None

        if self.config['with_landmark']:
            fake_landmarks = self.load_landmark(fake_landmark_path)
            real_landmarks = self.load_landmark(real_landmark_path)
        else:
            fake_landmarks, real_landmarks = None, None

        # Do transforms for fake and real images
        fake_image_trans, fake_landmarks_trans, fake_mask_trans = self.data_aug(fake_image, fake_landmarks, fake_mask)
        real_image_trans, real_landmarks_trans, real_mask_trans = self.data_aug(real_image, real_landmarks, real_mask)

        if not norm:
            return {"fake": (fake_image_trans, fake_label), 
                    "real": (real_image_trans, real_label)}

        # To tensor and normalize for fake and real images
        fake_image_trans = self.normalize(self.to_tensor(fake_image_trans))
        real_image_trans = self.normalize(self.to_tensor(real_image_trans))

        # Convert landmarks and masks to tensors if they exist
        if self.config['with_landmark']:
            fake_landmarks_trans = torch.from_numpy(fake_landmarks_trans)
            real_landmarks_trans = torch.from_numpy(real_landmarks_trans)
        if self.config['with_mask']:
            fake_mask_trans = torch.from_numpy(fake_mask_trans)
            real_mask_trans = torch.from_numpy(real_mask_trans)

        return {"fake": (fake_image_trans, fake_label, fake_spe_label, fake_landmarks_trans, fake_mask_trans), 
                "real": (real_image_trans, real_label, real_spe_label, real_landmarks_trans, real_mask_trans)}

    def __len__(self):
        return len(self.fake_imglist)

    @staticmethod
    def collate_fn(batch):
        """

        Collate a batch of data points.



        Args:

            batch (list): A list of tuples containing the image tensor, the label tensor,

                        the landmark tensor, and the mask tensor.



        Returns:

            A tuple containing the image tensor, the label tensor, the landmark tensor,

            and the mask tensor.

        """
        # Separate the image, label, landmark, and mask tensors for fake and real data
        fake_images, fake_labels, fake_spe_labels, fake_landmarks, fake_masks = zip(*[data["fake"] for data in batch])
        real_images, real_labels, real_spe_labels, real_landmarks, real_masks = zip(*[data["real"] for data in batch])

        # Stack the image, label, landmark, and mask tensors for fake and real data
        fake_images = torch.stack(fake_images, dim=0)
        fake_labels = torch.LongTensor(fake_labels)
        fake_spe_labels = torch.LongTensor(fake_spe_labels)
        real_images = torch.stack(real_images, dim=0)
        real_labels = torch.LongTensor(real_labels)
        real_spe_labels = torch.LongTensor(real_spe_labels)

        # Special case for landmarks and masks if they are None
        if fake_landmarks[0] is not None:
            fake_landmarks = torch.stack(fake_landmarks, dim=0)
        else:
            fake_landmarks = None
        if real_landmarks[0] is not None:
            real_landmarks = torch.stack(real_landmarks, dim=0)
        else:
            real_landmarks = None

        if fake_masks[0] is not None:
            fake_masks = torch.stack(fake_masks, dim=0)
        else:
            fake_masks = None
        if real_masks[0] is not None:
            real_masks = torch.stack(real_masks, dim=0)
        else:
            real_masks = None

        # Combine the fake and real tensors and create a dictionary of the tensors
        images = torch.cat([real_images, fake_images], dim=0)
        labels = torch.cat([real_labels, fake_labels], dim=0)
        spe_labels = torch.cat([real_spe_labels, fake_spe_labels], dim=0)
        
        if fake_landmarks is not None and real_landmarks is not None:
            landmarks = torch.cat([real_landmarks, fake_landmarks], dim=0)
        else:
            landmarks = None

        if fake_masks is not None and real_masks is not None:
            masks = torch.cat([real_masks, fake_masks], dim=0)
        else:
            masks = None

        data_dict = {
            'image': images,
            'label': labels,
            'label_spe': spe_labels,
            'landmark': landmarks,
            'mask': masks
        }
        return data_dict