File size: 823 Bytes
ed697ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from torch.utils.data import Dataset
from PIL import Image
from utils import data_utils


class InferenceDataset(Dataset):

	def __init__(self, root=None, paths_list=None, opts=None, transform=None, return_path=False):
		if paths_list is None:
			self.paths = sorted(data_utils.make_dataset(root))
		else:
			self.paths = data_utils.make_dataset_from_paths_list(paths_list)
		self.transform = transform
		self.opts = opts
		self.return_path = return_path

	def __len__(self):
		return len(self.paths)

	def __getitem__(self, index):
		from_path = self.paths[index]
		from_im = Image.open(from_path)
		from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L')
		if self.transform:
			from_im = self.transform(from_im)
		if self.return_path:
			return from_im, from_path
		else:
			return from_im