gaur3009 commited on
Commit
d89a8d6
·
verified ·
1 Parent(s): 9862b96

Upload 2 files

Browse files
Files changed (2) hide show
  1. cp_dataset.py +263 -0
  2. grid.png +0 -0
cp_dataset.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ import torch
3
+ import torch.utils.data as data
4
+ import torchvision.transforms as transforms
5
+
6
+ from PIL import Image
7
+ from PIL import ImageDraw
8
+
9
+ import os.path as osp
10
+ import numpy as np
11
+ import json
12
+
13
+
14
+ class CPDataset(data.Dataset):
15
+ """Dataset for CP-VTON+.
16
+ """
17
+
18
+ def __init__(self, opt):
19
+ super(CPDataset, self).__init__()
20
+ # base setting
21
+ self.opt = opt
22
+ self.root = opt.dataroot
23
+ self.datamode = opt.datamode # train or test or self-defined
24
+ self.stage = opt.stage # GMM or TOM
25
+ self.data_list = opt.data_list
26
+ self.fine_height = opt.fine_height
27
+ self.fine_width = opt.fine_width
28
+ self.radius = opt.radius
29
+ self.data_path = osp.join(opt.dataroot, opt.datamode)
30
+ self.transform = transforms.Compose([
31
+ transforms.ToTensor(),
32
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
33
+
34
+ # load data list
35
+ im_names = []
36
+ c_names = []
37
+ with open(osp.join(opt.dataroot, opt.data_list), 'r') as f:
38
+ for line in f.readlines():
39
+ im_name, c_name = line.strip().split()
40
+ im_names.append(im_name)
41
+ c_names.append(c_name)
42
+
43
+ self.im_names = im_names
44
+ self.c_names = c_names
45
+
46
+ def name(self):
47
+ return "CPDataset"
48
+
49
+ def __getitem__(self, index):
50
+ c_name = self.c_names[index]
51
+ im_name = self.im_names[index]
52
+ if self.stage == 'GMM':
53
+ c = Image.open(osp.join(self.data_path, 'cloth', c_name))
54
+ cm = Image.open(osp.join(self.data_path, 'cloth-mask', c_name)).convert('L')
55
+ else:
56
+ c = Image.open(osp.join(self.data_path, 'warp-cloth', im_name)) # c_name, if that is used when saved
57
+ cm = Image.open(osp.join(self.data_path, 'warp-mask', im_name)).convert('L') # c_name, if that is used when saved
58
+
59
+ c = self.transform(c) # [-1,1]
60
+ cm_array = np.array(cm)
61
+ cm_array = (cm_array >= 128).astype(np.float32)
62
+ cm = torch.from_numpy(cm_array) # [0,1]
63
+ cm.unsqueeze_(0)
64
+
65
+ # person image
66
+ im = Image.open(osp.join(self.data_path, 'image', im_name))
67
+ im = self.transform(im) # [-1,1]
68
+
69
+ """
70
+ LIP labels
71
+
72
+ [(0, 0, 0), # 0=Background
73
+ (128, 0, 0), # 1=Hat
74
+ (255, 0, 0), # 2=Hair
75
+ (0, 85, 0), # 3=Glove
76
+ (170, 0, 51), # 4=SunGlasses
77
+ (255, 85, 0), # 5=UpperClothes
78
+ (0, 0, 85), # 6=Dress
79
+ (0, 119, 221), # 7=Coat
80
+ (85, 85, 0), # 8=Socks
81
+ (0, 85, 85), # 9=Pants
82
+ (85, 51, 0), # 10=Jumpsuits
83
+ (52, 86, 128), # 11=Scarf
84
+ (0, 128, 0), # 12=Skirt
85
+ (0, 0, 255), # 13=Face
86
+ (51, 170, 221), # 14=LeftArm
87
+ (0, 255, 255), # 15=RightArm
88
+ (85, 255, 170), # 16=LeftLeg
89
+ (170, 255, 85), # 17=RightLeg
90
+ (255, 255, 0), # 18=LeftShoe
91
+ (255, 170, 0) # 19=RightShoe
92
+ (170, 170, 50) # 20=Skin/Neck/Chest (Newly added after running dataset_neck_skin_correction.py)
93
+ ]
94
+ """
95
+
96
+ # load parsing image
97
+ parse_name = im_name.replace('.jpg', '.png')
98
+ im_parse = Image.open(
99
+ # osp.join(self.data_path, 'image-parse', parse_name)).convert('L')
100
+ osp.join(self.data_path, 'image-parse-new', parse_name)).convert('L') # updated new segmentation
101
+ parse_array = np.array(im_parse)
102
+ im_mask = Image.open(
103
+ osp.join(self.data_path, 'image-mask', parse_name)).convert('L')
104
+ mask_array = np.array(im_mask)
105
+
106
+ # parse_shape = (parse_array > 0).astype(np.float32) # CP-VTON body shape
107
+ # Get shape from body mask (CP-VTON+)
108
+ parse_shape = (mask_array > 0).astype(np.float32)
109
+
110
+ if self.stage == 'GMM':
111
+ parse_head = (parse_array == 1).astype(np.float32) + \
112
+ (parse_array == 4).astype(np.float32) + \
113
+ (parse_array == 13).astype(
114
+ np.float32) # CP-VTON+ GMM input (reserved regions)
115
+ else:
116
+ parse_head = (parse_array == 1).astype(np.float32) + \
117
+ (parse_array == 2).astype(np.float32) + \
118
+ (parse_array == 4).astype(np.float32) + \
119
+ (parse_array == 9).astype(np.float32) + \
120
+ (parse_array == 12).astype(np.float32) + \
121
+ (parse_array == 13).astype(np.float32) + \
122
+ (parse_array == 16).astype(np.float32) + \
123
+ (parse_array == 17).astype(
124
+ np.float32) # CP-VTON+ TOM input (reserved regions)
125
+
126
+ parse_cloth = (parse_array == 5).astype(np.float32) + \
127
+ (parse_array == 6).astype(np.float32) + \
128
+ (parse_array == 7).astype(np.float32) # upper-clothes labels
129
+
130
+ # shape downsample
131
+ parse_shape_ori = Image.fromarray((parse_shape*255).astype(np.uint8))
132
+ parse_shape = parse_shape_ori.resize(
133
+ (self.fine_width//16, self.fine_height//16), Image.BILINEAR)
134
+ parse_shape = parse_shape.resize(
135
+ (self.fine_width, self.fine_height), Image.BILINEAR)
136
+ parse_shape_ori = parse_shape_ori.resize(
137
+ (self.fine_width, self.fine_height), Image.BILINEAR)
138
+ shape_ori = self.transform(parse_shape_ori) # [-1,1]
139
+ shape = self.transform(parse_shape) # [-1,1]
140
+ phead = torch.from_numpy(parse_head) # [0,1]
141
+ # phand = torch.from_numpy(parse_hand) # [0,1]
142
+ pcm = torch.from_numpy(parse_cloth) # [0,1]
143
+
144
+ # upper cloth
145
+ im_c = im * pcm + (1 - pcm) # [-1,1], fill 1 for other parts
146
+ im_h = im * phead - (1 - phead) # [-1,1], fill -1 for other parts
147
+
148
+ # load pose points
149
+ pose_name = im_name.replace('.jpg', '_keypoints.json')
150
+ with open(osp.join(self.data_path, 'pose', pose_name), 'r') as f:
151
+ pose_label = json.load(f)
152
+ pose_data = pose_label['people'][0]['pose_keypoints']
153
+ pose_data = np.array(pose_data)
154
+ pose_data = pose_data.reshape((-1, 3))
155
+
156
+ point_num = pose_data.shape[0]
157
+ pose_map = torch.zeros(point_num, self.fine_height, self.fine_width)
158
+ r = self.radius
159
+ im_pose = Image.new('L', (self.fine_width, self.fine_height))
160
+ pose_draw = ImageDraw.Draw(im_pose)
161
+ for i in range(point_num):
162
+ one_map = Image.new('L', (self.fine_width, self.fine_height))
163
+ draw = ImageDraw.Draw(one_map)
164
+ pointx = pose_data[i, 0]
165
+ pointy = pose_data[i, 1]
166
+ if pointx > 1 and pointy > 1:
167
+ draw.rectangle((pointx-r, pointy-r, pointx +
168
+ r, pointy+r), 'white', 'white')
169
+ pose_draw.rectangle(
170
+ (pointx-r, pointy-r, pointx+r, pointy+r), 'white', 'white')
171
+ one_map = self.transform(one_map)
172
+ pose_map[i] = one_map[0]
173
+
174
+ # just for visualization
175
+ im_pose = self.transform(im_pose)
176
+
177
+ # cloth-agnostic representation
178
+ agnostic = torch.cat([shape, im_h, pose_map], 0)
179
+
180
+ if self.stage == 'GMM':
181
+ im_g = Image.open('grid.png')
182
+ im_g = self.transform(im_g)
183
+ else:
184
+ im_g = ''
185
+
186
+ pcm.unsqueeze_(0) # CP-VTON+
187
+
188
+ result = {
189
+ 'c_name': c_name, # for visualization
190
+ 'im_name': im_name, # for visualization or ground truth
191
+ 'cloth': c, # for input
192
+ 'cloth_mask': cm, # for input
193
+ 'image': im, # for visualization
194
+ 'agnostic': agnostic, # for input
195
+ 'parse_cloth': im_c, # for ground truth
196
+ 'shape': shape, # for visualization
197
+ 'head': im_h, # for visualization
198
+ 'pose_image': im_pose, # for visualization
199
+ 'grid_image': im_g, # for visualization
200
+ 'parse_cloth_mask': pcm, # for CP-VTON+, TOM input
201
+ 'shape_ori': shape_ori, # original body shape without resize
202
+ }
203
+
204
+ return result
205
+
206
+ def __len__(self):
207
+ return len(self.im_names)
208
+
209
+
210
+ class CPDataLoader(object):
211
+ def __init__(self, opt, dataset):
212
+ super(CPDataLoader, self).__init__()
213
+
214
+ if opt.shuffle:
215
+ train_sampler = torch.utils.data.sampler.RandomSampler(dataset)
216
+ else:
217
+ train_sampler = None
218
+
219
+ self.data_loader = torch.utils.data.DataLoader(
220
+ dataset, batch_size=opt.batch_size, shuffle=(
221
+ train_sampler is None),
222
+ num_workers=opt.workers, pin_memory=True, sampler=train_sampler)
223
+ self.dataset = dataset
224
+ self.data_iter = self.data_loader.__iter__()
225
+
226
+ def next_batch(self):
227
+ try:
228
+ batch = self.data_iter.__next__()
229
+ except StopIteration:
230
+ self.data_iter = self.data_loader.__iter__()
231
+ batch = self.data_iter.__next__()
232
+
233
+ return batch
234
+
235
+
236
+ if __name__ == "__main__":
237
+ print("Check the dataset for geometric matching module!")
238
+
239
+ import argparse
240
+ parser = argparse.ArgumentParser()
241
+ parser.add_argument("--dataroot", default="data")
242
+ parser.add_argument("--datamode", default="train")
243
+ parser.add_argument("--stage", default="GMM")
244
+ parser.add_argument("--data_list", default="train_pairs.txt")
245
+ parser.add_argument("--fine_width", type=int, default=192)
246
+ parser.add_argument("--fine_height", type=int, default=256)
247
+ parser.add_argument("--radius", type=int, default=3)
248
+ parser.add_argument("--shuffle", action='store_true',
249
+ help='shuffle input data')
250
+ parser.add_argument('-b', '--batch-size', type=int, default=4)
251
+ parser.add_argument('-j', '--workers', type=int, default=1)
252
+
253
+ opt = parser.parse_args()
254
+ dataset = CPDataset(opt)
255
+ data_loader = CPDataLoader(opt, dataset)
256
+
257
+ print('Size of the dataset: %05d, dataloader: %04d'
258
+ % (len(dataset), len(data_loader.data_loader)))
259
+ first_item = dataset.__getitem__(0)
260
+ first_batch = data_loader.next_batch()
261
+
262
+ from IPython import embed
263
+ embed()
grid.png ADDED