from torch.utils.data import Dataset from PIL import Image from utils import data_utils class InferenceDataset(Dataset): def __init__(self, root=None, paths_list=None, opts=None, transform=None, return_path=False): if paths_list is None: self.paths = sorted(data_utils.make_dataset(root)) else: self.paths = data_utils.make_dataset_from_paths_list(paths_list) self.transform = transform self.opts = opts self.return_path = return_path def __len__(self): return len(self.paths) def __getitem__(self, index): from_path = self.paths[index] from_im = Image.open(from_path) from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L') if self.transform: from_im = self.transform(from_im) if self.return_path: return from_im, from_path else: return from_im