Upload 9 files
Browse files- data/__init__.py +1 -0
- data/__pycache__/__init__.cpython-310.pyc +0 -0
- data/__pycache__/data_loader_test.cpython-310.pyc +0 -0
- data/aligned_dataset_test.py +56 -0
- data/base_data_loader.py +14 -0
- data/base_dataset.py +113 -0
- data/custom_dataset_data_loader_test.py +31 -0
- data/data_loader_test.py +7 -0
- data/image_folder.py +73 -0
data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# data_init
|
data/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (151 Bytes). View file
|
|
data/__pycache__/data_loader_test.cpython-310.pyc
ADDED
Binary file (430 Bytes). View file
|
|
data/aligned_dataset_test.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
from data.base_dataset import BaseDataset, get_params, get_transform
|
3 |
+
from PIL import Image
|
4 |
+
import linecache
|
5 |
+
|
6 |
+
class AlignedDataset(BaseDataset):
|
7 |
+
def initialize(self, opt):
|
8 |
+
self.opt = opt
|
9 |
+
self.root = opt.dataroot
|
10 |
+
|
11 |
+
self.fine_height=256
|
12 |
+
self.fine_width=192
|
13 |
+
|
14 |
+
self.dataset_size = len(open('demo.txt').readlines())
|
15 |
+
|
16 |
+
dir_I = '_img'
|
17 |
+
self.dir_I = os.path.join(opt.dataroot, opt.phase + dir_I)
|
18 |
+
|
19 |
+
dir_C = '_clothes'
|
20 |
+
self.dir_C = os.path.join(opt.dataroot, opt.phase + dir_C)
|
21 |
+
|
22 |
+
dir_E = '_edge'
|
23 |
+
self.dir_E = os.path.join(opt.dataroot, opt.phase + dir_E)
|
24 |
+
|
25 |
+
def __getitem__(self, index):
|
26 |
+
|
27 |
+
file_path ='demo.txt'
|
28 |
+
im_name, c_name = linecache.getline(file_path, index+1).strip().split()
|
29 |
+
|
30 |
+
I_path = os.path.join(self.dir_I,im_name)
|
31 |
+
I = Image.open(I_path).convert('RGB')
|
32 |
+
# new line
|
33 |
+
I = I.resize((self.fine_width, self.fine_height))
|
34 |
+
|
35 |
+
params = get_params(self.opt, I.size)
|
36 |
+
transform = get_transform(self.opt, params)
|
37 |
+
transform_E = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
|
38 |
+
|
39 |
+
I_tensor = transform(I)
|
40 |
+
|
41 |
+
C_path = os.path.join(self.dir_C,c_name)
|
42 |
+
C = Image.open(C_path).convert('RGB')
|
43 |
+
C_tensor = transform(C)
|
44 |
+
|
45 |
+
E_path = os.path.join(self.dir_E,c_name)
|
46 |
+
E = Image.open(E_path).convert('L')
|
47 |
+
E_tensor = transform_E(E)
|
48 |
+
|
49 |
+
input_dict = { 'image': I_tensor,'clothes': C_tensor, 'edge': E_tensor}
|
50 |
+
return input_dict
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return self.dataset_size
|
54 |
+
|
55 |
+
def name(self):
|
56 |
+
return 'AlignedDataset'
|
data/base_data_loader.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class BaseDataLoader():
|
3 |
+
def __init__(self):
|
4 |
+
pass
|
5 |
+
|
6 |
+
def initialize(self, opt):
|
7 |
+
self.opt = opt
|
8 |
+
pass
|
9 |
+
|
10 |
+
def load_data():
|
11 |
+
return None
|
12 |
+
|
13 |
+
|
14 |
+
|
data/base_dataset.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
from PIL import Image
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
|
7 |
+
class BaseDataset(data.Dataset):
|
8 |
+
def __init__(self):
|
9 |
+
super(BaseDataset, self).__init__()
|
10 |
+
|
11 |
+
def name(self):
|
12 |
+
return 'BaseDataset'
|
13 |
+
|
14 |
+
def initialize(self, opt):
|
15 |
+
pass
|
16 |
+
|
17 |
+
def get_params(opt, size):
|
18 |
+
w, h = size
|
19 |
+
new_h = h
|
20 |
+
new_w = w
|
21 |
+
if opt.resize_or_crop == 'resize_and_crop':
|
22 |
+
new_h = new_w = opt.loadSize
|
23 |
+
elif opt.resize_or_crop == 'scale_width_and_crop':
|
24 |
+
new_w = opt.loadSize
|
25 |
+
new_h = opt.loadSize * h // w
|
26 |
+
|
27 |
+
x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
|
28 |
+
y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
|
29 |
+
|
30 |
+
flip = 0
|
31 |
+
return {'crop_pos': (x, y), 'flip': flip}
|
32 |
+
|
33 |
+
def get_transform_resize(opt, params, method=Image.BICUBIC, normalize=True):
|
34 |
+
transform_list = []
|
35 |
+
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
|
36 |
+
osize = [256,192]
|
37 |
+
transform_list.append(transforms.Scale(osize, method))
|
38 |
+
if 'crop' in opt.resize_or_crop:
|
39 |
+
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
|
40 |
+
|
41 |
+
if opt.resize_or_crop == 'none':
|
42 |
+
base = float(2 ** opt.n_downsample_global)
|
43 |
+
if opt.netG == 'local':
|
44 |
+
base *= (2 ** opt.n_local_enhancers)
|
45 |
+
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
|
46 |
+
|
47 |
+
if opt.isTrain and not opt.no_flip:
|
48 |
+
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
|
49 |
+
|
50 |
+
transform_list += [transforms.ToTensor()]
|
51 |
+
|
52 |
+
if normalize:
|
53 |
+
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
|
54 |
+
(0.5, 0.5, 0.5))]
|
55 |
+
return transforms.Compose(transform_list)
|
56 |
+
|
57 |
+
def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
|
58 |
+
transform_list = []
|
59 |
+
if 'resize' in opt.resize_or_crop:
|
60 |
+
osize = [opt.loadSize, opt.loadSize]
|
61 |
+
transform_list.append(transforms.Scale(osize, method))
|
62 |
+
elif 'scale_width' in opt.resize_or_crop:
|
63 |
+
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
|
64 |
+
osize = [256,192]
|
65 |
+
transform_list.append(transforms.Scale(osize, method))
|
66 |
+
if 'crop' in opt.resize_or_crop:
|
67 |
+
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
|
68 |
+
|
69 |
+
if opt.resize_or_crop == 'none':
|
70 |
+
base = float(16)
|
71 |
+
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
|
72 |
+
|
73 |
+
if opt.isTrain and not opt.no_flip:
|
74 |
+
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
|
75 |
+
|
76 |
+
transform_list += [transforms.ToTensor()]
|
77 |
+
|
78 |
+
if normalize:
|
79 |
+
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
|
80 |
+
(0.5, 0.5, 0.5))]
|
81 |
+
return transforms.Compose(transform_list)
|
82 |
+
|
83 |
+
def normalize():
|
84 |
+
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
85 |
+
|
86 |
+
def __make_power_2(img, base, method=Image.BICUBIC):
|
87 |
+
ow, oh = img.size
|
88 |
+
h = int(round(oh / base) * base)
|
89 |
+
w = int(round(ow / base) * base)
|
90 |
+
if (h == oh) and (w == ow):
|
91 |
+
return img
|
92 |
+
return img.resize((w, h), method)
|
93 |
+
|
94 |
+
def __scale_width(img, target_width, method=Image.BICUBIC):
|
95 |
+
ow, oh = img.size
|
96 |
+
if (ow == target_width):
|
97 |
+
return img
|
98 |
+
w = target_width
|
99 |
+
h = int(target_width * oh / ow)
|
100 |
+
return img.resize((w, h), method)
|
101 |
+
|
102 |
+
def __crop(img, pos, size):
|
103 |
+
ow, oh = img.size
|
104 |
+
x1, y1 = pos
|
105 |
+
tw = th = size
|
106 |
+
if (ow > tw or oh > th):
|
107 |
+
return img.crop((x1, y1, x1 + tw, y1 + th))
|
108 |
+
return img
|
109 |
+
|
110 |
+
def __flip(img, flip):
|
111 |
+
if flip:
|
112 |
+
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
113 |
+
return img
|
data/custom_dataset_data_loader_test.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data
|
2 |
+
from data.base_data_loader import BaseDataLoader
|
3 |
+
|
4 |
+
|
5 |
+
def CreateDataset(opt):
|
6 |
+
dataset = None
|
7 |
+
from data.aligned_dataset_test import AlignedDataset
|
8 |
+
dataset = AlignedDataset()
|
9 |
+
|
10 |
+
print("dataset [%s] was created" % (dataset.name()))
|
11 |
+
dataset.initialize(opt)
|
12 |
+
return dataset
|
13 |
+
|
14 |
+
class CustomDatasetDataLoader(BaseDataLoader):
|
15 |
+
def name(self):
|
16 |
+
return 'CustomDatasetDataLoader'
|
17 |
+
|
18 |
+
def initialize(self, opt):
|
19 |
+
BaseDataLoader.initialize(self, opt)
|
20 |
+
self.dataset = CreateDataset(opt)
|
21 |
+
self.dataloader = torch.utils.data.DataLoader(
|
22 |
+
self.dataset,
|
23 |
+
batch_size=opt.batchSize,
|
24 |
+
shuffle = False,
|
25 |
+
num_workers=int(opt.nThreads))
|
26 |
+
|
27 |
+
def load_data(self):
|
28 |
+
return self.dataloader
|
29 |
+
|
30 |
+
def __len__(self):
|
31 |
+
return min(len(self.dataset), self.opt.max_dataset_size)
|
data/data_loader_test.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
def CreateDataLoader(opt):
|
3 |
+
from data.custom_dataset_data_loader_test import CustomDatasetDataLoader
|
4 |
+
data_loader = CustomDatasetDataLoader()
|
5 |
+
print(data_loader.name())
|
6 |
+
data_loader.initialize(opt)
|
7 |
+
return data_loader
|
data/image_folder.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
from PIL import Image
|
3 |
+
import os
|
4 |
+
|
5 |
+
IMG_EXTENSIONS = [
|
6 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
7 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
|
8 |
+
]
|
9 |
+
|
10 |
+
|
11 |
+
def is_image_file(filename):
|
12 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
13 |
+
|
14 |
+
def make_dataset(dir):
|
15 |
+
images = []
|
16 |
+
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
17 |
+
|
18 |
+
f = dir.split('/')[-1].split('_')[-1]
|
19 |
+
print (dir, f)
|
20 |
+
dirs= os.listdir(dir)
|
21 |
+
for img in dirs:
|
22 |
+
|
23 |
+
path = os.path.join(dir, img)
|
24 |
+
#print(path)
|
25 |
+
images.append(path)
|
26 |
+
return images
|
27 |
+
|
28 |
+
def make_dataset_test(dir):
|
29 |
+
images = []
|
30 |
+
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
31 |
+
|
32 |
+
f = dir.split('/')[-1].split('_')[-1]
|
33 |
+
for i in range(len([name for name in os.listdir(dir) if os.path.isfile(os.path.join(dir, name))])):
|
34 |
+
if f == 'label' or f == 'labelref':
|
35 |
+
img = str(i) + '.png'
|
36 |
+
else:
|
37 |
+
img = str(i) + '.jpg'
|
38 |
+
path = os.path.join(dir, img)
|
39 |
+
images.append(path)
|
40 |
+
return images
|
41 |
+
|
42 |
+
def default_loader(path):
|
43 |
+
return Image.open(path).convert('RGB')
|
44 |
+
|
45 |
+
|
46 |
+
class ImageFolder(data.Dataset):
|
47 |
+
|
48 |
+
def __init__(self, root, transform=None, return_paths=False,
|
49 |
+
loader=default_loader):
|
50 |
+
imgs = make_dataset(root)
|
51 |
+
if len(imgs) == 0:
|
52 |
+
raise(RuntimeError("Found 0 images in: " + root + "\n"
|
53 |
+
"Supported image extensions are: " +
|
54 |
+
",".join(IMG_EXTENSIONS)))
|
55 |
+
|
56 |
+
self.root = root
|
57 |
+
self.imgs = imgs
|
58 |
+
self.transform = transform
|
59 |
+
self.return_paths = return_paths
|
60 |
+
self.loader = loader
|
61 |
+
|
62 |
+
def __getitem__(self, index):
|
63 |
+
path = self.imgs[index]
|
64 |
+
img = self.loader(path)
|
65 |
+
if self.transform is not None:
|
66 |
+
img = self.transform(img)
|
67 |
+
if self.return_paths:
|
68 |
+
return img, path
|
69 |
+
else:
|
70 |
+
return img
|
71 |
+
|
72 |
+
def __len__(self):
|
73 |
+
return len(self.imgs)
|