File size: 4,020 Bytes
caa56d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
'''
# author: Zhiyuan Yan
# email: [email protected]
# date: 2023-03-30
The code is designed for scenarios such as disentanglement-based methods where it is necessary to ensure an equal number of positive and negative samples.
'''
import os.path
from copy import deepcopy
import cv2
import math
import torch
import random
import yaml
from PIL import Image, ImageDraw
import numpy as np
from torch.utils.data import DataLoader
from dataset.abstract_dataset import DeepfakeAbstractBaseDataset
class IIDDataset(DeepfakeAbstractBaseDataset):
def __init__(self, config=None, mode='train'):
super().__init__(config, mode)
def __getitem__(self, index):
# Get the image paths and label
image_path = self.data_dict['image'][index]
if '\\' in image_path:
per = image_path.split('\\')[-2]
else:
per = image_path.split('/')[-2]
id_index = int(per.split('_')[-1]) # real video id
label = self.data_dict['label'][index]
# Load the image
try:
image = self.load_rgb(image_path)
except Exception as e:
# Skip this image and return the first one
print(f"Error loading image at index {index}: {e}")
return self.__getitem__(0)
image = np.array(image) # Convert to numpy array for data augmentation
# Do Data Augmentation
image_trans,_,_ = self.data_aug(image)
# To tensor and normalize
image_trans = self.normalize(self.to_tensor(image_trans))
return id_index, image_trans, label
@staticmethod
def collate_fn(batch):
"""
Collate a batch of data points.
Args:
batch (list): A list of tuples containing the image tensor, the label tensor,
the landmark tensor, and the mask tensor.
Returns:
A tuple containing the image tensor, the label tensor, the landmark tensor,
and the mask tensor.
"""
# Separate the image, label, landmark, and mask tensors
id_indexes, image_trans, label = zip(*batch)
# Stack the image, label, landmark, and mask tensors
images = torch.stack(image_trans, dim=0)
labels = torch.LongTensor(label)
ids = torch.LongTensor(id_indexes)
# Create a dictionary of the tensors
data_dict = {}
data_dict['image'] = images
data_dict['label'] = labels
data_dict['id_index'] = ids
data_dict['mask']=None
data_dict['landmark']=None
return data_dict
def draw_landmark(img,landmark):
draw = ImageDraw.Draw(img)
# landmark = np.stack([mean_face_x, mean_face_y], axis=1)
# landmark *=256
# 遍历每个特征点
for i, point in enumerate(landmark):
# 在图像上标记特征点
draw.ellipse((point[0] - 1, point[1] - 1, point[0] + 1, point[1] + 1), fill=(255, 0, 0))
# 在特征点旁边添加序号
draw.text((point[0], point[1]), str(i), fill=(255, 255, 255))
return img
if __name__ == '__main__':
detector_path = r"./training/config/detector/xception.yaml"
# weights_path = "./ckpts/xception/CDFv2/tb_v1/ov.pth"
with open(detector_path, 'r') as f:
config = yaml.safe_load(f)
with open('./training/config/train_config.yaml', 'r') as f:
config2 = yaml.safe_load(f)
config2['data_manner'] = 'lmdb'
config['dataset_json_folder'] = 'preprocessing/dataset_json_v3'
config.update(config2)
dataset = IIDDataset(config=config)
batch_size = 2
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True,collate_fn=dataset.collate_fn)
for i, batch in enumerate(dataloader):
print(f"Batch {i}: {batch}")
# 如果数据集返回的是一个元组(例如,(data, target)),可以这样获取:
img = batch['img']
|