import torch import os from PIL import Image import random import numpy as np import pickle import torchvision.transforms as transforms class BaseDataset(torch.utils.data.Dataset): """docstring for BaseDataset""" def __init__(self): super(BaseDataset, self).__init__() def name(self): return os.path.basename(self.opt.data_root.strip('/')) def initialize(self, opt): self.opt = opt self.imgs_dir = os.path.join(self.opt.data_root, self.opt.imgs_dir) self.is_train = self.opt.mode == "train" # load images path filename = self.opt.train_csv if self.is_train else self.opt.test_csv self.imgs_name_file = os.path.join(self.opt.data_root, filename) self.imgs_path = self.make_dataset() # load AUs dicitionary aus_pkl = os.path.join(self.opt.data_root, self.opt.aus_pkl) self.aus_dict = self.load_dict(aus_pkl) # load image to tensor transformer self.img2tensor = self.img_transformer() def make_dataset(self): return None def load_dict(self, pkl_path): saved_dict = {} with open(pkl_path, 'rb') as f: saved_dict = pickle.load(f, encoding='latin1') return saved_dict def get_img_by_path(self, img_path): assert os.path.isfile(img_path), "Cannot find image file: %s" % img_path img_type = 'L' if self.opt.img_nc == 1 else 'RGB' return Image.open(img_path).convert(img_type) def get_aus_by_path(self, img_path): return None def img_transformer(self): transform_list = [] if self.opt.resize_or_crop == 'resize_and_crop': transform_list.append(transforms.Resize([self.opt.load_size, self.opt.load_size], Image.BICUBIC)) transform_list.append(transforms.RandomCrop(self.opt.final_size)) elif self.opt.resize_or_crop == 'crop': transform_list.append(transforms.RandomCrop(self.opt.final_size)) elif self.opt.resize_or_crop == 'none': transform_list.append(transforms.Lambda(lambda image: image)) else: raise ValueError("--resize_or_crop %s is not a valid option." % self.opt.resize_or_crop) if self.is_train and not self.opt.no_flip: transform_list.append(transforms.RandomHorizontalFlip()) transform_list.append(transforms.ToTensor()) transform_list.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) img2tensor = transforms.Compose(transform_list) return img2tensor def __len__(self): return len(self.imgs_path)