Ravi21 commited on
Commit
8d9f691
·
verified ·
1 Parent(s): 7f0687b

Upload 9 files

Browse files
data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # data_init
data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (151 Bytes). View file
 
data/__pycache__/data_loader_test.cpython-310.pyc ADDED
Binary file (430 Bytes). View file
 
data/aligned_dataset_test.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from data.base_dataset import BaseDataset, get_params, get_transform
3
+ from PIL import Image
4
+ import linecache
5
+
6
+ class AlignedDataset(BaseDataset):
7
+ def initialize(self, opt):
8
+ self.opt = opt
9
+ self.root = opt.dataroot
10
+
11
+ self.fine_height=256
12
+ self.fine_width=192
13
+
14
+ self.dataset_size = len(open('demo.txt').readlines())
15
+
16
+ dir_I = '_img'
17
+ self.dir_I = os.path.join(opt.dataroot, opt.phase + dir_I)
18
+
19
+ dir_C = '_clothes'
20
+ self.dir_C = os.path.join(opt.dataroot, opt.phase + dir_C)
21
+
22
+ dir_E = '_edge'
23
+ self.dir_E = os.path.join(opt.dataroot, opt.phase + dir_E)
24
+
25
+ def __getitem__(self, index):
26
+
27
+ file_path ='demo.txt'
28
+ im_name, c_name = linecache.getline(file_path, index+1).strip().split()
29
+
30
+ I_path = os.path.join(self.dir_I,im_name)
31
+ I = Image.open(I_path).convert('RGB')
32
+ # new line
33
+ I = I.resize((self.fine_width, self.fine_height))
34
+
35
+ params = get_params(self.opt, I.size)
36
+ transform = get_transform(self.opt, params)
37
+ transform_E = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
38
+
39
+ I_tensor = transform(I)
40
+
41
+ C_path = os.path.join(self.dir_C,c_name)
42
+ C = Image.open(C_path).convert('RGB')
43
+ C_tensor = transform(C)
44
+
45
+ E_path = os.path.join(self.dir_E,c_name)
46
+ E = Image.open(E_path).convert('L')
47
+ E_tensor = transform_E(E)
48
+
49
+ input_dict = { 'image': I_tensor,'clothes': C_tensor, 'edge': E_tensor}
50
+ return input_dict
51
+
52
+ def __len__(self):
53
+ return self.dataset_size
54
+
55
+ def name(self):
56
+ return 'AlignedDataset'
data/base_data_loader.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class BaseDataLoader():
3
+ def __init__(self):
4
+ pass
5
+
6
+ def initialize(self, opt):
7
+ self.opt = opt
8
+ pass
9
+
10
+ def load_data():
11
+ return None
12
+
13
+
14
+
data/base_dataset.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+ from PIL import Image
3
+ import torchvision.transforms as transforms
4
+ import numpy as np
5
+ import random
6
+
7
+ class BaseDataset(data.Dataset):
8
+ def __init__(self):
9
+ super(BaseDataset, self).__init__()
10
+
11
+ def name(self):
12
+ return 'BaseDataset'
13
+
14
+ def initialize(self, opt):
15
+ pass
16
+
17
+ def get_params(opt, size):
18
+ w, h = size
19
+ new_h = h
20
+ new_w = w
21
+ if opt.resize_or_crop == 'resize_and_crop':
22
+ new_h = new_w = opt.loadSize
23
+ elif opt.resize_or_crop == 'scale_width_and_crop':
24
+ new_w = opt.loadSize
25
+ new_h = opt.loadSize * h // w
26
+
27
+ x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
28
+ y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
29
+
30
+ flip = 0
31
+ return {'crop_pos': (x, y), 'flip': flip}
32
+
33
+ def get_transform_resize(opt, params, method=Image.BICUBIC, normalize=True):
34
+ transform_list = []
35
+ transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
36
+ osize = [256,192]
37
+ transform_list.append(transforms.Scale(osize, method))
38
+ if 'crop' in opt.resize_or_crop:
39
+ transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
40
+
41
+ if opt.resize_or_crop == 'none':
42
+ base = float(2 ** opt.n_downsample_global)
43
+ if opt.netG == 'local':
44
+ base *= (2 ** opt.n_local_enhancers)
45
+ transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
46
+
47
+ if opt.isTrain and not opt.no_flip:
48
+ transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
49
+
50
+ transform_list += [transforms.ToTensor()]
51
+
52
+ if normalize:
53
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
54
+ (0.5, 0.5, 0.5))]
55
+ return transforms.Compose(transform_list)
56
+
57
+ def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
58
+ transform_list = []
59
+ if 'resize' in opt.resize_or_crop:
60
+ osize = [opt.loadSize, opt.loadSize]
61
+ transform_list.append(transforms.Scale(osize, method))
62
+ elif 'scale_width' in opt.resize_or_crop:
63
+ transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
64
+ osize = [256,192]
65
+ transform_list.append(transforms.Scale(osize, method))
66
+ if 'crop' in opt.resize_or_crop:
67
+ transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
68
+
69
+ if opt.resize_or_crop == 'none':
70
+ base = float(16)
71
+ transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
72
+
73
+ if opt.isTrain and not opt.no_flip:
74
+ transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
75
+
76
+ transform_list += [transforms.ToTensor()]
77
+
78
+ if normalize:
79
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
80
+ (0.5, 0.5, 0.5))]
81
+ return transforms.Compose(transform_list)
82
+
83
+ def normalize():
84
+ return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
85
+
86
+ def __make_power_2(img, base, method=Image.BICUBIC):
87
+ ow, oh = img.size
88
+ h = int(round(oh / base) * base)
89
+ w = int(round(ow / base) * base)
90
+ if (h == oh) and (w == ow):
91
+ return img
92
+ return img.resize((w, h), method)
93
+
94
+ def __scale_width(img, target_width, method=Image.BICUBIC):
95
+ ow, oh = img.size
96
+ if (ow == target_width):
97
+ return img
98
+ w = target_width
99
+ h = int(target_width * oh / ow)
100
+ return img.resize((w, h), method)
101
+
102
+ def __crop(img, pos, size):
103
+ ow, oh = img.size
104
+ x1, y1 = pos
105
+ tw = th = size
106
+ if (ow > tw or oh > th):
107
+ return img.crop((x1, y1, x1 + tw, y1 + th))
108
+ return img
109
+
110
+ def __flip(img, flip):
111
+ if flip:
112
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
113
+ return img
data/custom_dataset_data_loader_test.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data
2
+ from data.base_data_loader import BaseDataLoader
3
+
4
+
5
+ def CreateDataset(opt):
6
+ dataset = None
7
+ from data.aligned_dataset_test import AlignedDataset
8
+ dataset = AlignedDataset()
9
+
10
+ print("dataset [%s] was created" % (dataset.name()))
11
+ dataset.initialize(opt)
12
+ return dataset
13
+
14
+ class CustomDatasetDataLoader(BaseDataLoader):
15
+ def name(self):
16
+ return 'CustomDatasetDataLoader'
17
+
18
+ def initialize(self, opt):
19
+ BaseDataLoader.initialize(self, opt)
20
+ self.dataset = CreateDataset(opt)
21
+ self.dataloader = torch.utils.data.DataLoader(
22
+ self.dataset,
23
+ batch_size=opt.batchSize,
24
+ shuffle = False,
25
+ num_workers=int(opt.nThreads))
26
+
27
+ def load_data(self):
28
+ return self.dataloader
29
+
30
+ def __len__(self):
31
+ return min(len(self.dataset), self.opt.max_dataset_size)
data/data_loader_test.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ def CreateDataLoader(opt):
3
+ from data.custom_dataset_data_loader_test import CustomDatasetDataLoader
4
+ data_loader = CustomDatasetDataLoader()
5
+ print(data_loader.name())
6
+ data_loader.initialize(opt)
7
+ return data_loader
data/image_folder.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+ from PIL import Image
3
+ import os
4
+
5
+ IMG_EXTENSIONS = [
6
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
7
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
8
+ ]
9
+
10
+
11
+ def is_image_file(filename):
12
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
13
+
14
+ def make_dataset(dir):
15
+ images = []
16
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
17
+
18
+ f = dir.split('/')[-1].split('_')[-1]
19
+ print (dir, f)
20
+ dirs= os.listdir(dir)
21
+ for img in dirs:
22
+
23
+ path = os.path.join(dir, img)
24
+ #print(path)
25
+ images.append(path)
26
+ return images
27
+
28
+ def make_dataset_test(dir):
29
+ images = []
30
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
31
+
32
+ f = dir.split('/')[-1].split('_')[-1]
33
+ for i in range(len([name for name in os.listdir(dir) if os.path.isfile(os.path.join(dir, name))])):
34
+ if f == 'label' or f == 'labelref':
35
+ img = str(i) + '.png'
36
+ else:
37
+ img = str(i) + '.jpg'
38
+ path = os.path.join(dir, img)
39
+ images.append(path)
40
+ return images
41
+
42
+ def default_loader(path):
43
+ return Image.open(path).convert('RGB')
44
+
45
+
46
+ class ImageFolder(data.Dataset):
47
+
48
+ def __init__(self, root, transform=None, return_paths=False,
49
+ loader=default_loader):
50
+ imgs = make_dataset(root)
51
+ if len(imgs) == 0:
52
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
53
+ "Supported image extensions are: " +
54
+ ",".join(IMG_EXTENSIONS)))
55
+
56
+ self.root = root
57
+ self.imgs = imgs
58
+ self.transform = transform
59
+ self.return_paths = return_paths
60
+ self.loader = loader
61
+
62
+ def __getitem__(self, index):
63
+ path = self.imgs[index]
64
+ img = self.loader(path)
65
+ if self.transform is not None:
66
+ img = self.transform(img)
67
+ if self.return_paths:
68
+ return img, path
69
+ else:
70
+ return img
71
+
72
+ def __len__(self):
73
+ return len(self.imgs)