init
Browse files- APDrawingGAN2/data/__init__.py +75 -0
- APDrawingGAN2/data/aligned_dataset.py +288 -0
- APDrawingGAN2/data/base_data_loader.py +10 -0
- APDrawingGAN2/data/base_dataset.py +103 -0
- APDrawingGAN2/data/image_folder.py +68 -0
- APDrawingGAN2/data/single_dataset.py +176 -0
- APDrawingGAN2/docs/tips.md +8 -0
- APDrawingGAN2/models/__init__.py +39 -0
- APDrawingGAN2/models/apdrawingpp_style_model.py +692 -0
- APDrawingGAN2/models/base_model.py +545 -0
- APDrawingGAN2/models/networks.py +1194 -0
- APDrawingGAN2/models/test_model.py +214 -0
- APDrawingGAN2/options/__init__.py +0 -0
- APDrawingGAN2/options/base_options.py +192 -0
- APDrawingGAN2/options/test_options.py +23 -0
- APDrawingGAN2/options/train_options.py +62 -0
- APDrawingGAN2/preprocess/combine_A_and_B.py +48 -0
- APDrawingGAN2/preprocess/example/img_1701.jpg +0 -0
- APDrawingGAN2/preprocess/example/img_1701_aligned.png +0 -0
- APDrawingGAN2/preprocess/example/img_1701_aligned.txt +5 -0
- APDrawingGAN2/preprocess/example/img_1701_aligned_68lm.txt +68 -0
- APDrawingGAN2/preprocess/example/img_1701_aligned_bgmask.png +0 -0
- APDrawingGAN2/preprocess/example/img_1701_aligned_eyelmask.png +0 -0
- APDrawingGAN2/preprocess/example/img_1701_aligned_eyermask.png +0 -0
- APDrawingGAN2/preprocess/example/img_1701_aligned_facemask.png +0 -0
- APDrawingGAN2/preprocess/example/img_1701_aligned_mouthmask.png +0 -0
- APDrawingGAN2/preprocess/example/img_1701_aligned_nosemask.png +0 -0
- APDrawingGAN2/preprocess/example/img_1701_facial5point.mat +0 -0
- APDrawingGAN2/preprocess/face_align_512.m +55 -0
- APDrawingGAN2/preprocess/get_partmask.py +152 -0
- APDrawingGAN2/preprocess/readme.md +71 -0
- APDrawingGAN2/readme.md +105 -0
- APDrawingGAN2/requirements.txt +10 -0
- APDrawingGAN2/script/test.sh +2 -0
- APDrawingGAN2/script/test_single.sh +2 -0
- APDrawingGAN2/script/train.sh +3 -0
- APDrawingGAN2/test.py +69 -0
- APDrawingGAN2/train.py +67 -0
- APDrawingGAN2/util/__init__.py +0 -0
- APDrawingGAN2/util/get_data.py +115 -0
- APDrawingGAN2/util/html.py +68 -0
- APDrawingGAN2/util/image_pool.py +32 -0
- APDrawingGAN2/util/util.py +60 -0
- APDrawingGAN2/util/visualizer.py +171 -0
- README.md +1 -0
- app.py +210 -0
- packages.txt +2 -0
- requirements.txt +8 -0
APDrawingGAN2/data/__init__.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import torch.utils.data
|
3 |
+
from data.base_data_loader import BaseDataLoader
|
4 |
+
from data.base_dataset import BaseDataset
|
5 |
+
|
6 |
+
|
7 |
+
def find_dataset_using_name(dataset_name):
|
8 |
+
# Given the option --dataset_mode [datasetname],
|
9 |
+
# the file "data/datasetname_dataset.py"
|
10 |
+
# will be imported.
|
11 |
+
dataset_filename = "data." + dataset_name + "_dataset"
|
12 |
+
datasetlib = importlib.import_module(dataset_filename)
|
13 |
+
|
14 |
+
# In the file, the class called DatasetNameDataset() will
|
15 |
+
# be instantiated. It has to be a subclass of BaseDataset,
|
16 |
+
# and it is case-insensitive.
|
17 |
+
dataset = None
|
18 |
+
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
|
19 |
+
for name, cls in datasetlib.__dict__.items():
|
20 |
+
if name.lower() == target_dataset_name.lower() \
|
21 |
+
and issubclass(cls, BaseDataset):
|
22 |
+
dataset = cls
|
23 |
+
|
24 |
+
if dataset is None:
|
25 |
+
print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
|
26 |
+
exit(0)
|
27 |
+
|
28 |
+
return dataset
|
29 |
+
|
30 |
+
|
31 |
+
def get_option_setter(dataset_name):
|
32 |
+
dataset_class = find_dataset_using_name(dataset_name)
|
33 |
+
return dataset_class.modify_commandline_options
|
34 |
+
|
35 |
+
|
36 |
+
def create_dataset(opt):
|
37 |
+
dataset = find_dataset_using_name(opt.dataset_mode)
|
38 |
+
instance = dataset()
|
39 |
+
instance.initialize(opt)
|
40 |
+
print("dataset [%s] was created" % (instance.name()))
|
41 |
+
return instance
|
42 |
+
|
43 |
+
|
44 |
+
def CreateDataLoader(opt):
|
45 |
+
data_loader = CustomDatasetDataLoader()
|
46 |
+
data_loader.initialize(opt)
|
47 |
+
return data_loader
|
48 |
+
|
49 |
+
|
50 |
+
# Wrapper class of Dataset class that performs
|
51 |
+
# multi-threaded data loading
|
52 |
+
class CustomDatasetDataLoader(BaseDataLoader):
|
53 |
+
def name(self):
|
54 |
+
return 'CustomDatasetDataLoader'
|
55 |
+
|
56 |
+
def initialize(self, opt):
|
57 |
+
BaseDataLoader.initialize(self, opt)
|
58 |
+
self.dataset = create_dataset(opt)
|
59 |
+
self.dataloader = torch.utils.data.DataLoader(
|
60 |
+
self.dataset,
|
61 |
+
batch_size=opt.batch_size,
|
62 |
+
shuffle=not opt.serial_batches,#in training, serial_batches by default is false, shuffle=true
|
63 |
+
num_workers=int(opt.num_threads))
|
64 |
+
|
65 |
+
def load_data(self):
|
66 |
+
return self
|
67 |
+
|
68 |
+
def __len__(self):
|
69 |
+
return min(len(self.dataset), self.opt.max_dataset_size)
|
70 |
+
|
71 |
+
def __iter__(self):
|
72 |
+
for i, data in enumerate(self.dataloader):
|
73 |
+
if i * self.opt.batch_size >= self.opt.max_dataset_size:
|
74 |
+
break
|
75 |
+
yield data
|
APDrawingGAN2/data/aligned_dataset.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import random
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
import torch
|
5 |
+
from data.base_dataset import BaseDataset
|
6 |
+
from data.image_folder import make_dataset
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
import cv2
|
10 |
+
import csv
|
11 |
+
|
12 |
+
def getfeats(featpath):
|
13 |
+
trans_points = np.empty([5,2],dtype=np.int64)
|
14 |
+
with open(featpath, 'r') as csvfile:
|
15 |
+
reader = csv.reader(csvfile, delimiter=' ')
|
16 |
+
for ind,row in enumerate(reader):
|
17 |
+
trans_points[ind,:] = row
|
18 |
+
return trans_points
|
19 |
+
|
20 |
+
def tocv2(ts):
|
21 |
+
img = (ts.numpy()/2+0.5)*255
|
22 |
+
img = img.astype('uint8')
|
23 |
+
img = np.transpose(img,(1,2,0))
|
24 |
+
img = img[:,:,::-1]#rgb->bgr
|
25 |
+
return img
|
26 |
+
|
27 |
+
def dt(img):
|
28 |
+
if(img.shape[2]==3):
|
29 |
+
img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
|
30 |
+
#convert to BW
|
31 |
+
ret1,thresh1 = cv2.threshold(img,127,255,cv2.THRESH_BINARY)
|
32 |
+
ret2,thresh2 = cv2.threshold(img,127,255,cv2.THRESH_BINARY_INV)
|
33 |
+
dt1 = cv2.distanceTransform(thresh1,cv2.DIST_L2,5)
|
34 |
+
dt2 = cv2.distanceTransform(thresh2,cv2.DIST_L2,5)
|
35 |
+
dt1 = dt1/dt1.max()#->[0,1]
|
36 |
+
dt2 = dt2/dt2.max()
|
37 |
+
return dt1, dt2
|
38 |
+
|
39 |
+
def getSoft(size,xb,yb,boundwidth=5.0):
|
40 |
+
xarray = np.tile(np.arange(0,size[1]),(size[0],1))
|
41 |
+
yarray = np.tile(np.arange(0,size[0]),(size[1],1)).transpose()
|
42 |
+
cxdists = []
|
43 |
+
cydists = []
|
44 |
+
for i in range(len(xb)):
|
45 |
+
xba = np.tile(xb[i],(size[1],1)).transpose()
|
46 |
+
yba = np.tile(yb[i],(size[0],1))
|
47 |
+
cxdists.append(np.abs(xarray-xba))
|
48 |
+
cydists.append(np.abs(yarray-yba))
|
49 |
+
xdist = np.minimum.reduce(cxdists)
|
50 |
+
ydist = np.minimum.reduce(cydists)
|
51 |
+
manhdist = np.minimum.reduce([xdist,ydist])
|
52 |
+
im = (manhdist+1) / (boundwidth+1) * 1.0
|
53 |
+
im[im>=1.0] = 1.0
|
54 |
+
return im
|
55 |
+
|
56 |
+
class AlignedDataset(BaseDataset):
|
57 |
+
@staticmethod
|
58 |
+
def modify_commandline_options(parser, is_train):
|
59 |
+
return parser
|
60 |
+
|
61 |
+
def initialize(self, opt):
|
62 |
+
self.opt = opt
|
63 |
+
self.root = opt.dataroot
|
64 |
+
imglist = 'datasets/apdrawing_list/%s/%s.txt' % (opt.phase, opt.dataroot)
|
65 |
+
if os.path.exists(imglist):
|
66 |
+
lines = open(imglist, 'r').read().splitlines()
|
67 |
+
lines = sorted(lines)
|
68 |
+
self.AB_paths = [line.split()[0] for line in lines]
|
69 |
+
if len(lines[0].split()) == 2:
|
70 |
+
self.B_paths = [line.split()[1] for line in lines]
|
71 |
+
else:
|
72 |
+
self.dir_AB = os.path.join(opt.dataroot, opt.phase)
|
73 |
+
self.AB_paths = sorted(make_dataset(self.dir_AB))
|
74 |
+
assert(opt.resize_or_crop == 'resize_and_crop')
|
75 |
+
|
76 |
+
def __getitem__(self, index):
|
77 |
+
AB_path = self.AB_paths[index]
|
78 |
+
AB = Image.open(AB_path).convert('RGB')
|
79 |
+
w, h = AB.size
|
80 |
+
if w/h == 2:
|
81 |
+
w2 = int(w / 2)
|
82 |
+
A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
|
83 |
+
B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
|
84 |
+
else: # if w/h != 2, need B_paths
|
85 |
+
A = AB.resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
|
86 |
+
B = Image.open(self.B_paths[index]).convert('RGB')
|
87 |
+
B = B.resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
|
88 |
+
A = transforms.ToTensor()(A)
|
89 |
+
B = transforms.ToTensor()(B)
|
90 |
+
w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
|
91 |
+
h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
|
92 |
+
|
93 |
+
A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]#C,H,W
|
94 |
+
B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
|
95 |
+
|
96 |
+
A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A)
|
97 |
+
B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B)
|
98 |
+
|
99 |
+
if self.opt.which_direction == 'BtoA':
|
100 |
+
input_nc = self.opt.output_nc
|
101 |
+
output_nc = self.opt.input_nc
|
102 |
+
else:
|
103 |
+
input_nc = self.opt.input_nc
|
104 |
+
output_nc = self.opt.output_nc
|
105 |
+
|
106 |
+
flipped = False
|
107 |
+
if (not self.opt.no_flip) and random.random() < 0.5:
|
108 |
+
flipped = True
|
109 |
+
idx = [i for i in range(A.size(2) - 1, -1, -1)]
|
110 |
+
idx = torch.LongTensor(idx)
|
111 |
+
A = A.index_select(2, idx)
|
112 |
+
B = B.index_select(2, idx)
|
113 |
+
|
114 |
+
if input_nc == 1: # RGB to gray
|
115 |
+
tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
|
116 |
+
A = tmp.unsqueeze(0)
|
117 |
+
|
118 |
+
if output_nc == 1: # RGB to gray
|
119 |
+
tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
|
120 |
+
B = tmp.unsqueeze(0)
|
121 |
+
|
122 |
+
item = {'A': A, 'B': B,
|
123 |
+
'A_paths': AB_path, 'B_paths': AB_path}
|
124 |
+
|
125 |
+
if self.opt.use_local:
|
126 |
+
regions = ['eyel','eyer','nose','mouth']
|
127 |
+
basen = os.path.basename(AB_path)[:-4]+'.txt'
|
128 |
+
if self.opt.region_enm in [0,1]:
|
129 |
+
featdir = self.opt.lm_dir
|
130 |
+
featpath = os.path.join(featdir,basen)
|
131 |
+
feats = getfeats(featpath)
|
132 |
+
if flipped:
|
133 |
+
for i in range(5):
|
134 |
+
feats[i,0] = self.opt.fineSize - feats[i,0] - 1
|
135 |
+
tmp = [feats[0,0],feats[0,1]]
|
136 |
+
feats[0,:] = [feats[1,0],feats[1,1]]
|
137 |
+
feats[1,:] = tmp
|
138 |
+
mouth_x = int((feats[3,0]+feats[4,0])/2.0)
|
139 |
+
mouth_y = int((feats[3,1]+feats[4,1])/2.0)
|
140 |
+
ratio = self.opt.fineSize / 256
|
141 |
+
EYE_H = self.opt.EYE_H * ratio
|
142 |
+
EYE_W = self.opt.EYE_W * ratio
|
143 |
+
NOSE_H = self.opt.NOSE_H * ratio
|
144 |
+
NOSE_W = self.opt.NOSE_W * ratio
|
145 |
+
MOUTH_H = self.opt.MOUTH_H * ratio
|
146 |
+
MOUTH_W = self.opt.MOUTH_W * ratio
|
147 |
+
center = torch.IntTensor([[feats[0,0],feats[0,1]-4*ratio],[feats[1,0],feats[1,1]-4*ratio],[feats[2,0],feats[2,1]-NOSE_H/2+16*ratio],[mouth_x,mouth_y]])
|
148 |
+
item['center'] = center
|
149 |
+
rhs = [int(EYE_H),int(EYE_H),int(NOSE_H),int(MOUTH_H)]
|
150 |
+
rws = [int(EYE_W),int(EYE_W),int(NOSE_W),int(MOUTH_W)]
|
151 |
+
if self.opt.soft_border:
|
152 |
+
soft_border_mask4 = []
|
153 |
+
for i in range(4):
|
154 |
+
xb = [np.zeros(rhs[i]),np.ones(rhs[i])*(rws[i]-1)]
|
155 |
+
yb = [np.zeros(rws[i]),np.ones(rws[i])*(rhs[i]-1)]
|
156 |
+
soft_border_mask = getSoft([rhs[i],rws[i]],xb,yb)
|
157 |
+
soft_border_mask4.append(torch.Tensor(soft_border_mask).unsqueeze(0))
|
158 |
+
item['soft_'+regions[i]+'_mask'] = soft_border_mask4[i]
|
159 |
+
for i in range(4):
|
160 |
+
item[regions[i]+'_A'] = A[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2]
|
161 |
+
item[regions[i]+'_B'] = B[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2]
|
162 |
+
if self.opt.soft_border:
|
163 |
+
item[regions[i]+'_A'] = item[regions[i]+'_A'] * soft_border_mask4[i].repeat(int(input_nc/output_nc),1,1)
|
164 |
+
item[regions[i]+'_B'] = item[regions[i]+'_B'] * soft_border_mask4[i]
|
165 |
+
if self.opt.compactmask:
|
166 |
+
cmasks0 = []
|
167 |
+
cmasks = []
|
168 |
+
for i in range(4):
|
169 |
+
if flipped and i in [0,1]:
|
170 |
+
cmaskpath = os.path.join(self.opt.cmask_dir,regions[1-i],basen[:-4]+'.png')
|
171 |
+
else:
|
172 |
+
cmaskpath = os.path.join(self.opt.cmask_dir,regions[i],basen[:-4]+'.png')
|
173 |
+
im_cmask = Image.open(cmaskpath)
|
174 |
+
cmask0 = transforms.ToTensor()(im_cmask)
|
175 |
+
if flipped:
|
176 |
+
cmask0 = cmask0.index_select(2, idx)
|
177 |
+
if output_nc == 1 and cmask0.shape[0] == 3:
|
178 |
+
tmp = cmask0[0, ...] * 0.299 + cmask0[1, ...] * 0.587 + cmask0[2, ...] * 0.114
|
179 |
+
cmask0 = tmp.unsqueeze(0)
|
180 |
+
cmask0 = (cmask0 >= 0.5).float()
|
181 |
+
cmasks0.append(cmask0)
|
182 |
+
cmask = cmask0.clone()
|
183 |
+
if self.opt.region_enm in [0,1]:
|
184 |
+
cmask = cmask[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2]
|
185 |
+
elif self.opt.region_enm in [2]: # need to multiply cmask
|
186 |
+
item[regions[i]+'_A'] = (A/2+0.5) * cmask * 2 - 1
|
187 |
+
item[regions[i]+'_B'] = (B/2+0.5) * cmask * 2 - 1
|
188 |
+
cmasks.append(cmask)
|
189 |
+
item['cmaskel'] = cmasks[0]
|
190 |
+
item['cmasker'] = cmasks[1]
|
191 |
+
item['cmask'] = cmasks[2]
|
192 |
+
item['cmaskmo'] = cmasks[3]
|
193 |
+
if self.opt.hair_local:
|
194 |
+
mask = torch.ones(B.shape)
|
195 |
+
if self.opt.region_enm == 0:
|
196 |
+
for i in range(4):
|
197 |
+
mask[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2] = 0
|
198 |
+
if self.opt.soft_border:
|
199 |
+
imgsize = self.opt.fineSize
|
200 |
+
maskn = mask[0].numpy()
|
201 |
+
masks = [np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize])]
|
202 |
+
masks[0][1:] = maskn[:-1]
|
203 |
+
masks[1][:-1] = maskn[1:]
|
204 |
+
masks[2][:,1:] = maskn[:,:-1]
|
205 |
+
masks[3][:,:-1] = maskn[:,1:]
|
206 |
+
masks2 = [maskn-e for e in masks]
|
207 |
+
bound = np.minimum.reduce(masks2)
|
208 |
+
bound = -bound
|
209 |
+
xb = []
|
210 |
+
yb = []
|
211 |
+
for i in range(4):
|
212 |
+
xbi = [center[i,0]-rws[i]/2, center[i,0]+rws[i]/2-1]
|
213 |
+
ybi = [center[i,1]-rhs[i]/2, center[i,1]+rhs[i]/2-1]
|
214 |
+
for j in range(2):
|
215 |
+
maskx = bound[:,xbi[j]]
|
216 |
+
masky = bound[ybi[j],:]
|
217 |
+
tmp_a = torch.from_numpy(maskx)*xbi[j].double()
|
218 |
+
tmp_b = torch.from_numpy(1-maskx)
|
219 |
+
xb += [tmp_b*10000 + tmp_a]
|
220 |
+
|
221 |
+
tmp_a = torch.from_numpy(masky)*ybi[j].double()
|
222 |
+
tmp_b = torch.from_numpy(1-masky)
|
223 |
+
yb += [tmp_b*10000 + tmp_a]
|
224 |
+
soft = 1-getSoft([imgsize,imgsize],xb,yb)
|
225 |
+
soft = torch.Tensor(soft).unsqueeze(0)
|
226 |
+
mask = (torch.ones(mask.shape)-mask)*soft + mask
|
227 |
+
elif self.opt.region_enm == 1:
|
228 |
+
for i in range(4):
|
229 |
+
cmask0 = cmasks0[i]
|
230 |
+
rec = torch.zeros(B.shape)
|
231 |
+
rec[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2] = 1
|
232 |
+
mask = mask * (torch.ones(B.shape) - cmask0 * rec)
|
233 |
+
elif self.opt.region_enm == 2:
|
234 |
+
for i in range(4):
|
235 |
+
cmask0 = cmasks0[i]
|
236 |
+
mask = mask * (torch.ones(B.shape) - cmask0)
|
237 |
+
hair_A = (A/2+0.5) * mask.repeat(int(input_nc/output_nc),1,1) * 2 - 1
|
238 |
+
hair_B = (B/2+0.5) * mask * 2 - 1
|
239 |
+
item['hair_A'] = hair_A
|
240 |
+
item['hair_B'] = hair_B
|
241 |
+
item['mask'] = mask # mask out eyes, nose, mouth
|
242 |
+
if self.opt.bg_local:
|
243 |
+
bgdir = self.opt.bg_dir
|
244 |
+
bgpath = os.path.join(bgdir,basen[:-4]+'.png')
|
245 |
+
im_bg = Image.open(bgpath)
|
246 |
+
mask2 = transforms.ToTensor()(im_bg) # mask out background
|
247 |
+
if flipped:
|
248 |
+
mask2 = mask2.index_select(2, idx)
|
249 |
+
mask2 = (mask2 >= 0.5).float()
|
250 |
+
hair_A = (A/2+0.5) * mask.repeat(int(input_nc/output_nc),1,1) * mask2.repeat(int(input_nc/output_nc),1,1) * 2 - 1
|
251 |
+
hair_B = (B/2+0.5) * mask * mask2 * 2 - 1
|
252 |
+
bg_A = (A/2+0.5) * (torch.ones(mask2.shape)-mask2).repeat(int(input_nc/output_nc),1,1) * 2 - 1
|
253 |
+
bg_B = (B/2+0.5) * (torch.ones(mask2.shape)-mask2) * 2 - 1
|
254 |
+
item['hair_A'] = hair_A
|
255 |
+
item['hair_B'] = hair_B
|
256 |
+
item['bg_A'] = bg_A
|
257 |
+
item['bg_B'] = bg_B
|
258 |
+
item['mask'] = mask
|
259 |
+
item['mask2'] = mask2
|
260 |
+
|
261 |
+
if (self.opt.isTrain and self.opt.chamfer_loss):
|
262 |
+
if self.opt.which_direction == 'AtoB':
|
263 |
+
img = tocv2(B)
|
264 |
+
else:
|
265 |
+
img = tocv2(A)
|
266 |
+
dt1, dt2 = dt(img)
|
267 |
+
dt1 = torch.from_numpy(dt1)
|
268 |
+
dt2 = torch.from_numpy(dt2)
|
269 |
+
dt1 = dt1.unsqueeze(0)
|
270 |
+
dt2 = dt2.unsqueeze(0)
|
271 |
+
item['dt1gt'] = dt1
|
272 |
+
item['dt2gt'] = dt2
|
273 |
+
|
274 |
+
if self.opt.isTrain and self.opt.emphasis_conti_face:
|
275 |
+
face_mask_path = os.path.join(self.opt.facemask_dir,basen[:-4]+'.png')
|
276 |
+
face_mask = Image.open(face_mask_path)
|
277 |
+
face_mask = transforms.ToTensor()(face_mask) # [0,1]
|
278 |
+
if flipped:
|
279 |
+
face_mask = face_mask.index_select(2, idx)
|
280 |
+
item['face_mask'] = face_mask
|
281 |
+
|
282 |
+
return item
|
283 |
+
|
284 |
+
def __len__(self):
|
285 |
+
return len(self.AB_paths)
|
286 |
+
|
287 |
+
def name(self):
|
288 |
+
return 'AlignedDataset'
|
APDrawingGAN2/data/base_data_loader.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class BaseDataLoader():
|
2 |
+
def __init__(self):
|
3 |
+
pass
|
4 |
+
|
5 |
+
def initialize(self, opt):
|
6 |
+
self.opt = opt
|
7 |
+
pass
|
8 |
+
|
9 |
+
def load_data():
|
10 |
+
return None
|
APDrawingGAN2/data/base_dataset.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
from PIL import Image
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
|
5 |
+
|
6 |
+
class BaseDataset(data.Dataset):
|
7 |
+
def __init__(self):
|
8 |
+
super(BaseDataset, self).__init__()
|
9 |
+
|
10 |
+
def name(self):
|
11 |
+
return 'BaseDataset'
|
12 |
+
|
13 |
+
@staticmethod
|
14 |
+
def modify_commandline_options(parser, is_train):
|
15 |
+
return parser
|
16 |
+
|
17 |
+
def initialize(self, opt):
|
18 |
+
pass
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return 0
|
22 |
+
|
23 |
+
|
24 |
+
def get_transform(opt):
|
25 |
+
transform_list = []
|
26 |
+
if opt.resize_or_crop == 'resize_and_crop':
|
27 |
+
osize = [opt.loadSize, opt.fineSize]
|
28 |
+
transform_list.append(transforms.Resize(osize, Image.BICUBIC))
|
29 |
+
transform_list.append(transforms.RandomCrop(opt.fineSize))
|
30 |
+
elif opt.resize_or_crop == 'crop':
|
31 |
+
transform_list.append(transforms.RandomCrop(opt.fineSize))
|
32 |
+
elif opt.resize_or_crop == 'scale_width':
|
33 |
+
transform_list.append(transforms.Lambda(
|
34 |
+
lambda img: __scale_width(img, opt.fineSize)))
|
35 |
+
elif opt.resize_or_crop == 'scale_width_and_crop':
|
36 |
+
transform_list.append(transforms.Lambda(
|
37 |
+
lambda img: __scale_width(img, opt.loadSize)))
|
38 |
+
transform_list.append(transforms.RandomCrop(opt.fineSize))
|
39 |
+
elif opt.resize_or_crop == 'none':
|
40 |
+
transform_list.append(transforms.Lambda(
|
41 |
+
lambda img: __adjust(img)))
|
42 |
+
else:
|
43 |
+
raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop)
|
44 |
+
|
45 |
+
if opt.isTrain and not opt.no_flip:
|
46 |
+
transform_list.append(transforms.RandomHorizontalFlip())
|
47 |
+
|
48 |
+
transform_list += [transforms.ToTensor(),
|
49 |
+
transforms.Normalize((0.5, 0.5, 0.5),
|
50 |
+
(0.5, 0.5, 0.5))]
|
51 |
+
return transforms.Compose(transform_list)
|
52 |
+
|
53 |
+
# just modify the width and height to be multiple of 4
|
54 |
+
def __adjust(img):
|
55 |
+
ow, oh = img.size
|
56 |
+
|
57 |
+
# the size needs to be a multiple of this number,
|
58 |
+
# because going through generator network may change img size
|
59 |
+
# and eventually cause size mismatch error
|
60 |
+
mult = 4
|
61 |
+
if ow % mult == 0 and oh % mult == 0:
|
62 |
+
return img
|
63 |
+
w = (ow - 1) // mult
|
64 |
+
w = (w + 1) * mult
|
65 |
+
h = (oh - 1) // mult
|
66 |
+
h = (h + 1) * mult
|
67 |
+
|
68 |
+
if ow != w or oh != h:
|
69 |
+
__print_size_warning(ow, oh, w, h)
|
70 |
+
|
71 |
+
return img.resize((w, h), Image.BICUBIC)
|
72 |
+
|
73 |
+
|
74 |
+
def __scale_width(img, target_width):
|
75 |
+
ow, oh = img.size
|
76 |
+
|
77 |
+
# the size needs to be a multiple of this number,
|
78 |
+
# because going through generator network may change img size
|
79 |
+
# and eventually cause size mismatch error
|
80 |
+
mult = 4
|
81 |
+
assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult
|
82 |
+
if (ow == target_width and oh % mult == 0):
|
83 |
+
return img
|
84 |
+
w = target_width
|
85 |
+
target_height = int(target_width * oh / ow)
|
86 |
+
m = (target_height - 1) // mult
|
87 |
+
h = (m + 1) * mult
|
88 |
+
|
89 |
+
if target_height != h:
|
90 |
+
__print_size_warning(target_width, target_height, w, h)
|
91 |
+
|
92 |
+
return img.resize((w, h), Image.BICUBIC)
|
93 |
+
|
94 |
+
|
95 |
+
def __print_size_warning(ow, oh, w, h):
|
96 |
+
if not hasattr(__print_size_warning, 'has_printed'):
|
97 |
+
print("The image size needs to be a multiple of 4. "
|
98 |
+
"The loaded image size was (%d, %d), so it was adjusted to "
|
99 |
+
"(%d, %d). This adjustment will be done to all images "
|
100 |
+
"whose sizes are not multiples of 4" % (ow, oh, w, h))
|
101 |
+
__print_size_warning.has_printed = True
|
102 |
+
|
103 |
+
|
APDrawingGAN2/data/image_folder.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###############################################################################
|
2 |
+
# Code from
|
3 |
+
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
|
4 |
+
# Modified the original code so that it also loads images from the current
|
5 |
+
# directory as well as the subdirectories
|
6 |
+
###############################################################################
|
7 |
+
|
8 |
+
import torch.utils.data as data
|
9 |
+
|
10 |
+
from PIL import Image
|
11 |
+
import os
|
12 |
+
import os.path
|
13 |
+
|
14 |
+
IMG_EXTENSIONS = [
|
15 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
16 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
17 |
+
]
|
18 |
+
|
19 |
+
|
20 |
+
def is_image_file(filename):
|
21 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
22 |
+
|
23 |
+
|
24 |
+
def make_dataset(dir):
|
25 |
+
images = []
|
26 |
+
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
27 |
+
|
28 |
+
for root, _, fnames in sorted(os.walk(dir)):
|
29 |
+
for fname in fnames:
|
30 |
+
if is_image_file(fname):
|
31 |
+
path = os.path.join(root, fname)
|
32 |
+
images.append(path)
|
33 |
+
|
34 |
+
return images
|
35 |
+
|
36 |
+
|
37 |
+
def default_loader(path):
|
38 |
+
return Image.open(path).convert('RGB')
|
39 |
+
|
40 |
+
|
41 |
+
class ImageFolder(data.Dataset):
|
42 |
+
|
43 |
+
def __init__(self, root, transform=None, return_paths=False,
|
44 |
+
loader=default_loader):
|
45 |
+
imgs = make_dataset(root)
|
46 |
+
if len(imgs) == 0:
|
47 |
+
raise(RuntimeError("Found 0 images in: " + root + "\n"
|
48 |
+
"Supported image extensions are: " +
|
49 |
+
",".join(IMG_EXTENSIONS)))
|
50 |
+
|
51 |
+
self.root = root
|
52 |
+
self.imgs = imgs
|
53 |
+
self.transform = transform
|
54 |
+
self.return_paths = return_paths
|
55 |
+
self.loader = loader
|
56 |
+
|
57 |
+
def __getitem__(self, index):
|
58 |
+
path = self.imgs[index]
|
59 |
+
img = self.loader(path)
|
60 |
+
if self.transform is not None:
|
61 |
+
img = self.transform(img)
|
62 |
+
if self.return_paths:
|
63 |
+
return img, path
|
64 |
+
else:
|
65 |
+
return img
|
66 |
+
|
67 |
+
def __len__(self):
|
68 |
+
return len(self.imgs)
|
APDrawingGAN2/data/single_dataset.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
from data.base_dataset import BaseDataset, get_transform
|
3 |
+
from data.image_folder import make_dataset
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import csv
|
7 |
+
import torch
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
|
10 |
+
def getfeats(featpath):
|
11 |
+
trans_points = np.empty([5,2],dtype=np.int64)
|
12 |
+
with open(featpath, 'r') as csvfile:
|
13 |
+
reader = csv.reader(csvfile, delimiter=' ')
|
14 |
+
for ind,row in enumerate(reader):
|
15 |
+
trans_points[ind,:] = row
|
16 |
+
return trans_points
|
17 |
+
|
18 |
+
def getSoft(size,xb,yb,boundwidth=5.0):
|
19 |
+
xarray = np.tile(np.arange(0,size[1]),(size[0],1))
|
20 |
+
yarray = np.tile(np.arange(0,size[0]),(size[1],1)).transpose()
|
21 |
+
cxdists = []
|
22 |
+
cydists = []
|
23 |
+
for i in range(len(xb)):
|
24 |
+
xba = np.tile(xb[i],(size[1],1)).transpose()
|
25 |
+
yba = np.tile(yb[i],(size[0],1))
|
26 |
+
cxdists.append(np.abs(xarray-xba))
|
27 |
+
cydists.append(np.abs(yarray-yba))
|
28 |
+
xdist = np.minimum.reduce(cxdists)
|
29 |
+
ydist = np.minimum.reduce(cydists)
|
30 |
+
manhdist = np.minimum.reduce([xdist,ydist])
|
31 |
+
im = (manhdist+1) / (boundwidth+1) * 1.0
|
32 |
+
im[im>=1.0] = 1.0
|
33 |
+
return im
|
34 |
+
|
35 |
+
class SingleDataset(BaseDataset):
|
36 |
+
@staticmethod
|
37 |
+
def modify_commandline_options(parser, is_train):
|
38 |
+
return parser
|
39 |
+
|
40 |
+
def initialize(self, opt):
|
41 |
+
self.opt = opt
|
42 |
+
self.root = opt.dataroot
|
43 |
+
self.dir_A = os.path.join(opt.dataroot)
|
44 |
+
imglist = 'datasets/apdrawing_list/%s/%s.txt' % (opt.phase, opt.dataroot)
|
45 |
+
if os.path.exists(imglist):
|
46 |
+
lines = open(imglist, 'r').read().splitlines()
|
47 |
+
self.A_paths = sorted(lines)
|
48 |
+
else:
|
49 |
+
self.A_paths = make_dataset(self.dir_A)
|
50 |
+
self.A_paths = sorted(self.A_paths)
|
51 |
+
self.transform = get_transform(opt) # this function uses NO_FLIP; aligned dataset do not use this, aligned dataset manually transform
|
52 |
+
|
53 |
+
def __getitem__(self, index):
|
54 |
+
A_path = self.A_paths[index]
|
55 |
+
A_img = Image.open(A_path).convert('RGB')
|
56 |
+
A = self.transform(A_img)
|
57 |
+
if self.opt.which_direction == 'BtoA':
|
58 |
+
input_nc = self.opt.output_nc
|
59 |
+
output_nc = self.opt.input_nc
|
60 |
+
else:
|
61 |
+
input_nc = self.opt.input_nc
|
62 |
+
output_nc = self.opt.output_nc
|
63 |
+
|
64 |
+
if input_nc == 1: # RGB to gray
|
65 |
+
tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
|
66 |
+
A = tmp.unsqueeze(0)
|
67 |
+
|
68 |
+
item = {'A': A, 'A_paths': A_path}
|
69 |
+
|
70 |
+
if self.opt.use_local:
|
71 |
+
regions = ['eyel','eyer','nose','mouth']
|
72 |
+
basen = os.path.basename(A_path)[:-4]+'.txt'
|
73 |
+
featdir = self.opt.lm_dir
|
74 |
+
featpath = os.path.join(featdir,basen)
|
75 |
+
feats = getfeats(featpath)
|
76 |
+
mouth_x = int((feats[3,0]+feats[4,0])/2.0)
|
77 |
+
mouth_y = int((feats[3,1]+feats[4,1])/2.0)
|
78 |
+
ratio = self.opt.fineSize / 256
|
79 |
+
EYE_H = self.opt.EYE_H * ratio
|
80 |
+
EYE_W = self.opt.EYE_W * ratio
|
81 |
+
NOSE_H = self.opt.NOSE_H * ratio
|
82 |
+
NOSE_W = self.opt.NOSE_W * ratio
|
83 |
+
MOUTH_H = self.opt.MOUTH_H * ratio
|
84 |
+
MOUTH_W = self.opt.MOUTH_W * ratio
|
85 |
+
center = torch.IntTensor([[feats[0,0],feats[0,1]-4*ratio],[feats[1,0],feats[1,1]-4*ratio],[feats[2,0],feats[2,1]-NOSE_H/2+16*ratio],[mouth_x,mouth_y]])
|
86 |
+
item['center'] = center
|
87 |
+
rhs = [int(EYE_H),int(EYE_H),int(NOSE_H),int(MOUTH_H)]
|
88 |
+
rws = [int(EYE_W),int(EYE_W),int(NOSE_W),int(MOUTH_W)]
|
89 |
+
if self.opt.soft_border:
|
90 |
+
soft_border_mask4 = []
|
91 |
+
for i in range(4):
|
92 |
+
xb = [np.zeros(rhs[i]),np.ones(rhs[i])*(rws[i]-1)]
|
93 |
+
yb = [np.zeros(rws[i]),np.ones(rws[i])*(rhs[i]-1)]
|
94 |
+
soft_border_mask = getSoft([rhs[i],rws[i]],xb,yb)
|
95 |
+
soft_border_mask4.append(torch.Tensor(soft_border_mask).unsqueeze(0))
|
96 |
+
item['soft_'+regions[i]+'_mask'] = soft_border_mask4[i]
|
97 |
+
for i in range(4):
|
98 |
+
item[regions[i]+'_A'] = A[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2]
|
99 |
+
if self.opt.soft_border:
|
100 |
+
item[regions[i]+'_A'] = item[regions[i]+'_A'] * soft_border_mask4[i].repeat(int(input_nc/output_nc),1,1)
|
101 |
+
if self.opt.compactmask:
|
102 |
+
cmasks0 = []
|
103 |
+
cmasks = []
|
104 |
+
for i in range(4):
|
105 |
+
cmaskpath = os.path.join(self.opt.cmask_dir,regions[i],basen[:-4]+'.png')
|
106 |
+
im_cmask = Image.open(cmaskpath)
|
107 |
+
cmask0 = transforms.ToTensor()(im_cmask)
|
108 |
+
if output_nc == 1 and cmask0.shape[0] == 3:
|
109 |
+
tmp = cmask0[0, ...] * 0.299 + cmask0[1, ...] * 0.587 + cmask0[2, ...] * 0.114
|
110 |
+
cmask0 = tmp.unsqueeze(0)
|
111 |
+
cmask0 = (cmask0 >= 0.5).float()
|
112 |
+
cmasks0.append(cmask0)
|
113 |
+
cmask = cmask0.clone()
|
114 |
+
cmask = cmask[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2]
|
115 |
+
cmasks.append(cmask)
|
116 |
+
item['cmaskel'] = cmasks[0]
|
117 |
+
item['cmasker'] = cmasks[1]
|
118 |
+
item['cmask'] = cmasks[2]
|
119 |
+
item['cmaskmo'] = cmasks[3]
|
120 |
+
if self.opt.hair_local:
|
121 |
+
output_nc = self.opt.output_nc
|
122 |
+
mask = torch.ones([output_nc,A.shape[1],A.shape[2]])
|
123 |
+
for i in range(4):
|
124 |
+
mask[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2] = 0
|
125 |
+
if self.opt.soft_border:
|
126 |
+
imgsize = self.opt.fineSize
|
127 |
+
maskn = mask[0].numpy()
|
128 |
+
masks = [np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize])]
|
129 |
+
masks[0][1:] = maskn[:-1]
|
130 |
+
masks[1][:-1] = maskn[1:]
|
131 |
+
masks[2][:,1:] = maskn[:,:-1]
|
132 |
+
masks[3][:,:-1] = maskn[:,1:]
|
133 |
+
masks2 = [maskn-e for e in masks]
|
134 |
+
bound = np.minimum.reduce(masks2)
|
135 |
+
bound = -bound
|
136 |
+
xb = []
|
137 |
+
yb = []
|
138 |
+
for i in range(4):
|
139 |
+
xbi = [center[i,0]-rws[i]/2, center[i,0]+rws[i]/2-1]
|
140 |
+
ybi = [center[i,1]-rhs[i]/2, center[i,1]+rhs[i]/2-1]
|
141 |
+
for j in range(2):
|
142 |
+
maskx = bound[:,xbi[j]]
|
143 |
+
masky = bound[ybi[j],:]
|
144 |
+
tmp_a = torch.from_numpy(maskx)*xbi[j].double()
|
145 |
+
tmp_b = torch.from_numpy(1-maskx)
|
146 |
+
xb += [tmp_b*10000 + tmp_a]
|
147 |
+
|
148 |
+
tmp_a = torch.from_numpy(masky)*ybi[j].double()
|
149 |
+
tmp_b = torch.from_numpy(1-masky)
|
150 |
+
yb += [tmp_b*10000 + tmp_a]
|
151 |
+
soft = 1-getSoft([imgsize,imgsize],xb,yb)
|
152 |
+
soft = torch.Tensor(soft).unsqueeze(0)
|
153 |
+
mask = (torch.ones(mask.shape)-mask)*soft + mask
|
154 |
+
hair_A = (A/2+0.5) * mask.repeat(int(input_nc/output_nc),1,1) * 2 - 1
|
155 |
+
item['hair_A'] = hair_A
|
156 |
+
item['mask'] = mask
|
157 |
+
if self.opt.bg_local:
|
158 |
+
bgdir = self.opt.bg_dir
|
159 |
+
bgpath = os.path.join(bgdir,basen[:-4]+'.png')
|
160 |
+
im_bg = Image.open(bgpath)
|
161 |
+
mask2 = transforms.ToTensor()(im_bg) # mask out background
|
162 |
+
mask2 = (mask2 >= 0.5).float()
|
163 |
+
hair_A = (A/2+0.5) * mask.repeat(int(input_nc/output_nc),1,1) * mask2.repeat(int(input_nc/output_nc),1,1) * 2 - 1
|
164 |
+
bg_A = (A/2+0.5) * (torch.ones(mask2.shape)-mask2).repeat(int(input_nc/output_nc),1,1) * 2 - 1
|
165 |
+
item['hair_A'] = hair_A
|
166 |
+
item['bg_A'] = bg_A
|
167 |
+
item['mask'] = mask
|
168 |
+
item['mask2'] = mask2
|
169 |
+
|
170 |
+
return item
|
171 |
+
|
172 |
+
def __len__(self):
|
173 |
+
return len(self.A_paths)
|
174 |
+
|
175 |
+
def name(self):
|
176 |
+
return 'SingleImageDataset'
|
APDrawingGAN2/docs/tips.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Training/test Tips
|
2 |
+
- Flags: see `options/train_options.py` and `options/base_options.py` for the training flags; see `options/test_options.py` and `options/base_options.py` for the test flags. The default values of these options are somtimes adjusted in the model files.
|
3 |
+
|
4 |
+
- CPU/GPU (default `--gpu_ids 0`): set`--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. You need a large batch size (e.g. `--batch_size 32`) to benefit from multiple GPUs.
|
5 |
+
|
6 |
+
- Visualization: during training, the current results can be viewed using two methods. First, if you set `--display_id` > 0, the results and loss plot will appear on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should have `visdom` installed and a server running by the command `python -m visdom.server`. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. The `visdom` display functionality is turned on by default. To avoid the extra overhead of communicating with `visdom` set `--display_id -1`. Second, the intermediate results are saved to `[opt.checkpoints_dir]/[opt.name]/web/` as an HTML file. To avoid this, set `--no_html`.
|
7 |
+
|
8 |
+
- Fine-tuning/Resume training: to fine-tune a pre-trained model, or resume the previous training, use the `--continue_train` flag. The program will then load the model based on `which_epoch`. By default, the program will initialize the epoch count as 1. Set `--epoch_count <int>` to specify a different starting epoch count.
|
APDrawingGAN2/models/__init__.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from models.base_model import BaseModel
|
3 |
+
|
4 |
+
|
5 |
+
def find_model_using_name(model_name):
|
6 |
+
# Given the option --model [modelname],
|
7 |
+
# the file "models/modelname_model.py"
|
8 |
+
# will be imported.
|
9 |
+
model_filename = "models." + model_name + "_model"
|
10 |
+
modellib = importlib.import_module(model_filename)
|
11 |
+
|
12 |
+
# In the file, the class called ModelNameModel() will
|
13 |
+
# be instantiated. It has to be a subclass of BaseModel,
|
14 |
+
# and it is case-insensitive.
|
15 |
+
model = None
|
16 |
+
target_model_name = model_name.replace('_', '') + 'model'
|
17 |
+
for name, cls in modellib.__dict__.items():
|
18 |
+
if name.lower() == target_model_name.lower() \
|
19 |
+
and issubclass(cls, BaseModel):
|
20 |
+
model = cls
|
21 |
+
|
22 |
+
if model is None:
|
23 |
+
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
24 |
+
exit(0)
|
25 |
+
|
26 |
+
return model
|
27 |
+
|
28 |
+
|
29 |
+
def get_option_setter(model_name):
|
30 |
+
model_class = find_model_using_name(model_name)
|
31 |
+
return model_class.modify_commandline_options
|
32 |
+
|
33 |
+
|
34 |
+
def create_model(opt):
|
35 |
+
model = find_model_using_name(opt.model)
|
36 |
+
instance = model()
|
37 |
+
instance.initialize(opt)
|
38 |
+
print("model [%s] was created" % (instance.name()))
|
39 |
+
return instance
|
APDrawingGAN2/models/apdrawingpp_style_model.py
ADDED
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from util.image_pool import ImagePool
|
3 |
+
from .base_model import BaseModel
|
4 |
+
from . import networks
|
5 |
+
import os
|
6 |
+
import math
|
7 |
+
|
8 |
+
W = 11
|
9 |
+
aa = int(math.floor(512./W))
|
10 |
+
res = 512 - W*aa
|
11 |
+
|
12 |
+
|
13 |
+
def padpart(A,part,centers,opt,device):
|
14 |
+
IMAGE_SIZE = opt.fineSize
|
15 |
+
bs,nc,_,_ = A.shape
|
16 |
+
ratio = IMAGE_SIZE / 256
|
17 |
+
NOSE_W = opt.NOSE_W * ratio
|
18 |
+
NOSE_H = opt.NOSE_H * ratio
|
19 |
+
EYE_W = opt.EYE_W * ratio
|
20 |
+
EYE_H = opt.EYE_H * ratio
|
21 |
+
MOUTH_W = opt.MOUTH_W * ratio
|
22 |
+
MOUTH_H = opt.MOUTH_H * ratio
|
23 |
+
A_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(device)
|
24 |
+
padvalue = -1 # black
|
25 |
+
for i in range(bs):
|
26 |
+
center = centers[i]
|
27 |
+
if part == 'nose':
|
28 |
+
A_p[i] = torch.nn.ConstantPad2d((center[2,0] - NOSE_W / 2, IMAGE_SIZE - (center[2,0]+NOSE_W/2), center[2,1] - NOSE_H / 2, IMAGE_SIZE - (center[2,1]+NOSE_H/2)),padvalue)(A[i])
|
29 |
+
elif part == 'eyel':
|
30 |
+
A_p[i] = torch.nn.ConstantPad2d((center[0,0] - EYE_W / 2, IMAGE_SIZE - (center[0,0]+EYE_W/2), center[0,1] - EYE_H / 2, IMAGE_SIZE - (center[0,1]+EYE_H/2)),padvalue)(A[i])
|
31 |
+
elif part == 'eyer':
|
32 |
+
A_p[i] = torch.nn.ConstantPad2d((center[1,0] - EYE_W / 2, IMAGE_SIZE - (center[1,0]+EYE_W/2), center[1,1] - EYE_H / 2, IMAGE_SIZE - (center[1,1]+EYE_H/2)),padvalue)(A[i])
|
33 |
+
elif part == 'mouth':
|
34 |
+
A_p[i] = torch.nn.ConstantPad2d((center[3,0] - MOUTH_W / 2, IMAGE_SIZE - (center[3,0]+MOUTH_W/2), center[3,1] - MOUTH_H / 2, IMAGE_SIZE - (center[3,1]+MOUTH_H/2)),padvalue)(A[i])
|
35 |
+
return A_p
|
36 |
+
|
37 |
+
import numpy as np
|
38 |
+
def nonlinearDt(dt,type='atan',xmax=torch.Tensor([10.0])):#dt in [0,1], first multiply xmax(>1), then remap to [0,1]
|
39 |
+
if type == 'atan':
|
40 |
+
nldt = torch.atan(dt*xmax) / torch.atan(xmax)
|
41 |
+
elif type == 'sigmoid':
|
42 |
+
nldt = (torch.sigmoid(dt*xmax)-0.5) / (torch.sigmoid(xmax)-0.5)
|
43 |
+
elif type == 'tanh':
|
44 |
+
nldt = torch.tanh(dt*xmax) / torch.tanh(xmax)
|
45 |
+
elif type == 'pow':
|
46 |
+
nldt = torch.pow(dt*xmax,2) / torch.pow(xmax,2)
|
47 |
+
elif type == 'exp':
|
48 |
+
if xmax.item()>1:
|
49 |
+
xmax = xmax / 3
|
50 |
+
nldt = (torch.exp(dt*xmax)-1) / (torch.exp(xmax)-1)
|
51 |
+
#print("remap dt:", type, xmax.item())
|
52 |
+
return nldt
|
53 |
+
|
54 |
+
class APDrawingPPStyleModel(BaseModel):
|
55 |
+
def name(self):
|
56 |
+
return 'APDrawingPPStyleModel'
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def modify_commandline_options(parser, is_train=True):
|
60 |
+
|
61 |
+
# changing the default values to match the pix2pix paper
|
62 |
+
# (https://phillipi.github.io/pix2pix/)
|
63 |
+
parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch')# no_lsgan=True, use_lsgan=False
|
64 |
+
parser.set_defaults(dataset_mode='aligned')
|
65 |
+
parser.set_defaults(auxiliary_root='auxiliaryeye2o')
|
66 |
+
parser.set_defaults(use_local=True, hair_local=True, bg_local=True)
|
67 |
+
parser.set_defaults(discriminator_local=True, gan_loss_strategy=2)
|
68 |
+
parser.set_defaults(chamfer_loss=True, dt_nonlinear='exp', lambda_chamfer=0.35, lambda_chamfer2=0.35)
|
69 |
+
parser.set_defaults(nose_ae=True, others_ae=True, compactmask=True, MOUTH_H=56)
|
70 |
+
parser.set_defaults(soft_border=1, batch_size=1, save_epoch_freq=25)
|
71 |
+
parser.add_argument('--nnG_hairc', type=int, default=6, help='nnG for hair classifier')
|
72 |
+
parser.add_argument('--use_resnet', action='store_true', help='use resnet for generator')
|
73 |
+
parser.add_argument('--regarch', type=int, default=4, help='architecture for netRegressor')
|
74 |
+
if is_train:
|
75 |
+
parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
|
76 |
+
parser.add_argument('--lambda_local', type=float, default=25.0, help='weight for Local loss')
|
77 |
+
parser.set_defaults(netG_dt='unet_512')
|
78 |
+
parser.set_defaults(netG_line='unet_512')
|
79 |
+
|
80 |
+
return parser
|
81 |
+
|
82 |
+
def initialize(self, opt):
|
83 |
+
BaseModel.initialize(self, opt)
|
84 |
+
self.isTrain = opt.isTrain
|
85 |
+
# specify the training losses you want to print out. The program will call base_model.get_current_losses
|
86 |
+
self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
|
87 |
+
if self.isTrain and self.opt.no_l1_loss:
|
88 |
+
self.loss_names = ['G_GAN', 'D_real', 'D_fake']
|
89 |
+
if self.isTrain and self.opt.use_local and not self.opt.no_G_local_loss:
|
90 |
+
self.loss_names.append('G_local')
|
91 |
+
self.loss_names.append('G_hair_local')
|
92 |
+
self.loss_names.append('G_bg_local')
|
93 |
+
if self.isTrain and self.opt.discriminator_local:
|
94 |
+
self.loss_names.append('D_real_local')
|
95 |
+
self.loss_names.append('D_fake_local')
|
96 |
+
self.loss_names.append('G_GAN_local')
|
97 |
+
if self.isTrain and self.opt.chamfer_loss:
|
98 |
+
self.loss_names.append('G_chamfer')
|
99 |
+
self.loss_names.append('G_chamfer2')
|
100 |
+
if self.isTrain and self.opt.continuity_loss:
|
101 |
+
self.loss_names.append('G_continuity')
|
102 |
+
self.loss_names.append('G')
|
103 |
+
print('loss_names', self.loss_names)
|
104 |
+
# specify the images you want to save/display. The program will call base_model.get_current_visuals
|
105 |
+
self.visual_names = ['real_A', 'fake_B', 'real_B']
|
106 |
+
if self.opt.use_local:
|
107 |
+
self.visual_names += ['fake_B0', 'fake_B1']
|
108 |
+
self.visual_names += ['fake_B_hair', 'real_B_hair', 'real_A_hair']
|
109 |
+
self.visual_names += ['fake_B_bg', 'real_B_bg', 'real_A_bg']
|
110 |
+
if self.opt.region_enm in [0,1]:
|
111 |
+
if self.opt.nose_ae:
|
112 |
+
self.visual_names += ['fake_B_nose_v','fake_B_nose_v1','fake_B_nose_v2','cmask1no']
|
113 |
+
if self.opt.others_ae:
|
114 |
+
self.visual_names += ['fake_B_eyel_v','fake_B_eyel_v1','fake_B_eyel_v2','cmask1el']
|
115 |
+
self.visual_names += ['fake_B_eyer_v','fake_B_eyer_v1','fake_B_eyer_v2','cmask1er']
|
116 |
+
self.visual_names += ['fake_B_mouth_v','fake_B_mouth_v1','fake_B_mouth_v2','cmask1mo']
|
117 |
+
elif self.opt.region_enm in [2]:
|
118 |
+
self.visual_names += ['fake_B_nose','fake_B_eyel','fake_B_eyer','fake_B_mouth']
|
119 |
+
if self.isTrain and self.opt.chamfer_loss:
|
120 |
+
self.visual_names += ['dt1', 'dt2']
|
121 |
+
self.visual_names += ['dt1gt', 'dt2gt']
|
122 |
+
if self.isTrain and self.opt.soft_border:
|
123 |
+
self.visual_names += ['mask']
|
124 |
+
if not self.isTrain and self.opt.save2:
|
125 |
+
self.visual_names = ['real_A', 'fake_B']
|
126 |
+
print('visuals', self.visual_names)
|
127 |
+
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
|
128 |
+
self.auxiliary_model_names = []
|
129 |
+
if self.isTrain:
|
130 |
+
self.model_names = ['G', 'D']
|
131 |
+
if self.opt.discriminator_local:
|
132 |
+
self.model_names += ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG']
|
133 |
+
# auxiliary nets for loss calculation
|
134 |
+
if self.opt.chamfer_loss:
|
135 |
+
self.auxiliary_model_names += ['DT1', 'DT2']
|
136 |
+
self.auxiliary_model_names += ['Line1', 'Line2']
|
137 |
+
if self.opt.continuity_loss:
|
138 |
+
self.auxiliary_model_names += ['Regressor']
|
139 |
+
else: # during test time, only load Gs
|
140 |
+
self.model_names = ['G']
|
141 |
+
if self.opt.test_continuity_loss:
|
142 |
+
self.auxiliary_model_names += ['Regressor']
|
143 |
+
if self.opt.use_local:
|
144 |
+
self.model_names += ['GLEyel','GLEyer','GLNose','GLMouth','GLHair','GLBG','GCombine']
|
145 |
+
self.auxiliary_model_names += ['CLm','CLh']
|
146 |
+
# auxiliary nets for local output refinement
|
147 |
+
if self.opt.nose_ae:
|
148 |
+
self.auxiliary_model_names += ['AE']
|
149 |
+
if self.opt.others_ae:
|
150 |
+
self.auxiliary_model_names += ['AEel','AEer','AEmowhite','AEmoblack']
|
151 |
+
print('model_names', self.model_names)
|
152 |
+
print('auxiliary_model_names', self.auxiliary_model_names)
|
153 |
+
# load/define networks
|
154 |
+
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
|
155 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
156 |
+
opt.nnG)
|
157 |
+
print('netG', opt.netG)
|
158 |
+
|
159 |
+
if self.isTrain:
|
160 |
+
use_sigmoid = opt.no_lsgan
|
161 |
+
self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
162 |
+
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
|
163 |
+
print('netD', opt.netD, opt.n_layers_D)
|
164 |
+
if self.opt.discriminator_local:
|
165 |
+
self.netDLEyel = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
166 |
+
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
|
167 |
+
self.netDLEyer = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
168 |
+
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
|
169 |
+
self.netDLNose = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
170 |
+
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
|
171 |
+
self.netDLMouth = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
172 |
+
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
|
173 |
+
self.netDLHair = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
174 |
+
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
|
175 |
+
self.netDLBG = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
176 |
+
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
|
177 |
+
|
178 |
+
|
179 |
+
if self.opt.use_local:
|
180 |
+
netlocal1 = 'partunet' if self.opt.use_resnet == 0 else 'resnet_nblocks'
|
181 |
+
netlocal2 = 'partunet2' if self.opt.use_resnet == 0 else 'resnet_6blocks'
|
182 |
+
netlocal2_style = 'partunet2style' if self.opt.use_resnet == 0 else 'resnet_style2_6blocks'
|
183 |
+
self.netGLEyel = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
|
184 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
|
185 |
+
self.netGLEyer = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
|
186 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
|
187 |
+
self.netGLNose = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
|
188 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
|
189 |
+
self.netGLMouth = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
|
190 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
|
191 |
+
self.netGLHair = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal2_style, opt.norm,
|
192 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=4,
|
193 |
+
extra_channel=3)
|
194 |
+
self.netGLBG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal2, opt.norm,
|
195 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=4)
|
196 |
+
# by default combiner_type is combiner, which uses resnet
|
197 |
+
print('combiner_type', self.opt.combiner_type)
|
198 |
+
self.netGCombine = networks.define_G(2*opt.output_nc, opt.output_nc, opt.ngf, self.opt.combiner_type, opt.norm,
|
199 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 2)
|
200 |
+
# auxiliary classifiers for mouth and hair
|
201 |
+
ratio = self.opt.fineSize / 256
|
202 |
+
self.MOUTH_H = int(self.opt.MOUTH_H * ratio)
|
203 |
+
self.MOUTH_W = int(self.opt.MOUTH_W * ratio)
|
204 |
+
self.netCLm = networks.define_G(opt.input_nc, 2, opt.ngf, 'classifier', opt.norm,
|
205 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
206 |
+
nnG = 3, ae_h = self.MOUTH_H, ae_w = self.MOUTH_W)
|
207 |
+
self.netCLh = networks.define_G(opt.input_nc, 3, opt.ngf, 'classifier', opt.norm,
|
208 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
209 |
+
nnG = opt.nnG_hairc, ae_h = opt.fineSize, ae_w = opt.fineSize)
|
210 |
+
|
211 |
+
|
212 |
+
if self.isTrain:
|
213 |
+
self.fake_AB_pool = ImagePool(opt.pool_size)
|
214 |
+
# define loss functions
|
215 |
+
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
|
216 |
+
self.criterionL1 = torch.nn.L1Loss()
|
217 |
+
|
218 |
+
# initialize optimizers
|
219 |
+
self.optimizers = []
|
220 |
+
if not self.opt.use_local:
|
221 |
+
print('G_params 1 components')
|
222 |
+
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
|
223 |
+
lr=opt.lr, betas=(opt.beta1, 0.999))
|
224 |
+
else:
|
225 |
+
G_params = list(self.netG.parameters()) + list(self.netGLEyel.parameters()) + list(self.netGLEyer.parameters()) + list(self.netGLNose.parameters()) + list(self.netGLMouth.parameters()) + list(self.netGCombine.parameters()) + list(self.netGLHair.parameters()) + list(self.netGLBG.parameters())
|
226 |
+
print('G_params 8 components')
|
227 |
+
self.optimizer_G = torch.optim.Adam(G_params,
|
228 |
+
lr=opt.lr, betas=(opt.beta1, 0.999))
|
229 |
+
|
230 |
+
if not self.opt.discriminator_local:
|
231 |
+
print('D_params 1 components')
|
232 |
+
self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
|
233 |
+
lr=opt.lr, betas=(opt.beta1, 0.999))
|
234 |
+
else:#self.opt.discriminator_local == True
|
235 |
+
D_params = list(self.netD.parameters()) + list(self.netDLEyel.parameters()) +list(self.netDLEyer.parameters()) + list(self.netDLNose.parameters()) + list(self.netDLMouth.parameters()) + list(self.netDLHair.parameters()) + list(self.netDLBG.parameters())
|
236 |
+
print('D_params 7 components')
|
237 |
+
self.optimizer_D = torch.optim.Adam(D_params,
|
238 |
+
lr=opt.lr, betas=(opt.beta1, 0.999))
|
239 |
+
self.optimizers.append(self.optimizer_G)
|
240 |
+
self.optimizers.append(self.optimizer_D)
|
241 |
+
|
242 |
+
# ==================================auxiliary nets (loaded, parameters fixed)=============================
|
243 |
+
if self.opt.use_local and self.opt.nose_ae:
|
244 |
+
ratio = self.opt.fineSize / 256
|
245 |
+
NOSE_H = self.opt.NOSE_H * ratio
|
246 |
+
NOSE_W = self.opt.NOSE_W * ratio
|
247 |
+
self.netAE = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
248 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
249 |
+
latent_dim=self.opt.ae_latentno, ae_h=NOSE_H, ae_w=NOSE_W)
|
250 |
+
self.set_requires_grad(self.netAE, False)
|
251 |
+
if self.opt.use_local and self.opt.others_ae:
|
252 |
+
ratio = self.opt.fineSize / 256
|
253 |
+
EYE_H = self.opt.EYE_H * ratio
|
254 |
+
EYE_W = self.opt.EYE_W * ratio
|
255 |
+
MOUTH_H = self.opt.MOUTH_H * ratio
|
256 |
+
MOUTH_W = self.opt.MOUTH_W * ratio
|
257 |
+
self.netAEel = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
258 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
259 |
+
latent_dim=self.opt.ae_latenteye, ae_h=EYE_H, ae_w=EYE_W)
|
260 |
+
self.netAEer = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
261 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
262 |
+
latent_dim=self.opt.ae_latenteye, ae_h=EYE_H, ae_w=EYE_W)
|
263 |
+
self.netAEmowhite = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
264 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
265 |
+
latent_dim=self.opt.ae_latentmo, ae_h=MOUTH_H, ae_w=MOUTH_W)
|
266 |
+
self.netAEmoblack = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
267 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
268 |
+
latent_dim=self.opt.ae_latentmo, ae_h=MOUTH_H, ae_w=MOUTH_W)
|
269 |
+
self.set_requires_grad(self.netAEel, False)
|
270 |
+
self.set_requires_grad(self.netAEer, False)
|
271 |
+
self.set_requires_grad(self.netAEmowhite, False)
|
272 |
+
self.set_requires_grad(self.netAEmoblack, False)
|
273 |
+
|
274 |
+
|
275 |
+
if self.isTrain and self.opt.continuity_loss:
|
276 |
+
self.nc = 1
|
277 |
+
self.netRegressor = networks.define_G(self.nc, 1, opt.ngf, 'regressor', opt.norm,
|
278 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p,
|
279 |
+
nnG = opt.regarch)
|
280 |
+
self.set_requires_grad(self.netRegressor, False)
|
281 |
+
|
282 |
+
if self.isTrain and self.opt.chamfer_loss:
|
283 |
+
self.nc = 1
|
284 |
+
self.netDT1 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_dt, opt.norm,
|
285 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p)
|
286 |
+
self.netDT2 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_dt, opt.norm,
|
287 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p)
|
288 |
+
self.set_requires_grad(self.netDT1, False)
|
289 |
+
self.set_requires_grad(self.netDT2, False)
|
290 |
+
self.netLine1 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_line, opt.norm,
|
291 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p)
|
292 |
+
self.netLine2 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_line, opt.norm,
|
293 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p)
|
294 |
+
self.set_requires_grad(self.netLine1, False)
|
295 |
+
self.set_requires_grad(self.netLine2, False)
|
296 |
+
|
297 |
+
# ==================================for test (nets loaded, parameters fixed)=============================
|
298 |
+
if not self.isTrain and self.opt.test_continuity_loss:
|
299 |
+
self.nc = 1
|
300 |
+
self.netRegressor = networks.define_G(self.nc, 1, opt.ngf, 'regressor', opt.norm,
|
301 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
302 |
+
nnG = opt.regarch)
|
303 |
+
self.set_requires_grad(self.netRegressor, False)
|
304 |
+
|
305 |
+
|
306 |
+
def set_input(self, input):
|
307 |
+
AtoB = self.opt.which_direction == 'AtoB'
|
308 |
+
self.real_A = input['A' if AtoB else 'B'].to(self.device)
|
309 |
+
self.real_B = input['B' if AtoB else 'A'].to(self.device)
|
310 |
+
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
311 |
+
self.batch_size = len(self.image_paths)
|
312 |
+
if self.opt.use_local:
|
313 |
+
self.real_A_eyel = input['eyel_A'].to(self.device)
|
314 |
+
self.real_A_eyer = input['eyer_A'].to(self.device)
|
315 |
+
self.real_A_nose = input['nose_A'].to(self.device)
|
316 |
+
self.real_A_mouth = input['mouth_A'].to(self.device)
|
317 |
+
self.real_B_eyel = input['eyel_B'].to(self.device)
|
318 |
+
self.real_B_eyer = input['eyer_B'].to(self.device)
|
319 |
+
self.real_B_nose = input['nose_B'].to(self.device)
|
320 |
+
self.real_B_mouth = input['mouth_B'].to(self.device)
|
321 |
+
if self.opt.region_enm in [0,1]:
|
322 |
+
self.center = input['center']
|
323 |
+
if self.opt.soft_border:
|
324 |
+
self.softel = input['soft_eyel_mask'].to(self.device)
|
325 |
+
self.softer = input['soft_eyer_mask'].to(self.device)
|
326 |
+
self.softno = input['soft_nose_mask'].to(self.device)
|
327 |
+
self.softmo = input['soft_mouth_mask'].to(self.device)
|
328 |
+
if self.opt.compactmask:
|
329 |
+
self.cmask = input['cmask'].to(self.device)
|
330 |
+
self.cmask1 = self.cmask*2-1#[0,1]->[-1,1]
|
331 |
+
self.cmaskel = input['cmaskel'].to(self.device)
|
332 |
+
self.cmask1el = self.cmaskel*2-1
|
333 |
+
self.cmasker = input['cmasker'].to(self.device)
|
334 |
+
self.cmask1er = self.cmasker*2-1
|
335 |
+
self.cmaskmo = input['cmaskmo'].to(self.device)
|
336 |
+
self.cmask1mo = self.cmaskmo*2-1
|
337 |
+
self.real_A_hair = input['hair_A'].to(self.device)
|
338 |
+
self.real_B_hair = input['hair_B'].to(self.device)
|
339 |
+
self.mask = input['mask'].to(self.device) # mask for non-eyes,nose,mouth
|
340 |
+
self.mask2 = input['mask2'].to(self.device) # mask for non-bg
|
341 |
+
self.real_A_bg = input['bg_A'].to(self.device)
|
342 |
+
self.real_B_bg = input['bg_B'].to(self.device)
|
343 |
+
if (self.isTrain and self.opt.chamfer_loss):
|
344 |
+
self.dt1gt = input['dt1gt'].to(self.device)
|
345 |
+
self.dt2gt = input['dt2gt'].to(self.device)
|
346 |
+
if self.isTrain and self.opt.emphasis_conti_face:
|
347 |
+
self.face_mask = input['face_mask'].cuda(self.gpu_ids_p[0])
|
348 |
+
|
349 |
+
def getonehot(self,outputs,classes):
|
350 |
+
[maxv,index] = torch.max(outputs,1)
|
351 |
+
y = torch.unsqueeze(index,1)
|
352 |
+
onehot = torch.FloatTensor(self.batch_size,classes).to(self.device)
|
353 |
+
onehot.zero_()
|
354 |
+
onehot.scatter_(1,y,1)
|
355 |
+
return onehot
|
356 |
+
|
357 |
+
def forward(self):
|
358 |
+
if not self.opt.use_local:
|
359 |
+
self.fake_B = self.netG(self.real_A)
|
360 |
+
else:
|
361 |
+
self.fake_B0 = self.netG(self.real_A)
|
362 |
+
# EYES, MOUTH
|
363 |
+
outputs1 = self.netCLm(self.real_A_mouth)
|
364 |
+
onehot1 = self.getonehot(outputs1,2)
|
365 |
+
|
366 |
+
if not self.opt.others_ae:
|
367 |
+
fake_B_eyel = self.netGLEyel(self.real_A_eyel)
|
368 |
+
fake_B_eyer = self.netGLEyer(self.real_A_eyer)
|
369 |
+
fake_B_mouth = self.netGLMouth(self.real_A_mouth)
|
370 |
+
else: # use AE that only constains compact region, need cmask!
|
371 |
+
self.fake_B_eyel1 = self.netGLEyel(self.real_A_eyel)
|
372 |
+
self.fake_B_eyer1 = self.netGLEyer(self.real_A_eyer)
|
373 |
+
self.fake_B_mouth1 = self.netGLMouth(self.real_A_mouth)
|
374 |
+
self.fake_B_eyel2,_ = self.netAEel(self.fake_B_eyel1)
|
375 |
+
self.fake_B_eyer2,_ = self.netAEer(self.fake_B_eyer1)
|
376 |
+
# USE 2 AEs
|
377 |
+
self.fake_B_mouth2 = torch.FloatTensor(self.batch_size,self.opt.output_nc,self.MOUTH_H,self.MOUTH_W).to(self.device)
|
378 |
+
for i in range(self.batch_size):
|
379 |
+
if onehot1[i][0] == 1:
|
380 |
+
self.fake_B_mouth2[i],_ = self.netAEmowhite(self.fake_B_mouth1[i].unsqueeze(0))
|
381 |
+
#print('AEmowhite')
|
382 |
+
elif onehot1[i][1] == 1:
|
383 |
+
self.fake_B_mouth2[i],_ = self.netAEmoblack(self.fake_B_mouth1[i].unsqueeze(0))
|
384 |
+
#print('AEmoblack')
|
385 |
+
fake_B_eyel = self.add_with_mask(self.fake_B_eyel2,self.fake_B_eyel1,self.cmaskel)
|
386 |
+
fake_B_eyer = self.add_with_mask(self.fake_B_eyer2,self.fake_B_eyer1,self.cmasker)
|
387 |
+
fake_B_mouth = self.add_with_mask(self.fake_B_mouth2,self.fake_B_mouth1,self.cmaskmo)
|
388 |
+
# NOSE
|
389 |
+
if not self.opt.nose_ae:
|
390 |
+
fake_B_nose = self.netGLNose(self.real_A_nose)
|
391 |
+
else: # use AE that only constains compact region, need cmask!
|
392 |
+
self.fake_B_nose1 = self.netGLNose(self.real_A_nose)
|
393 |
+
self.fake_B_nose2,_ = self.netAE(self.fake_B_nose1)
|
394 |
+
fake_B_nose = self.add_with_mask(self.fake_B_nose2,self.fake_B_nose1,self.cmask)
|
395 |
+
|
396 |
+
# for visuals and later local loss
|
397 |
+
if self.opt.region_enm in [0,1]:
|
398 |
+
self.fake_B_nose = fake_B_nose
|
399 |
+
self.fake_B_eyel = fake_B_eyel
|
400 |
+
self.fake_B_eyer = fake_B_eyer
|
401 |
+
self.fake_B_mouth = fake_B_mouth
|
402 |
+
# for soft border of 4 rectangle facial feature
|
403 |
+
if self.opt.region_enm == 0 and self.opt.soft_border:
|
404 |
+
self.fake_B_nose = self.masked(fake_B_nose, self.softno)
|
405 |
+
self.fake_B_eyel = self.masked(fake_B_eyel, self.softel)
|
406 |
+
self.fake_B_eyer = self.masked(fake_B_eyer, self.softer)
|
407 |
+
self.fake_B_mouth = self.masked(fake_B_mouth, self.softmo)
|
408 |
+
elif self.opt.region_enm in [2]: # need to multiply cmask
|
409 |
+
self.fake_B_nose = self.masked(fake_B_nose,self.cmask)
|
410 |
+
self.fake_B_eyel = self.masked(fake_B_eyel,self.cmaskel)
|
411 |
+
self.fake_B_eyer = self.masked(fake_B_eyer,self.cmasker)
|
412 |
+
self.fake_B_mouth = self.masked(fake_B_mouth,self.cmaskmo)
|
413 |
+
|
414 |
+
# HAIR, BG AND PARTCOMBINE
|
415 |
+
outputs2 = self.netCLh(self.real_A_hair)
|
416 |
+
onehot2 = self.getonehot(outputs2,3)
|
417 |
+
|
418 |
+
if not self.isTrain:
|
419 |
+
opt = self.opt
|
420 |
+
if opt.imagefolder == 'images':
|
421 |
+
file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), 'styleonehot.txt')
|
422 |
+
else:
|
423 |
+
file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), opt.imagefolder, 'styleonehot.txt')
|
424 |
+
message = '%s [%d %d] [%d %d %d]' % (self.image_paths[0], onehot1[0][0], onehot1[0][1],
|
425 |
+
onehot2[0][0], onehot2[0][1], onehot2[0][2])
|
426 |
+
with open(file_name, 'a+') as s_file:
|
427 |
+
s_file.write(message)
|
428 |
+
s_file.write('\n')
|
429 |
+
|
430 |
+
fake_B_hair = self.netGLHair(self.real_A_hair,onehot2)
|
431 |
+
fake_B_bg = self.netGLBG(self.real_A_bg)
|
432 |
+
self.fake_B_hair = self.masked(fake_B_hair,self.mask*self.mask2)
|
433 |
+
self.fake_B_bg = self.masked(fake_B_bg,self.inverse_mask(self.mask2))
|
434 |
+
if not self.opt.compactmask:
|
435 |
+
self.fake_B1 = self.partCombiner2_bg(fake_B_eyel,fake_B_eyer,fake_B_nose,fake_B_mouth,fake_B_hair,fake_B_bg,self.mask*self.mask2,self.inverse_mask(self.mask2),self.opt.comb_op)
|
436 |
+
else:
|
437 |
+
self.fake_B1 = self.partCombiner2_bg(fake_B_eyel,fake_B_eyer,fake_B_nose,fake_B_mouth,fake_B_hair,fake_B_bg,self.mask*self.mask2,self.inverse_mask(self.mask2),self.opt.comb_op,self.opt.region_enm,self.cmaskel,self.cmasker,self.cmask,self.cmaskmo)
|
438 |
+
|
439 |
+
self.fake_B = self.netGCombine(torch.cat([self.fake_B0,self.fake_B1],1))
|
440 |
+
|
441 |
+
# for AE visuals
|
442 |
+
if self.opt.region_enm in [0,1]:
|
443 |
+
if self.opt.nose_ae:
|
444 |
+
self.fake_B_nose_v = padpart(self.fake_B_nose, 'nose', self.center, self.opt, self.device)
|
445 |
+
self.fake_B_nose_v1 = padpart(self.fake_B_nose1, 'nose', self.center, self.opt, self.device)
|
446 |
+
self.fake_B_nose_v2 = padpart(self.fake_B_nose2, 'nose', self.center, self.opt, self.device)
|
447 |
+
self.cmask1no = padpart(self.cmask1, 'nose', self.center, self.opt, self.device)
|
448 |
+
if self.opt.others_ae:
|
449 |
+
self.fake_B_eyel_v = padpart(self.fake_B_eyel, 'eyel', self.center, self.opt, self.device)
|
450 |
+
self.fake_B_eyel_v1 = padpart(self.fake_B_eyel1, 'eyel', self.center, self.opt, self.device)
|
451 |
+
self.fake_B_eyel_v2 = padpart(self.fake_B_eyel2, 'eyel', self.center, self.opt, self.device)
|
452 |
+
self.cmask1el = padpart(self.cmask1el, 'eyel', self.center, self.opt, self.device)
|
453 |
+
self.fake_B_eyer_v = padpart(self.fake_B_eyer, 'eyer', self.center, self.opt, self.device)
|
454 |
+
self.fake_B_eyer_v1 = padpart(self.fake_B_eyer1, 'eyer', self.center, self.opt, self.device)
|
455 |
+
self.fake_B_eyer_v2 = padpart(self.fake_B_eyer2, 'eyer', self.center, self.opt, self.device)
|
456 |
+
self.cmask1er = padpart(self.cmask1er, 'eyer', self.center, self.opt, self.device)
|
457 |
+
self.fake_B_mouth_v = padpart(self.fake_B_mouth, 'mouth', self.center, self.opt, self.device)
|
458 |
+
self.fake_B_mouth_v1 = padpart(self.fake_B_mouth1, 'mouth', self.center, self.opt, self.device)
|
459 |
+
self.fake_B_mouth_v2 = padpart(self.fake_B_mouth2, 'mouth', self.center, self.opt, self.device)
|
460 |
+
self.cmask1mo = padpart(self.cmask1mo, 'mouth', self.center, self.opt, self.device)
|
461 |
+
|
462 |
+
if not self.isTrain and self.opt.test_continuity_loss:
|
463 |
+
self.ContinuityForTest(real=1)
|
464 |
+
|
465 |
+
|
466 |
+
def backward_D(self):
|
467 |
+
# Fake
|
468 |
+
# stop backprop to the generator by detaching fake_B
|
469 |
+
fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
|
470 |
+
#print('fake_AB', fake_AB.shape) # (1,4,512,512)
|
471 |
+
pred_fake = self.netD(fake_AB.detach())# by detach, not affect G's gradient
|
472 |
+
self.loss_D_fake = self.criterionGAN(pred_fake, False)
|
473 |
+
if self.opt.discriminator_local:
|
474 |
+
fake_AB_parts = self.getLocalParts(fake_AB)
|
475 |
+
local_names = ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG']
|
476 |
+
self.loss_D_fake_local = 0
|
477 |
+
for i in range(len(fake_AB_parts)):
|
478 |
+
net = getattr(self, 'net' + local_names[i])
|
479 |
+
pred_fake_tmp = net(fake_AB_parts[i].detach())
|
480 |
+
addw = self.getaddw(local_names[i])
|
481 |
+
self.loss_D_fake_local = self.loss_D_fake_local + self.criterionGAN(pred_fake_tmp, False) * addw
|
482 |
+
self.loss_D_fake = self.loss_D_fake + self.loss_D_fake_local
|
483 |
+
|
484 |
+
# Real
|
485 |
+
real_AB = torch.cat((self.real_A, self.real_B), 1)
|
486 |
+
pred_real = self.netD(real_AB)
|
487 |
+
self.loss_D_real = self.criterionGAN(pred_real, True)
|
488 |
+
if self.opt.discriminator_local:
|
489 |
+
real_AB_parts = self.getLocalParts(real_AB)
|
490 |
+
local_names = ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG']
|
491 |
+
self.loss_D_real_local = 0
|
492 |
+
for i in range(len(real_AB_parts)):
|
493 |
+
net = getattr(self, 'net' + local_names[i])
|
494 |
+
pred_real_tmp = net(real_AB_parts[i])
|
495 |
+
addw = self.getaddw(local_names[i])
|
496 |
+
self.loss_D_real_local = self.loss_D_real_local + self.criterionGAN(pred_real_tmp, True) * addw
|
497 |
+
self.loss_D_real = self.loss_D_real + self.loss_D_real_local
|
498 |
+
|
499 |
+
# Combined loss
|
500 |
+
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
|
501 |
+
|
502 |
+
self.loss_D.backward()
|
503 |
+
|
504 |
+
def backward_G(self):
|
505 |
+
# First, G(A) should fake the discriminator
|
506 |
+
fake_AB = torch.cat((self.real_A, self.fake_B), 1)
|
507 |
+
pred_fake = self.netD(fake_AB) # (1,4,512,512)->(1,1,30,30)
|
508 |
+
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
|
509 |
+
if self.opt.discriminator_local:
|
510 |
+
fake_AB_parts = self.getLocalParts(fake_AB)
|
511 |
+
local_names = ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG']
|
512 |
+
self.loss_G_GAN_local = 0 # G_GAN_local is then added into G_GAN
|
513 |
+
for i in range(len(fake_AB_parts)):
|
514 |
+
net = getattr(self, 'net' + local_names[i])
|
515 |
+
pred_fake_tmp = net(fake_AB_parts[i])
|
516 |
+
addw = self.getaddw(local_names[i])
|
517 |
+
self.loss_G_GAN_local = self.loss_G_GAN_local + self.criterionGAN(pred_fake_tmp, True) * addw
|
518 |
+
if self.opt.gan_loss_strategy == 1:
|
519 |
+
self.loss_G_GAN = (self.loss_G_GAN + self.loss_G_GAN_local) / (len(fake_AB_parts) + 1)
|
520 |
+
elif self.opt.gan_loss_strategy == 2:
|
521 |
+
self.loss_G_GAN_local = self.loss_G_GAN_local * 0.25
|
522 |
+
self.loss_G_GAN = self.loss_G_GAN + self.loss_G_GAN_local
|
523 |
+
|
524 |
+
# Second, G(A) = B
|
525 |
+
if not self.opt.no_l1_loss:
|
526 |
+
self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
|
527 |
+
|
528 |
+
if self.opt.use_local and not self.opt.no_G_local_loss:
|
529 |
+
local_names = ['eyel','eyer','nose','mouth']
|
530 |
+
self.loss_G_local = 0
|
531 |
+
for i in range(len(local_names)):
|
532 |
+
fakeblocal = getattr(self, 'fake_B_' + local_names[i])
|
533 |
+
realblocal = getattr(self, 'real_B_' + local_names[i])
|
534 |
+
addw = self.getaddw(local_names[i])
|
535 |
+
self.loss_G_local = self.loss_G_local + self.criterionL1(fakeblocal,realblocal) * self.opt.lambda_local * addw
|
536 |
+
self.loss_G_hair_local = self.criterionL1(self.fake_B_hair, self.real_B_hair) * self.opt.lambda_local * self.opt.addw_hair
|
537 |
+
self.loss_G_bg_local = self.criterionL1(self.fake_B_bg, self.real_B_bg) * self.opt.lambda_local * self.opt.addw_bg
|
538 |
+
|
539 |
+
# Third, chamfer matching (assume chamfer_2way and chamfer_only_line is true)
|
540 |
+
if self.opt.chamfer_loss:
|
541 |
+
if self.fake_B.shape[1] == 3:
|
542 |
+
tmp = self.fake_B[:,0,...]*0.299+self.fake_B[:,1,...]*0.587+self.fake_B[:,2,...]*0.114
|
543 |
+
fake_B_gray = tmp.unsqueeze(1)
|
544 |
+
else:
|
545 |
+
fake_B_gray = self.fake_B
|
546 |
+
if self.real_B.shape[1] == 3:
|
547 |
+
tmp = self.real_B[:,0,...]*0.299+self.real_B[:,1,...]*0.587+self.real_B[:,2,...]*0.114
|
548 |
+
real_B_gray = tmp.unsqueeze(1)
|
549 |
+
else:
|
550 |
+
real_B_gray = self.real_B
|
551 |
+
|
552 |
+
gpu_p = self.opt.gpu_ids_p[0]
|
553 |
+
gpu = self.opt.gpu_ids[0]
|
554 |
+
if gpu_p != gpu:
|
555 |
+
fake_B_gray = fake_B_gray.cuda(gpu_p)
|
556 |
+
real_B_gray = real_B_gray.cuda(gpu_p)
|
557 |
+
|
558 |
+
# d_CM(a_i,G(p_i))
|
559 |
+
self.dt1 = self.netDT1(fake_B_gray)
|
560 |
+
self.dt2 = self.netDT2(fake_B_gray)
|
561 |
+
dt1 = self.dt1/2.0+0.5#[-1,1]->[0,1]
|
562 |
+
dt2 = self.dt2/2.0+0.5
|
563 |
+
if self.opt.dt_nonlinear != '':
|
564 |
+
dt_xmax = torch.Tensor([self.opt.dt_xmax]).cuda(gpu_p)
|
565 |
+
dt1 = nonlinearDt(dt1, self.opt.dt_nonlinear, dt_xmax)
|
566 |
+
dt2 = nonlinearDt(dt2, self.opt.dt_nonlinear, dt_xmax)
|
567 |
+
#print('dt1dt2',torch.min(dt1).item(),torch.max(dt1).item(),torch.min(dt2).item(),torch.max(dt2).item())
|
568 |
+
|
569 |
+
bs = real_B_gray.shape[0]
|
570 |
+
real_B_gray_line1 = self.netLine1(real_B_gray)
|
571 |
+
real_B_gray_line2 = self.netLine2(real_B_gray)
|
572 |
+
self.loss_G_chamfer = (dt1[(real_B_gray<0)&(real_B_gray_line1<0)].sum() + dt2[(real_B_gray>=0)&(real_B_gray_line2>=0)].sum()) / bs * self.opt.lambda_chamfer
|
573 |
+
if gpu_p != gpu:
|
574 |
+
self.loss_G_chamfer = self.loss_G_chamfer.cuda(gpu)
|
575 |
+
|
576 |
+
# d_CM(G(p_i),a_i)
|
577 |
+
if gpu_p != gpu:
|
578 |
+
dt1gt = self.dt1gt.cuda(gpu_p)
|
579 |
+
dt2gt = self.dt2gt.cuda(gpu_p)
|
580 |
+
else:
|
581 |
+
dt1gt = self.dt1gt
|
582 |
+
dt2gt = self.dt2gt
|
583 |
+
if self.opt.dt_nonlinear != '':
|
584 |
+
dt1gt = nonlinearDt(dt1gt, self.opt.dt_nonlinear, dt_xmax)
|
585 |
+
dt2gt = nonlinearDt(dt2gt, self.opt.dt_nonlinear, dt_xmax)
|
586 |
+
#print('dt1gtdt2gt',torch.min(dt1gt).item(),torch.max(dt1gt).item(),torch.min(dt2gt).item(),torch.max(dt2gt).item())
|
587 |
+
self.dt1gt = (self.dt1gt-0.5)*2
|
588 |
+
self.dt2gt = (self.dt2gt-0.5)*2
|
589 |
+
|
590 |
+
fake_B_gray_line1 = self.netLine1(fake_B_gray)
|
591 |
+
fake_B_gray_line2 = self.netLine2(fake_B_gray)
|
592 |
+
self.loss_G_chamfer2 = (dt1gt[(fake_B_gray<0)&(fake_B_gray_line1<0)].sum() + dt2gt[(fake_B_gray>=0)&(fake_B_gray_line2>=0)].sum()) / bs * self.opt.lambda_chamfer2
|
593 |
+
if gpu_p != gpu:
|
594 |
+
self.loss_G_chamfer2 = self.loss_G_chamfer2.cuda(gpu)
|
595 |
+
|
596 |
+
# Fourth, line continuity loss, constrained on synthesized drawing
|
597 |
+
if self.opt.continuity_loss:
|
598 |
+
# Patch-based
|
599 |
+
self.get_patches()
|
600 |
+
self.outputs = self.netRegressor(self.fake_B_patches)
|
601 |
+
if not self.opt.emphasis_conti_face:
|
602 |
+
self.loss_G_continuity = (1.0-torch.mean(self.outputs)).cuda(gpu) * self.opt.lambda_continuity
|
603 |
+
else:
|
604 |
+
self.loss_G_continuity = torch.mean((1.0-self.outputs)*self.conti_weights).cuda(gpu) * self.opt.lambda_continuity
|
605 |
+
|
606 |
+
|
607 |
+
|
608 |
+
self.loss_G = self.loss_G_GAN
|
609 |
+
if 'G_L1' in self.loss_names:
|
610 |
+
self.loss_G = self.loss_G + self.loss_G_L1
|
611 |
+
if 'G_local' in self.loss_names:
|
612 |
+
self.loss_G = self.loss_G + self.loss_G_local
|
613 |
+
if 'G_hair_local' in self.loss_names:
|
614 |
+
self.loss_G = self.loss_G + self.loss_G_hair_local
|
615 |
+
if 'G_bg_local' in self.loss_names:
|
616 |
+
self.loss_G = self.loss_G + self.loss_G_bg_local
|
617 |
+
if 'G_chamfer' in self.loss_names:
|
618 |
+
self.loss_G = self.loss_G + self.loss_G_chamfer
|
619 |
+
if 'G_chamfer2' in self.loss_names:
|
620 |
+
self.loss_G = self.loss_G + self.loss_G_chamfer2
|
621 |
+
if 'G_continuity' in self.loss_names:
|
622 |
+
self.loss_G = self.loss_G + self.loss_G_continuity
|
623 |
+
|
624 |
+
self.loss_G.backward()
|
625 |
+
|
626 |
+
def optimize_parameters(self):
|
627 |
+
self.forward()
|
628 |
+
# update D
|
629 |
+
self.set_requires_grad(self.netD, True)
|
630 |
+
|
631 |
+
if self.opt.discriminator_local:
|
632 |
+
self.set_requires_grad(self.netDLEyel, True)
|
633 |
+
self.set_requires_grad(self.netDLEyer, True)
|
634 |
+
self.set_requires_grad(self.netDLNose, True)
|
635 |
+
self.set_requires_grad(self.netDLMouth, True)
|
636 |
+
self.set_requires_grad(self.netDLHair, True)
|
637 |
+
self.set_requires_grad(self.netDLBG, True)
|
638 |
+
self.optimizer_D.zero_grad()
|
639 |
+
self.backward_D()
|
640 |
+
self.optimizer_D.step()
|
641 |
+
|
642 |
+
# update G
|
643 |
+
self.set_requires_grad(self.netD, False)
|
644 |
+
if self.opt.discriminator_local:
|
645 |
+
self.set_requires_grad(self.netDLEyel, False)
|
646 |
+
self.set_requires_grad(self.netDLEyer, False)
|
647 |
+
self.set_requires_grad(self.netDLNose, False)
|
648 |
+
self.set_requires_grad(self.netDLMouth, False)
|
649 |
+
self.set_requires_grad(self.netDLHair, False)
|
650 |
+
self.set_requires_grad(self.netDLBG, False)
|
651 |
+
self.optimizer_G.zero_grad()
|
652 |
+
self.backward_G()
|
653 |
+
self.optimizer_G.step()
|
654 |
+
|
655 |
+
def get_patches(self):
|
656 |
+
gpu_p = self.opt.gpu_ids_p[0]
|
657 |
+
gpu = self.opt.gpu_ids[0]
|
658 |
+
if gpu_p != gpu:
|
659 |
+
self.fake_B = self.fake_B.cuda(gpu_p)
|
660 |
+
# [1,1,512,512]->[bs,1,11,11]
|
661 |
+
patches = []
|
662 |
+
if self.isTrain and self.opt.emphasis_conti_face:
|
663 |
+
weights = []
|
664 |
+
W2 = int(W/2)
|
665 |
+
t = np.random.randint(res,size=2)
|
666 |
+
for i in range(aa):
|
667 |
+
for j in range(aa):
|
668 |
+
p = self.fake_B[:,:,t[0]+i*W:t[0]+(i+1)*W,t[1]+j*W:t[1]+(j+1)*W]
|
669 |
+
whitenum = torch.sum(p>=0.0)
|
670 |
+
#if whitenum < 5 or whitenum > W*W-5:
|
671 |
+
if whitenum < 1 or whitenum > W*W-1:
|
672 |
+
continue
|
673 |
+
patches.append(p)
|
674 |
+
if self.isTrain and self.opt.emphasis_conti_face:
|
675 |
+
weights.append(self.face_mask[:,:,t[0]+i*W+W2,t[1]+j*W+W2])
|
676 |
+
self.fake_B_patches = torch.cat(patches, dim=0)
|
677 |
+
if self.isTrain and self.opt.emphasis_conti_face:
|
678 |
+
self.conti_weights = torch.cat(weights, dim=0)+1 #0->1,1->2
|
679 |
+
|
680 |
+
def get_patches_real(self):
|
681 |
+
# [1,1,512,512]->[bs,1,11,11]
|
682 |
+
patches = []
|
683 |
+
t = np.random.randint(res,size=2)
|
684 |
+
for i in range(aa):
|
685 |
+
for j in range(aa):
|
686 |
+
p = self.real_B[:,:,t[0]+i*W:t[0]+(i+1)*W,t[1]+j*W:t[1]+(j+1)*W]
|
687 |
+
whitenum = torch.sum(p>=0.0)
|
688 |
+
#if whitenum < 5 or whitenum > W*W-5:
|
689 |
+
if whitenum < 1 or whitenum > W*W-1:
|
690 |
+
continue
|
691 |
+
patches.append(p)
|
692 |
+
self.real_B_patches = torch.cat(patches, dim=0)
|
APDrawingGAN2/models/base_model.py
ADDED
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict
|
4 |
+
from . import networks
|
5 |
+
|
6 |
+
|
7 |
+
class BaseModel():
|
8 |
+
|
9 |
+
# modify parser to add command line options,
|
10 |
+
# and also change the default values if needed
|
11 |
+
@staticmethod
|
12 |
+
def modify_commandline_options(parser, is_train):
|
13 |
+
return parser
|
14 |
+
|
15 |
+
def name(self):
|
16 |
+
return 'BaseModel'
|
17 |
+
|
18 |
+
def initialize(self, opt):
|
19 |
+
self.opt = opt
|
20 |
+
self.gpu_ids = opt.gpu_ids
|
21 |
+
self.gpu_ids_p = opt.gpu_ids_p
|
22 |
+
self.isTrain = opt.isTrain
|
23 |
+
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
|
24 |
+
self.device_p = torch.device('cuda:{}'.format(self.gpu_ids_p[0])) if self.gpu_ids else torch.device('cpu')
|
25 |
+
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
26 |
+
self.auxiliary_dir = os.path.join(opt.checkpoints_dir, opt.auxiliary_root)
|
27 |
+
if opt.resize_or_crop != 'scale_width':
|
28 |
+
torch.backends.cudnn.benchmark = True
|
29 |
+
self.loss_names = []
|
30 |
+
self.model_names = []
|
31 |
+
self.visual_names = []
|
32 |
+
self.image_paths = []
|
33 |
+
|
34 |
+
def set_input(self, input):
|
35 |
+
self.input = input
|
36 |
+
|
37 |
+
def forward(self):
|
38 |
+
pass
|
39 |
+
|
40 |
+
# load and print networks; create schedulers
|
41 |
+
def setup(self, opt, parser=None):
|
42 |
+
if self.isTrain:
|
43 |
+
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
|
44 |
+
|
45 |
+
if not self.isTrain or opt.continue_train:
|
46 |
+
self.load_networks(opt.which_epoch)
|
47 |
+
if len(self.auxiliary_model_names) > 0:
|
48 |
+
self.load_auxiliary_networks()
|
49 |
+
self.print_networks(opt.verbose)
|
50 |
+
|
51 |
+
# make models eval mode during test time
|
52 |
+
def eval(self):
|
53 |
+
for name in self.model_names:
|
54 |
+
if isinstance(name, str):
|
55 |
+
net = getattr(self, 'net' + name)
|
56 |
+
net.eval()
|
57 |
+
|
58 |
+
# used in test time, wrapping `forward` in no_grad() so we don't save
|
59 |
+
# intermediate steps for backprop
|
60 |
+
def test(self):
|
61 |
+
with torch.no_grad():
|
62 |
+
self.forward()
|
63 |
+
|
64 |
+
# get image paths
|
65 |
+
def get_image_paths(self):
|
66 |
+
return self.image_paths
|
67 |
+
|
68 |
+
def optimize_parameters(self):
|
69 |
+
pass
|
70 |
+
|
71 |
+
# update learning rate (called once every epoch)
|
72 |
+
def update_learning_rate(self):
|
73 |
+
for scheduler in self.schedulers:
|
74 |
+
scheduler.step()
|
75 |
+
lr = self.optimizers[0].param_groups[0]['lr']
|
76 |
+
print('learning rate = %.7f' % lr)
|
77 |
+
|
78 |
+
# return visualization images. train.py will display these images, and save the images to a html
|
79 |
+
def get_current_visuals(self):
|
80 |
+
visual_ret = OrderedDict()
|
81 |
+
for name in self.visual_names:
|
82 |
+
if isinstance(name, str):
|
83 |
+
visual_ret[name] = getattr(self, name)
|
84 |
+
return visual_ret
|
85 |
+
|
86 |
+
# return traning losses/errors. train.py will print out these errors as debugging information
|
87 |
+
def get_current_losses(self):
|
88 |
+
errors_ret = OrderedDict()
|
89 |
+
for name in self.loss_names:
|
90 |
+
if isinstance(name, str):
|
91 |
+
# float(...) works for both scalar tensor and float number
|
92 |
+
errors_ret[name] = float(getattr(self, 'loss_' + name))
|
93 |
+
return errors_ret
|
94 |
+
|
95 |
+
# save models to the disk
|
96 |
+
def save_networks(self, which_epoch):
|
97 |
+
for name in self.model_names:
|
98 |
+
if isinstance(name, str):
|
99 |
+
save_filename = '%s_net_%s.pth' % (which_epoch, name)
|
100 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
101 |
+
net = getattr(self, 'net' + name)
|
102 |
+
|
103 |
+
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
104 |
+
torch.save(net.module.cpu().state_dict(), save_path)
|
105 |
+
net.cuda(self.gpu_ids[0])
|
106 |
+
else:
|
107 |
+
torch.save(net.cpu().state_dict(), save_path)
|
108 |
+
|
109 |
+
def save_networks2(self, which_epoch):
|
110 |
+
gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch))
|
111 |
+
dis_name = os.path.join(self.save_dir, '%s_net_dis.pt' % (which_epoch))
|
112 |
+
dict_gen = {}
|
113 |
+
dict_dis = {}
|
114 |
+
for name in self.model_names:
|
115 |
+
if isinstance(name, str):
|
116 |
+
net = getattr(self, 'net' + name)
|
117 |
+
|
118 |
+
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
119 |
+
state_dict = net.module.cpu().state_dict()
|
120 |
+
net.cuda(self.gpu_ids[0])
|
121 |
+
else:
|
122 |
+
state_dict = net.cpu().state_dict()
|
123 |
+
|
124 |
+
if name[0] == 'G':
|
125 |
+
dict_gen[name] = state_dict
|
126 |
+
elif name[0] == 'D':
|
127 |
+
dict_dis[name] = state_dict
|
128 |
+
else:
|
129 |
+
save_filename = '%s_net_%s.pth' % (which_epoch, name)
|
130 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
131 |
+
torch.save(state_dict, save_path)
|
132 |
+
if dict_gen:
|
133 |
+
torch.save(dict_gen, gen_name)
|
134 |
+
if dict_dis:
|
135 |
+
torch.save(dict_dis, dis_name)
|
136 |
+
|
137 |
+
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
|
138 |
+
key = keys[i]
|
139 |
+
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
|
140 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
141 |
+
(key == 'running_mean' or key == 'running_var'):
|
142 |
+
if getattr(module, key) is None:
|
143 |
+
state_dict.pop('.'.join(keys))
|
144 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
145 |
+
(key == 'num_batches_tracked'):
|
146 |
+
state_dict.pop('.'.join(keys))
|
147 |
+
else:
|
148 |
+
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
149 |
+
|
150 |
+
# load models from the disk
|
151 |
+
def load_networks(self, which_epoch):
|
152 |
+
gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch))
|
153 |
+
if os.path.exists(gen_name):
|
154 |
+
self.load_networks2(which_epoch)
|
155 |
+
return
|
156 |
+
for name in self.model_names:
|
157 |
+
if isinstance(name, str):
|
158 |
+
load_filename = '%s_net_%s.pth' % (which_epoch, name)
|
159 |
+
load_path = os.path.join(self.save_dir, load_filename)
|
160 |
+
net = getattr(self, 'net' + name)
|
161 |
+
if isinstance(net, torch.nn.DataParallel):
|
162 |
+
net = net.module
|
163 |
+
print('loading the model from %s' % load_path)
|
164 |
+
# if you are using PyTorch newer than 0.4 (e.g., built from
|
165 |
+
# GitHub source), you can remove str() on self.device
|
166 |
+
state_dict = torch.load(load_path, map_location=str(self.device))
|
167 |
+
if hasattr(state_dict, '_metadata'):
|
168 |
+
del state_dict._metadata
|
169 |
+
|
170 |
+
# patch InstanceNorm checkpoints prior to 0.4
|
171 |
+
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
172 |
+
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
173 |
+
net.load_state_dict(state_dict)
|
174 |
+
|
175 |
+
def load_networks2(self, which_epoch):
|
176 |
+
gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch))
|
177 |
+
gen_state_dict = torch.load(gen_name, map_location=str(self.device))
|
178 |
+
if self.isTrain and self.opt.model != 'apdrawing_style_nogan':
|
179 |
+
dis_name = os.path.join(self.save_dir, '%s_net_dis.pt' % (which_epoch))
|
180 |
+
dis_state_dict = torch.load(dis_name, map_location=str(self.device))
|
181 |
+
for name in self.model_names:
|
182 |
+
if isinstance(name, str):
|
183 |
+
net = getattr(self, 'net' + name)
|
184 |
+
if isinstance(net, torch.nn.DataParallel):
|
185 |
+
net = net.module
|
186 |
+
if name[0] == 'G':
|
187 |
+
print('loading the model %s from %s' % (name,gen_name))
|
188 |
+
state_dict = gen_state_dict[name]
|
189 |
+
elif name[0] == 'D':
|
190 |
+
print('loading the model %s from %s' % (name,gen_name))
|
191 |
+
state_dict = dis_state_dict[name]
|
192 |
+
|
193 |
+
if hasattr(state_dict, '_metadata'):
|
194 |
+
del state_dict._metadata
|
195 |
+
# patch InstanceNorm checkpoints prior to 0.4
|
196 |
+
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
197 |
+
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
198 |
+
net.load_state_dict(state_dict)
|
199 |
+
|
200 |
+
# load auxiliary net models from the disk
|
201 |
+
def load_auxiliary_networks(self):
|
202 |
+
for name in self.auxiliary_model_names:
|
203 |
+
if isinstance(name, str):
|
204 |
+
if 'AE' in name and self.opt.ae_small:
|
205 |
+
load_filename = '%s_net_%s_small.pth' % ('latest', name)
|
206 |
+
elif 'Regressor' in name:
|
207 |
+
load_filename = '%s_net_%s%d.pth' % ('latest', name, self.opt.regarch)
|
208 |
+
else:
|
209 |
+
load_filename = '%s_net_%s.pth' % ('latest', name)
|
210 |
+
load_path = os.path.join(self.auxiliary_dir, load_filename)
|
211 |
+
net = getattr(self, 'net' + name)
|
212 |
+
if isinstance(net, torch.nn.DataParallel):
|
213 |
+
net = net.module
|
214 |
+
print('loading the model from %s' % load_path)
|
215 |
+
# if you are using PyTorch newer than 0.4 (e.g., built from
|
216 |
+
# GitHub source), you can remove str() on self.device
|
217 |
+
if name in ['DT1', 'DT2', 'Line1', 'Line2', 'Continuity1', 'Continuity2', 'Regressor', 'Regressorhair', 'Regressorface']:
|
218 |
+
state_dict = torch.load(load_path, map_location=str(self.device_p))
|
219 |
+
else:
|
220 |
+
state_dict = torch.load(load_path, map_location=str(self.device))
|
221 |
+
if hasattr(state_dict, '_metadata'):
|
222 |
+
del state_dict._metadata
|
223 |
+
|
224 |
+
# patch InstanceNorm checkpoints prior to 0.4
|
225 |
+
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
226 |
+
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
227 |
+
net.load_state_dict(state_dict)
|
228 |
+
|
229 |
+
# print network information
|
230 |
+
def print_networks(self, verbose):
|
231 |
+
print('---------- Networks initialized -------------')
|
232 |
+
for name in self.model_names:
|
233 |
+
if isinstance(name, str):
|
234 |
+
net = getattr(self, 'net' + name)
|
235 |
+
num_params = 0
|
236 |
+
for param in net.parameters():
|
237 |
+
num_params += param.numel()
|
238 |
+
if verbose:
|
239 |
+
print(net)
|
240 |
+
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
|
241 |
+
print('-----------------------------------------------')
|
242 |
+
|
243 |
+
# set requies_grad=Fasle to avoid computation
|
244 |
+
def set_requires_grad(self, nets, requires_grad=False):
|
245 |
+
if not isinstance(nets, list):
|
246 |
+
nets = [nets]
|
247 |
+
for net in nets:
|
248 |
+
if net is not None:
|
249 |
+
for param in net.parameters():
|
250 |
+
param.requires_grad = requires_grad
|
251 |
+
|
252 |
+
# =============================================================================================================
|
253 |
+
def inverse_mask(self, mask):
|
254 |
+
return torch.ones(mask.shape).to(self.device)-mask
|
255 |
+
|
256 |
+
def masked(self, A,mask):
|
257 |
+
return (A/2+0.5)*mask*2-1
|
258 |
+
|
259 |
+
def add_with_mask(self, A,B,mask):
|
260 |
+
return ((A/2+0.5)*mask+(B/2+0.5)*(torch.ones(mask.shape).to(self.device)-mask))*2-1
|
261 |
+
|
262 |
+
def addone_with_mask(self, A,mask):
|
263 |
+
return ((A/2+0.5)*mask+(torch.ones(mask.shape).to(self.device)-mask))*2-1
|
264 |
+
|
265 |
+
def partCombiner(self, eyel, eyer, nose, mouth, average_pos=False, comb_op = 1, region_enm = 0, cmaskel = None, cmasker = None, cmaskno = None, cmaskmo = None):
|
266 |
+
'''
|
267 |
+
x y
|
268 |
+
100.571 123.429
|
269 |
+
155.429 123.429
|
270 |
+
128.000 155.886
|
271 |
+
103.314 185.417
|
272 |
+
152.686 185.417
|
273 |
+
this is the mean locaiton of 5 landmarks (for 256x256)
|
274 |
+
Pad2d Left,Right,Top,Down
|
275 |
+
'''
|
276 |
+
if comb_op == 0:
|
277 |
+
# use max pooling, pad black for eyes etc
|
278 |
+
padvalue = -1
|
279 |
+
if region_enm in [1,2]:
|
280 |
+
eyel = eyel * cmaskel
|
281 |
+
eyer = eyer * cmasker
|
282 |
+
nose = nose * cmaskno
|
283 |
+
mouth = mouth * cmaskmo
|
284 |
+
else:
|
285 |
+
# use min pooling, pad white for eyes etc
|
286 |
+
padvalue = 1
|
287 |
+
if region_enm in [1,2]:
|
288 |
+
eyel = self.addone_with_mask(eyel, cmaskel)
|
289 |
+
eyer = self.addone_with_mask(eyer, cmasker)
|
290 |
+
nose = self.addone_with_mask(nose, cmaskno)
|
291 |
+
mouth = self.addone_with_mask(mouth, cmaskmo)
|
292 |
+
if region_enm in [0,1]: # need to pad
|
293 |
+
IMAGE_SIZE = self.opt.fineSize
|
294 |
+
ratio = IMAGE_SIZE / 256
|
295 |
+
EYE_W = self.opt.EYE_W * ratio
|
296 |
+
EYE_H = self.opt.EYE_H * ratio
|
297 |
+
NOSE_W = self.opt.NOSE_W * ratio
|
298 |
+
NOSE_H = self.opt.NOSE_H * ratio
|
299 |
+
MOUTH_W = self.opt.MOUTH_W * ratio
|
300 |
+
MOUTH_H = self.opt.MOUTH_H * ratio
|
301 |
+
bs,nc,_,_ = eyel.shape
|
302 |
+
eyel_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
303 |
+
eyer_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
304 |
+
nose_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
305 |
+
mouth_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
306 |
+
for i in range(bs):
|
307 |
+
if not average_pos:
|
308 |
+
center = self.center[i]#x,y
|
309 |
+
else:# if average_pos = True
|
310 |
+
center = torch.tensor([[101,123-4],[155,123-4],[128,156-NOSE_H/2+16],[128,185]])
|
311 |
+
eyel_p[i] = torch.nn.ConstantPad2d((int(center[0,0] - EYE_W / 2 - 1), int(IMAGE_SIZE - (center[0,0]+EYE_W/2-1)), int(center[0,1] - EYE_H / 2 - 1),int(IMAGE_SIZE - (center[0,1]+EYE_H/2 - 1))),-1)(eyel[i])
|
312 |
+
eyer_p[i] = torch.nn.ConstantPad2d((int(center[1,0] - EYE_W / 2 - 1), int(IMAGE_SIZE - (center[1,0]+EYE_W/2-1)), int(center[1,1] - EYE_H / 2 - 1), int(IMAGE_SIZE - (center[1,1]+EYE_H/2 - 1))),-1)(eyer[i])
|
313 |
+
nose_p[i] = torch.nn.ConstantPad2d((int(center[2,0] - NOSE_W / 2 - 1), int(IMAGE_SIZE - (center[2,0]+NOSE_W/2-1)), int(center[2,1] - NOSE_H / 2 - 1), int(IMAGE_SIZE - (center[2,1]+NOSE_H/2 - 1))),-1)(nose[i])
|
314 |
+
mouth_p[i] = torch.nn.ConstantPad2d((int(center[3,0] - MOUTH_W / 2 - 1), int(IMAGE_SIZE - (center[3,0]+MOUTH_W/2-1)), int(center[3,1] - MOUTH_H / 2 - 1), int(IMAGE_SIZE - (center[3,1]+MOUTH_H/2 - 1))),-1)(mouth[i])
|
315 |
+
elif region_enm in [2]:
|
316 |
+
eyel_p = eyel
|
317 |
+
eyer_p = eyer
|
318 |
+
nose_p = nose
|
319 |
+
mouth_p = mouth
|
320 |
+
if comb_op == 0:
|
321 |
+
# use max pooling
|
322 |
+
eyes = torch.max(eyel_p, eyer_p)
|
323 |
+
eye_nose = torch.max(eyes, nose_p)
|
324 |
+
result = torch.max(eye_nose, mouth_p)
|
325 |
+
else:
|
326 |
+
# use min pooling
|
327 |
+
eyes = torch.min(eyel_p, eyer_p)
|
328 |
+
eye_nose = torch.min(eyes, nose_p)
|
329 |
+
result = torch.min(eye_nose, mouth_p)
|
330 |
+
return result
|
331 |
+
|
332 |
+
def partCombiner2(self, eyel, eyer, nose, mouth, hair, mask, comb_op = 1, region_enm = 0, cmaskel = None, cmasker = None, cmaskno = None, cmaskmo = None):
|
333 |
+
if comb_op == 0:
|
334 |
+
# use max pooling, pad black for eyes etc
|
335 |
+
padvalue = -1
|
336 |
+
hair = self.masked(hair, mask)
|
337 |
+
if region_enm in [1,2]:
|
338 |
+
eyel = eyel * cmaskel
|
339 |
+
eyer = eyer * cmasker
|
340 |
+
nose = nose * cmaskno
|
341 |
+
mouth = mouth * cmaskmo
|
342 |
+
else:
|
343 |
+
# use min pooling, pad white for eyes etc
|
344 |
+
padvalue = 1
|
345 |
+
hair = self.addone_with_mask(hair, mask)
|
346 |
+
if region_enm in [1,2]:
|
347 |
+
eyel = self.addone_with_mask(eyel, cmaskel)
|
348 |
+
eyer = self.addone_with_mask(eyer, cmasker)
|
349 |
+
nose = self.addone_with_mask(nose, cmaskno)
|
350 |
+
mouth = self.addone_with_mask(mouth, cmaskmo)
|
351 |
+
if region_enm in [0,1]: # need to pad
|
352 |
+
IMAGE_SIZE = self.opt.fineSize
|
353 |
+
ratio = IMAGE_SIZE / 256
|
354 |
+
EYE_W = self.opt.EYE_W * ratio
|
355 |
+
EYE_H = self.opt.EYE_H * ratio
|
356 |
+
NOSE_W = self.opt.NOSE_W * ratio
|
357 |
+
NOSE_H = self.opt.NOSE_H * ratio
|
358 |
+
MOUTH_W = self.opt.MOUTH_W * ratio
|
359 |
+
MOUTH_H = self.opt.MOUTH_H * ratio
|
360 |
+
bs,nc,_,_ = eyel.shape
|
361 |
+
eyel_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
362 |
+
eyer_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
363 |
+
nose_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
364 |
+
mouth_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
365 |
+
for i in range(bs):
|
366 |
+
center = self.center[i]#x,y
|
367 |
+
eyel_p[i] = torch.nn.ConstantPad2d((center[0,0] - EYE_W / 2, IMAGE_SIZE - (center[0,0]+EYE_W/2), center[0,1] - EYE_H / 2, IMAGE_SIZE - (center[0,1]+EYE_H/2)),padvalue)(eyel[i])
|
368 |
+
eyer_p[i] = torch.nn.ConstantPad2d((center[1,0] - EYE_W / 2, IMAGE_SIZE - (center[1,0]+EYE_W/2), center[1,1] - EYE_H / 2, IMAGE_SIZE - (center[1,1]+EYE_H/2)),padvalue)(eyer[i])
|
369 |
+
nose_p[i] = torch.nn.ConstantPad2d((center[2,0] - NOSE_W / 2, IMAGE_SIZE - (center[2,0]+NOSE_W/2), center[2,1] - NOSE_H / 2, IMAGE_SIZE - (center[2,1]+NOSE_H/2)),padvalue)(nose[i])
|
370 |
+
mouth_p[i] = torch.nn.ConstantPad2d((center[3,0] - MOUTH_W / 2, IMAGE_SIZE - (center[3,0]+MOUTH_W/2), center[3,1] - MOUTH_H / 2, IMAGE_SIZE - (center[3,1]+MOUTH_H/2)),padvalue)(mouth[i])
|
371 |
+
elif region_enm in [2]:
|
372 |
+
eyel_p = eyel
|
373 |
+
eyer_p = eyer
|
374 |
+
nose_p = nose
|
375 |
+
mouth_p = mouth
|
376 |
+
if comb_op == 0:
|
377 |
+
# use max pooling
|
378 |
+
eyes = torch.max(eyel_p, eyer_p)
|
379 |
+
eye_nose = torch.max(eyes, nose_p)
|
380 |
+
eye_nose_mouth = torch.max(eye_nose, mouth_p)
|
381 |
+
result = torch.max(hair,eye_nose_mouth)
|
382 |
+
else:
|
383 |
+
# use min pooling
|
384 |
+
eyes = torch.min(eyel_p, eyer_p)
|
385 |
+
eye_nose = torch.min(eyes, nose_p)
|
386 |
+
eye_nose_mouth = torch.min(eye_nose, mouth_p)
|
387 |
+
result = torch.min(hair,eye_nose_mouth)
|
388 |
+
return result
|
389 |
+
|
390 |
+
def partCombiner2_bg(self, eyel, eyer, nose, mouth, hair, bg, maskh, maskb, comb_op = 1, region_enm = 0, cmaskel = None, cmasker = None, cmaskno = None, cmaskmo = None):
|
391 |
+
if comb_op == 0:
|
392 |
+
# use max pooling, pad black for eyes etc
|
393 |
+
padvalue = -1
|
394 |
+
hair = self.masked(hair, maskh)
|
395 |
+
bg = self.masked(bg, maskb)
|
396 |
+
if region_enm in [1,2]:
|
397 |
+
eyel = eyel * cmaskel
|
398 |
+
eyer = eyer * cmasker
|
399 |
+
nose = nose * cmaskno
|
400 |
+
mouth = mouth * cmaskmo
|
401 |
+
else:
|
402 |
+
# use min pooling, pad white for eyes etc
|
403 |
+
padvalue = 1
|
404 |
+
hair = self.addone_with_mask(hair, maskh)
|
405 |
+
bg = self.addone_with_mask(bg, maskb)
|
406 |
+
if region_enm in [1,2]:
|
407 |
+
eyel = self.addone_with_mask(eyel, cmaskel)
|
408 |
+
eyer = self.addone_with_mask(eyer, cmasker)
|
409 |
+
nose = self.addone_with_mask(nose, cmaskno)
|
410 |
+
mouth = self.addone_with_mask(mouth, cmaskmo)
|
411 |
+
if region_enm in [0,1]: # need to pad to full size
|
412 |
+
IMAGE_SIZE = self.opt.fineSize
|
413 |
+
ratio = IMAGE_SIZE / 256
|
414 |
+
EYE_W = self.opt.EYE_W * ratio
|
415 |
+
EYE_H = self.opt.EYE_H * ratio
|
416 |
+
NOSE_W = self.opt.NOSE_W * ratio
|
417 |
+
NOSE_H = self.opt.NOSE_H * ratio
|
418 |
+
MOUTH_W = self.opt.MOUTH_W * ratio
|
419 |
+
MOUTH_H = self.opt.MOUTH_H * ratio
|
420 |
+
bs,nc,_,_ = eyel.shape
|
421 |
+
eyel_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
422 |
+
eyer_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
423 |
+
nose_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
424 |
+
mouth_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
425 |
+
for i in range(bs):
|
426 |
+
center = self.center[i]#x,y
|
427 |
+
eyel_p[i] = torch.nn.ConstantPad2d((center[0,0] - EYE_W / 2, IMAGE_SIZE - (center[0,0]+EYE_W/2), center[0,1] - EYE_H / 2, IMAGE_SIZE - (center[0,1]+EYE_H/2)),padvalue)(eyel[i])
|
428 |
+
eyer_p[i] = torch.nn.ConstantPad2d((center[1,0] - EYE_W / 2, IMAGE_SIZE - (center[1,0]+EYE_W/2), center[1,1] - EYE_H / 2, IMAGE_SIZE - (center[1,1]+EYE_H/2)),padvalue)(eyer[i])
|
429 |
+
nose_p[i] = torch.nn.ConstantPad2d((center[2,0] - NOSE_W / 2, IMAGE_SIZE - (center[2,0]+NOSE_W/2), center[2,1] - NOSE_H / 2, IMAGE_SIZE - (center[2,1]+NOSE_H/2)),padvalue)(nose[i])
|
430 |
+
mouth_p[i] = torch.nn.ConstantPad2d((center[3,0] - MOUTH_W / 2, IMAGE_SIZE - (center[3,0]+MOUTH_W/2), center[3,1] - MOUTH_H / 2, IMAGE_SIZE - (center[3,1]+MOUTH_H/2)),padvalue)(mouth[i])
|
431 |
+
elif region_enm in [2]:
|
432 |
+
eyel_p = eyel
|
433 |
+
eyer_p = eyer
|
434 |
+
nose_p = nose
|
435 |
+
mouth_p = mouth
|
436 |
+
if comb_op == 0:
|
437 |
+
eyes = torch.max(eyel_p, eyer_p)
|
438 |
+
eye_nose = torch.max(eyes, nose_p)
|
439 |
+
eye_nose_mouth = torch.max(eye_nose, mouth_p)
|
440 |
+
eye_nose_mouth_hair = torch.max(hair,eye_nose_mouth)
|
441 |
+
result = torch.max(bg,eye_nose_mouth_hair)
|
442 |
+
else:
|
443 |
+
eyes = torch.min(eyel_p, eyer_p)
|
444 |
+
eye_nose = torch.min(eyes, nose_p)
|
445 |
+
eye_nose_mouth = torch.min(eye_nose, mouth_p)
|
446 |
+
eye_nose_mouth_hair = torch.min(hair,eye_nose_mouth)
|
447 |
+
result = torch.min(bg,eye_nose_mouth_hair)
|
448 |
+
return result
|
449 |
+
|
450 |
+
def partCombiner3(self, face, hair, maskf, maskh, comb_op = 1):
|
451 |
+
if comb_op == 0:
|
452 |
+
# use max pooling, pad black etc
|
453 |
+
padvalue = -1
|
454 |
+
face = self.masked(face, maskf)
|
455 |
+
hair = self.masked(hair, maskh)
|
456 |
+
else:
|
457 |
+
# use min pooling, pad white etc
|
458 |
+
padvalue = 1
|
459 |
+
face = self.addone_with_mask(face, maskf)
|
460 |
+
hair = self.addone_with_mask(hair, maskh)
|
461 |
+
if comb_op == 0:
|
462 |
+
result = torch.max(face,hair)
|
463 |
+
else:
|
464 |
+
result = torch.min(face,hair)
|
465 |
+
return result
|
466 |
+
|
467 |
+
|
468 |
+
def tocv2(ts):
|
469 |
+
img = (ts.numpy()/2+0.5)*255
|
470 |
+
img = img.astype('uint8')
|
471 |
+
img = np.transpose(img,(1,2,0))
|
472 |
+
img = img[:,:,::-1]#rgb->bgr
|
473 |
+
return img
|
474 |
+
|
475 |
+
def totor(img):
|
476 |
+
img = img[:,:,::-1]
|
477 |
+
tor = transforms.ToTensor()(img)
|
478 |
+
tor = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(tor)
|
479 |
+
return tor
|
480 |
+
|
481 |
+
|
482 |
+
def ContinuityForTest(self, real = 0):
|
483 |
+
# Patch-based
|
484 |
+
self.get_patches()
|
485 |
+
self.outputs = self.netRegressor(self.fake_B_patches)
|
486 |
+
line_continuity = torch.mean(self.outputs)
|
487 |
+
opt = self.opt
|
488 |
+
file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), 'continuity.txt')
|
489 |
+
message = '%s %.04f' % (self.image_paths[0], line_continuity)
|
490 |
+
with open(file_name, 'a+') as c_file:
|
491 |
+
c_file.write(message)
|
492 |
+
c_file.write('\n')
|
493 |
+
if real == 1:
|
494 |
+
self.get_patches_real()
|
495 |
+
self.outputs2 = self.netRegressor(self.real_B_patches)
|
496 |
+
line_continuity2 = torch.mean(self.outputs2)
|
497 |
+
file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), 'continuity-r.txt')
|
498 |
+
message = '%s %.04f' % (self.image_paths[0], line_continuity2)
|
499 |
+
with open(file_name, 'a+') as c_file:
|
500 |
+
c_file.write(message)
|
501 |
+
c_file.write('\n')
|
502 |
+
|
503 |
+
def getLocalParts(self,fakeAB):
|
504 |
+
bs,nc,_,_ = fakeAB.shape #dtype torch.float32
|
505 |
+
ncr = int(nc / self.opt.output_nc)
|
506 |
+
if self.opt.region_enm in [0,1]:
|
507 |
+
ratio = self.opt.fineSize / 256
|
508 |
+
EYE_H = self.opt.EYE_H * ratio
|
509 |
+
EYE_W = self.opt.EYE_W * ratio
|
510 |
+
NOSE_H = self.opt.NOSE_H * ratio
|
511 |
+
NOSE_W = self.opt.NOSE_W * ratio
|
512 |
+
MOUTH_H = self.opt.MOUTH_H * ratio
|
513 |
+
MOUTH_W = self.opt.MOUTH_W * ratio
|
514 |
+
eyel = torch.ones((bs,nc,int(EYE_H),int(EYE_W))).to(self.device)
|
515 |
+
eyer = torch.ones((bs,nc,int(EYE_H),int(EYE_W))).to(self.device)
|
516 |
+
nose = torch.ones((bs,nc,int(NOSE_H),int(NOSE_W))).to(self.device)
|
517 |
+
mouth = torch.ones((bs,nc,int(MOUTH_H),int(MOUTH_W))).to(self.device)
|
518 |
+
for i in range(bs):
|
519 |
+
center = self.center[i]
|
520 |
+
eyel[i] = fakeAB[i,:,center[0,1]-EYE_H/2:center[0,1]+EYE_H/2,center[0,0]-EYE_W/2:center[0,0]+EYE_W/2]
|
521 |
+
eyer[i] = fakeAB[i,:,center[1,1]-EYE_H/2:center[1,1]+EYE_H/2,center[1,0]-EYE_W/2:center[1,0]+EYE_W/2]
|
522 |
+
nose[i] = fakeAB[i,:,center[2,1]-NOSE_H/2:center[2,1]+NOSE_H/2,center[2,0]-NOSE_W/2:center[2,0]+NOSE_W/2]
|
523 |
+
mouth[i] = fakeAB[i,:,center[3,1]-MOUTH_H/2:center[3,1]+MOUTH_H/2,center[3,0]-MOUTH_W/2:center[3,0]+MOUTH_W/2]
|
524 |
+
elif self.opt.region_enm in [2]:
|
525 |
+
eyel = (fakeAB/2+0.5) * self.cmaskel.repeat(1,ncr,1,1) * 2 - 1
|
526 |
+
eyer = (fakeAB/2+0.5) * self.cmasker.repeat(1,ncr,1,1) * 2 - 1
|
527 |
+
nose = (fakeAB/2+0.5) * self.cmask.repeat(1,ncr,1,1) * 2 - 1
|
528 |
+
mouth = (fakeAB/2+0.5) * self.cmaskmo.repeat(1,ncr,1,1) * 2 - 1
|
529 |
+
hair = (fakeAB/2+0.5) * self.mask.repeat(1,ncr,1,1) * self.mask2.repeat(1,ncr,1,1) * 2 - 1
|
530 |
+
bg = (fakeAB/2+0.5) * (torch.ones(fakeAB.shape).to(self.device)-self.mask2.repeat(1,ncr,1,1)) * 2 - 1
|
531 |
+
return eyel, eyer, nose, mouth, hair, bg
|
532 |
+
|
533 |
+
def getaddw(self,local_name):
|
534 |
+
addw = 1
|
535 |
+
if local_name in ['DLEyel','DLEyer','eyel','eyer','DLFace','face']:
|
536 |
+
addw = self.opt.addw_eye
|
537 |
+
elif local_name in ['DLNose', 'nose']:
|
538 |
+
addw = self.opt.addw_nose
|
539 |
+
elif local_name in ['DLMouth', 'mouth']:
|
540 |
+
addw = self.opt.addw_mouth
|
541 |
+
elif local_name in ['DLHair', 'hair']:
|
542 |
+
addw = self.opt.addw_hair
|
543 |
+
elif local_name in ['DLBG', 'bg']:
|
544 |
+
addw = self.opt.addw_bg
|
545 |
+
return addw
|
APDrawingGAN2/models/networks.py
ADDED
@@ -0,0 +1,1194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import init
|
4 |
+
import functools
|
5 |
+
from torch.optim import lr_scheduler
|
6 |
+
|
7 |
+
###############################################################################
|
8 |
+
# Helper Functions
|
9 |
+
###############################################################################
|
10 |
+
|
11 |
+
|
12 |
+
def get_norm_layer(norm_type='instance'):
|
13 |
+
if norm_type == 'batch':
|
14 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
15 |
+
elif norm_type == 'instance':
|
16 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
|
17 |
+
elif norm_type == 'none':
|
18 |
+
norm_layer = None
|
19 |
+
else:
|
20 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
21 |
+
return norm_layer
|
22 |
+
|
23 |
+
|
24 |
+
def get_scheduler(optimizer, opt):
|
25 |
+
if opt.lr_policy == 'lambda':
|
26 |
+
def lambda_rule(epoch):
|
27 |
+
lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
|
28 |
+
return lr_l
|
29 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
30 |
+
elif opt.lr_policy == 'step':
|
31 |
+
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
32 |
+
elif opt.lr_policy == 'plateau':
|
33 |
+
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
34 |
+
elif opt.lr_policy == 'cosine':
|
35 |
+
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
|
36 |
+
else:
|
37 |
+
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
38 |
+
return scheduler
|
39 |
+
|
40 |
+
|
41 |
+
def init_weights(net, init_type='normal', gain=0.02):
|
42 |
+
def init_func(m):
|
43 |
+
classname = m.__class__.__name__
|
44 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
45 |
+
if init_type == 'normal':
|
46 |
+
init.normal_(m.weight.data, 0.0, gain)
|
47 |
+
elif init_type == 'xavier':
|
48 |
+
init.xavier_normal_(m.weight.data, gain=gain)
|
49 |
+
elif init_type == 'kaiming':
|
50 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
51 |
+
elif init_type == 'orthogonal':
|
52 |
+
init.orthogonal_(m.weight.data, gain=gain)
|
53 |
+
else:
|
54 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
55 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
56 |
+
init.constant_(m.bias.data, 0.0)
|
57 |
+
elif classname.find('BatchNorm2d') != -1:
|
58 |
+
init.normal_(m.weight.data, 1.0, gain)
|
59 |
+
init.constant_(m.bias.data, 0.0)
|
60 |
+
|
61 |
+
print('initialize network with %s' % init_type)
|
62 |
+
net.apply(init_func)
|
63 |
+
|
64 |
+
|
65 |
+
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
66 |
+
if len(gpu_ids) > 0:
|
67 |
+
assert(torch.cuda.is_available())
|
68 |
+
net.to(gpu_ids[0])
|
69 |
+
net = torch.nn.DataParallel(net, gpu_ids)
|
70 |
+
init_weights(net, init_type, gain=init_gain)
|
71 |
+
return net
|
72 |
+
|
73 |
+
|
74 |
+
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], nnG=9, multiple=2, latent_dim=1024, ae_h=96, ae_w=96, extra_channel=2, nres=1):
|
75 |
+
net = None
|
76 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
77 |
+
|
78 |
+
if netG == 'autoencoder':
|
79 |
+
net = AutoEncoder(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
80 |
+
elif netG == 'autoencoderfc':
|
81 |
+
net = AutoEncoderWithFC(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
|
82 |
+
multiple=multiple, latent_dim=latent_dim, h=ae_h, w=ae_w)
|
83 |
+
elif netG == 'autoencoderfc2':
|
84 |
+
net = AutoEncoderWithFC2(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
|
85 |
+
multiple=multiple, latent_dim=latent_dim, h=ae_h, w=ae_w)
|
86 |
+
elif netG == 'vae':
|
87 |
+
net = VAE(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
|
88 |
+
multiple=multiple, latent_dim=latent_dim, h=ae_h, w=ae_w)
|
89 |
+
elif netG == 'classifier':
|
90 |
+
net = Classifier(input_nc, output_nc, ngf, num_downs=nnG, norm_layer=norm_layer, use_dropout=use_dropout, h=ae_h, w=ae_w)
|
91 |
+
elif netG == 'regressor':
|
92 |
+
net = Regressor(input_nc, ngf, norm_layer=norm_layer, arch=nnG)
|
93 |
+
elif netG == 'resnet_9blocks':#default for cyclegan
|
94 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
|
95 |
+
elif netG == 'resnet_6blocks':
|
96 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
|
97 |
+
elif netG == 'resnet_nblocks':
|
98 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=nnG)
|
99 |
+
elif netG == 'resnet_style2_9blocks':
|
100 |
+
net = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, model0_res=0, extra_channel=extra_channel)
|
101 |
+
elif netG == 'resnet_style2_6blocks':
|
102 |
+
net = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, model0_res=0, extra_channel=extra_channel)
|
103 |
+
elif netG == 'resnet_style2_nblocks':
|
104 |
+
net = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=nnG, model0_res=0, extra_channel=extra_channel)
|
105 |
+
elif netG == 'unet_128':
|
106 |
+
net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
107 |
+
elif netG == 'unet_256':#default for pix2pix
|
108 |
+
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
109 |
+
elif netG == 'unet_512':
|
110 |
+
net = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
111 |
+
elif netG == 'unet_ndown':
|
112 |
+
net = UnetGenerator(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
113 |
+
elif netG == 'unetres_ndown':
|
114 |
+
net = UnetResGenerator(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout, nres=nres)
|
115 |
+
elif netG == 'partunet':
|
116 |
+
net = PartUnet(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
117 |
+
elif netG == 'partunet2':
|
118 |
+
net = PartUnet2(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
119 |
+
elif netG == 'partunetres':
|
120 |
+
net = PartUnetRes(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout,nres=nres)
|
121 |
+
elif netG == 'partunet2res':
|
122 |
+
net = PartUnet2Res(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout,nres=nres)
|
123 |
+
elif netG == 'partunet2style':
|
124 |
+
net = PartUnet2Style(input_nc, output_nc, nnG, ngf, extra_channel=extra_channel, norm_layer=norm_layer, use_dropout=use_dropout)
|
125 |
+
elif netG == 'partunet2resstyle':
|
126 |
+
net = PartUnet2ResStyle(input_nc, output_nc, nnG, ngf, extra_channel=extra_channel, norm_layer=norm_layer, use_dropout=use_dropout,nres=nres)
|
127 |
+
elif netG == 'combiner':
|
128 |
+
net = Combiner(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=2)
|
129 |
+
elif netG == 'combiner2':
|
130 |
+
net = Combiner2(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
131 |
+
else:
|
132 |
+
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
|
133 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|
134 |
+
|
135 |
+
|
136 |
+
def define_D(input_nc, ndf, netD,
|
137 |
+
n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
138 |
+
net = None
|
139 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
140 |
+
|
141 |
+
if netD == 'basic':
|
142 |
+
net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
|
143 |
+
elif netD == 'n_layers':
|
144 |
+
net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
|
145 |
+
elif netD == 'pixel':
|
146 |
+
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
|
147 |
+
else:
|
148 |
+
raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
|
149 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|
150 |
+
|
151 |
+
|
152 |
+
##############################################################################
|
153 |
+
# Classes
|
154 |
+
##############################################################################
|
155 |
+
|
156 |
+
|
157 |
+
# Defines the GAN loss which uses either LSGAN or the regular GAN.
|
158 |
+
# When LSGAN is used, it is basically same as MSELoss,
|
159 |
+
# but it abstracts away the need to create the target label tensor
|
160 |
+
# that has the same size as the input
|
161 |
+
class GANLoss(nn.Module):
|
162 |
+
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
|
163 |
+
super(GANLoss, self).__init__()
|
164 |
+
self.register_buffer('real_label', torch.tensor(target_real_label))
|
165 |
+
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
166 |
+
if use_lsgan:
|
167 |
+
self.loss = nn.MSELoss()
|
168 |
+
else:#no_lsgan
|
169 |
+
self.loss = nn.BCELoss()
|
170 |
+
|
171 |
+
def get_target_tensor(self, input, target_is_real):
|
172 |
+
if target_is_real:
|
173 |
+
target_tensor = self.real_label
|
174 |
+
else:
|
175 |
+
target_tensor = self.fake_label
|
176 |
+
return target_tensor.expand_as(input)
|
177 |
+
|
178 |
+
def __call__(self, input, target_is_real):
|
179 |
+
target_tensor = self.get_target_tensor(input, target_is_real)
|
180 |
+
return self.loss(input, target_tensor)
|
181 |
+
|
182 |
+
|
183 |
+
class AutoEncoderMNIST(nn.Module):
|
184 |
+
def __init__(self):
|
185 |
+
super(AutoEncoderMNIST, self).__init__()
|
186 |
+
self.encoder = nn.Sequential(
|
187 |
+
nn.Conv2d(1, 16, 3, stride=3, padding=1), # b, 16, 10, 10
|
188 |
+
nn.ReLU(True),
|
189 |
+
nn.MaxPool2d(2, stride=2), # b, 16, 5, 5
|
190 |
+
nn.Conv2d(16, 8, 3, stride=2, padding=1), # b, 8, 3, 3
|
191 |
+
nn.ReLU(True),
|
192 |
+
nn.MaxPool2d(2, stride=1) # b, 8, 2, 2
|
193 |
+
)
|
194 |
+
self.decoder = nn.Sequential(
|
195 |
+
nn.ConvTranspose2d(8, 16, 3, stride=2), # b, 16, 5, 5
|
196 |
+
nn.ReLU(True),
|
197 |
+
nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1), # b, 8, 15, 15
|
198 |
+
nn.ReLU(True),
|
199 |
+
nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1), # b, 1, 28, 28
|
200 |
+
nn.Tanh()
|
201 |
+
)
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
x = self.encoder(x)
|
205 |
+
x = self.decoder(x)
|
206 |
+
return x
|
207 |
+
|
208 |
+
class AutoEncoder(nn.Module):
|
209 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, padding_type='reflect'):
|
210 |
+
super(AutoEncoder, self).__init__()
|
211 |
+
self.input_nc = input_nc
|
212 |
+
self.output_nc = output_nc
|
213 |
+
self.ngf = ngf
|
214 |
+
if type(norm_layer) == functools.partial:
|
215 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
216 |
+
else:
|
217 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
218 |
+
|
219 |
+
model = [nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)]
|
220 |
+
n_downsampling = 3
|
221 |
+
for i in range(n_downsampling):
|
222 |
+
mult = 2**i
|
223 |
+
model += [nn.LeakyReLU(0.2),
|
224 |
+
nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=4,
|
225 |
+
stride=2, padding=1, bias=use_bias),
|
226 |
+
norm_layer(ngf * mult * 2)]
|
227 |
+
self.encoder = nn.Sequential(*model)
|
228 |
+
|
229 |
+
model2 = []
|
230 |
+
for i in range(n_downsampling):
|
231 |
+
mult = 2**(n_downsampling - i)
|
232 |
+
model2 += [nn.ReLU(),
|
233 |
+
nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
234 |
+
kernel_size=4, stride=2,
|
235 |
+
padding=1, bias=use_bias),
|
236 |
+
norm_layer(int(ngf * mult / 2))]
|
237 |
+
model2 += [nn.ReLU()]
|
238 |
+
model2 += [nn.ConvTranspose2d(ngf, output_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)]
|
239 |
+
model2 += [nn.Tanh()]
|
240 |
+
self.decoder = nn.Sequential(*model2)
|
241 |
+
|
242 |
+
def forward(self, x):
|
243 |
+
ax = self.encoder(x) # b, 512, 6, 6
|
244 |
+
y = self.decoder(ax)
|
245 |
+
return y, ax
|
246 |
+
|
247 |
+
class AutoEncoderWithFC(nn.Module):
|
248 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, multiple=2,latent_dim=1024, h=96, w=96):
|
249 |
+
super(AutoEncoderWithFC, self).__init__()
|
250 |
+
self.input_nc = input_nc
|
251 |
+
self.output_nc = output_nc
|
252 |
+
self.ngf = ngf
|
253 |
+
if type(norm_layer) == functools.partial:
|
254 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
255 |
+
else:
|
256 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
257 |
+
|
258 |
+
model = [nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)]
|
259 |
+
n_downsampling = 3
|
260 |
+
#multiple = 2
|
261 |
+
for i in range(n_downsampling):
|
262 |
+
mult = multiple**i
|
263 |
+
model += [nn.LeakyReLU(0.2),
|
264 |
+
nn.Conv2d(int(ngf * mult), int(ngf * mult * multiple), kernel_size=4,
|
265 |
+
stride=2, padding=1, bias=use_bias),
|
266 |
+
norm_layer(int(ngf * mult * multiple))]
|
267 |
+
self.encoder = nn.Sequential(*model)
|
268 |
+
self.fc1 = nn.Linear(int(ngf*(multiple**n_downsampling)*h/16*w/16),latent_dim)
|
269 |
+
self.relu = nn.ReLU(latent_dim)
|
270 |
+
self.fc2 = nn.Linear(latent_dim,int(ngf*(multiple**n_downsampling)*h/16*w/16))
|
271 |
+
self.rh = int(h/16)
|
272 |
+
self.rw = int(w/16)
|
273 |
+
model2 = []
|
274 |
+
for i in range(n_downsampling):
|
275 |
+
mult = multiple**(n_downsampling - i)
|
276 |
+
model2 += [nn.ReLU(),
|
277 |
+
nn.ConvTranspose2d(int(ngf * mult), int(ngf * mult / multiple),
|
278 |
+
kernel_size=4, stride=2,
|
279 |
+
padding=1, bias=use_bias),
|
280 |
+
norm_layer(int(ngf * mult / multiple))]
|
281 |
+
model2 += [nn.ReLU()]
|
282 |
+
model2 += [nn.ConvTranspose2d(ngf, output_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)]
|
283 |
+
model2 += [nn.Tanh()]
|
284 |
+
self.decoder = nn.Sequential(*model2)
|
285 |
+
|
286 |
+
def forward(self, x):
|
287 |
+
ax = self.encoder(x) # b, 512, 6, 6
|
288 |
+
ax = ax.view(ax.size(0), -1) # view -- reshape
|
289 |
+
ax = self.relu(self.fc1(ax))
|
290 |
+
ax = self.fc2(ax)
|
291 |
+
ax = ax.view(ax.size(0),-1,self.rh,self.rw)
|
292 |
+
y = self.decoder(ax)
|
293 |
+
return y, ax
|
294 |
+
|
295 |
+
class AutoEncoderWithFC2(nn.Module):
|
296 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, multiple=2,latent_dim=1024, h=96, w=96):
|
297 |
+
super(AutoEncoderWithFC2, self).__init__()
|
298 |
+
self.input_nc = input_nc
|
299 |
+
self.output_nc = output_nc
|
300 |
+
self.ngf = ngf
|
301 |
+
if type(norm_layer) == functools.partial:
|
302 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
303 |
+
else:
|
304 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
305 |
+
|
306 |
+
model = [nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)]
|
307 |
+
n_downsampling = 2
|
308 |
+
#multiple = 2
|
309 |
+
for i in range(n_downsampling):
|
310 |
+
mult = multiple**i
|
311 |
+
model += [nn.LeakyReLU(0.2),
|
312 |
+
nn.Conv2d(int(ngf * mult), int(ngf * mult * multiple), kernel_size=4,
|
313 |
+
stride=2, padding=1, bias=use_bias),
|
314 |
+
norm_layer(int(ngf * mult * multiple))]
|
315 |
+
self.encoder = nn.Sequential(*model)
|
316 |
+
self.fc1 = nn.Linear(int(ngf*(multiple**n_downsampling)*h/8*w/8),latent_dim)
|
317 |
+
self.relu = nn.ReLU(latent_dim)
|
318 |
+
self.fc2 = nn.Linear(latent_dim,int(ngf*(multiple**n_downsampling)*h/8*w/8))
|
319 |
+
self.rh = h/8
|
320 |
+
self.rw = w/8
|
321 |
+
model2 = []
|
322 |
+
for i in range(n_downsampling):
|
323 |
+
mult = multiple**(n_downsampling - i)
|
324 |
+
model2 += [nn.ReLU(),
|
325 |
+
nn.ConvTranspose2d(int(ngf * mult), int(ngf * mult / multiple),
|
326 |
+
kernel_size=4, stride=2,
|
327 |
+
padding=1, bias=use_bias),
|
328 |
+
norm_layer(int(ngf * mult / multiple))]
|
329 |
+
model2 += [nn.ReLU()]
|
330 |
+
model2 += [nn.ConvTranspose2d(ngf, output_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)]
|
331 |
+
model2 += [nn.Tanh()]
|
332 |
+
self.decoder = nn.Sequential(*model2)
|
333 |
+
|
334 |
+
def forward(self, x):
|
335 |
+
ax = self.encoder(x) # b, 256, 12, 12
|
336 |
+
ax = ax.view(ax.size(0), -1) # view -- reshape
|
337 |
+
ax = self.relu(self.fc1(ax))
|
338 |
+
ax = self.fc2(ax)
|
339 |
+
ax = ax.view(ax.size(0),-1,self.rh,self.rw)
|
340 |
+
y = self.decoder(ax)
|
341 |
+
return y, ax
|
342 |
+
|
343 |
+
class VAE(nn.Module):
|
344 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, multiple=2,latent_dim=1024, h=96, w=96):
|
345 |
+
super(VAE, self).__init__()
|
346 |
+
self.input_nc = input_nc
|
347 |
+
self.output_nc = output_nc
|
348 |
+
self.ngf = ngf
|
349 |
+
if type(norm_layer) == functools.partial:
|
350 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
351 |
+
else:
|
352 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
353 |
+
|
354 |
+
model = [nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)]
|
355 |
+
n_downsampling = 3
|
356 |
+
for i in range(n_downsampling):
|
357 |
+
mult = multiple**i
|
358 |
+
model += [nn.LeakyReLU(0.2),
|
359 |
+
nn.Conv2d(int(ngf * mult), int(ngf * mult * multiple), kernel_size=4,
|
360 |
+
stride=2, padding=1, bias=use_bias),
|
361 |
+
norm_layer(int(ngf * mult * multiple))]
|
362 |
+
self.encoder_cnn = nn.Sequential(*model)
|
363 |
+
|
364 |
+
self.c_dim = int(ngf*(multiple**n_downsampling)*h/16*w/16)
|
365 |
+
self.rh = h/16
|
366 |
+
self.rw = w/16
|
367 |
+
self.fc1 = nn.Linear(self.c_dim,latent_dim)
|
368 |
+
self.fc2 = nn.Linear(self.c_dim,latent_dim)
|
369 |
+
self.fc3 = nn.Linear(latent_dim,self.c_dim)
|
370 |
+
self.relu = nn.ReLU()
|
371 |
+
|
372 |
+
model2 = []
|
373 |
+
for i in range(n_downsampling):
|
374 |
+
mult = multiple**(n_downsampling - i)
|
375 |
+
model2 += [nn.ReLU(),
|
376 |
+
nn.ConvTranspose2d(int(ngf * mult), int(ngf * mult / multiple),
|
377 |
+
kernel_size=4, stride=2,
|
378 |
+
padding=1, bias=use_bias),
|
379 |
+
norm_layer(int(ngf * mult / multiple))]
|
380 |
+
model2 += [nn.ReLU()]
|
381 |
+
model2 += [nn.ConvTranspose2d(ngf, output_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)]
|
382 |
+
model2 += [nn.Tanh()]#[-1,1]
|
383 |
+
self.decoder_cnn = nn.Sequential(*model2)
|
384 |
+
|
385 |
+
def encode(self, x):
|
386 |
+
h1 = self.encoder_cnn(x)
|
387 |
+
r1 = h1.view(h1.size(0), -1)
|
388 |
+
return self.fc1(r1), self.fc2(r1)
|
389 |
+
|
390 |
+
def reparameterize(self, mu, logvar):# not deterministic for test mode
|
391 |
+
std = torch.exp(0.5*logvar)
|
392 |
+
eps = torch.randn_like(std)# torch.rand_like returns a tensor with the same size as input,
|
393 |
+
# that is filled with random numbers from a normal distribution N(0,1)
|
394 |
+
return eps.mul(std).add_(mu)
|
395 |
+
|
396 |
+
def decode(self, z):
|
397 |
+
h4 = self.relu(self.fc3(z))
|
398 |
+
r3 = h4.view(h4.size(0),-1,self.rh,self.rw)
|
399 |
+
return self.decoder_cnn(r3)
|
400 |
+
|
401 |
+
def forward(self, x):
|
402 |
+
mu, logvar = self.encode(x)
|
403 |
+
z = self.reparameterize(mu, logvar)
|
404 |
+
reconx = self.decode(z)
|
405 |
+
return reconx, mu, logvar
|
406 |
+
|
407 |
+
class Classifier(nn.Module):
|
408 |
+
def __init__(self, input_nc, classes, ngf=64, num_downs=3, norm_layer=nn.BatchNorm2d, use_dropout=False,
|
409 |
+
h=96, w=96):
|
410 |
+
super(Classifier, self).__init__()
|
411 |
+
self.input_nc = input_nc
|
412 |
+
self.ngf = ngf
|
413 |
+
if type(norm_layer) == functools.partial:
|
414 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
415 |
+
else:
|
416 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
417 |
+
|
418 |
+
model = [nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)]
|
419 |
+
multiple = 2
|
420 |
+
for i in range(num_downs):
|
421 |
+
mult = multiple**i
|
422 |
+
model += [nn.LeakyReLU(0.2),
|
423 |
+
nn.Conv2d(int(ngf * mult), int(ngf * mult * multiple), kernel_size=4,
|
424 |
+
stride=2, padding=1, bias=use_bias),
|
425 |
+
norm_layer(int(ngf * mult * multiple))]
|
426 |
+
self.encoder = nn.Sequential(*model)
|
427 |
+
strides = 2**(num_downs+1)
|
428 |
+
self.fc1 = nn.Linear(int(ngf*h*w/(strides*2)), classes)
|
429 |
+
|
430 |
+
def forward(self, x):
|
431 |
+
ax = self.encoder(x) # b, 512, 6, 6
|
432 |
+
ax = ax.view(ax.size(0), -1) # view -- reshape
|
433 |
+
return self.fc1(ax)
|
434 |
+
|
435 |
+
class Regressor(nn.Module):
|
436 |
+
def __init__(self, input_nc, ngf=64, norm_layer=nn.BatchNorm2d, arch=1):
|
437 |
+
super(Regressor, self).__init__()
|
438 |
+
# if use BatchNorm2d,
|
439 |
+
# no need to use bias as BatchNorm2d has affine parameters
|
440 |
+
|
441 |
+
self.arch = arch
|
442 |
+
|
443 |
+
if arch == 1:
|
444 |
+
use_bias = True
|
445 |
+
sequence = [
|
446 |
+
nn.Conv2d(input_nc, ngf, kernel_size=3, stride=2, padding=0, bias=use_bias),#11->5
|
447 |
+
nn.LeakyReLU(0.2, True),
|
448 |
+
nn.Conv2d(ngf, 1, kernel_size=5, stride=1, padding=0, bias=use_bias),#5->1
|
449 |
+
]
|
450 |
+
elif arch == 2:
|
451 |
+
if type(norm_layer) == functools.partial:
|
452 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
453 |
+
else:
|
454 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
455 |
+
sequence = [
|
456 |
+
nn.Conv2d(input_nc, ngf, kernel_size=3, stride=1, padding=0, bias=use_bias),#11->9
|
457 |
+
nn.LeakyReLU(0.2, True),
|
458 |
+
nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=1, padding=0, bias=use_bias),#9->7
|
459 |
+
norm_layer(ngf*2),
|
460 |
+
nn.LeakyReLU(0.2, True),
|
461 |
+
nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=1, padding=0, bias=use_bias),#7->5
|
462 |
+
norm_layer(ngf*4),
|
463 |
+
nn.LeakyReLU(0.2, True),
|
464 |
+
nn.Conv2d(ngf*4, 1, kernel_size=5, stride=1, padding=0, bias=use_bias),#5->1
|
465 |
+
]
|
466 |
+
elif arch == 3:
|
467 |
+
use_bias = True
|
468 |
+
sequence = [
|
469 |
+
nn.Conv2d(input_nc, ngf, kernel_size=3, stride=1, padding=1, bias=use_bias),#11->11
|
470 |
+
nn.LeakyReLU(0.2, True),
|
471 |
+
nn.Conv2d(ngf, 1, kernel_size=11, stride=1, padding=0, bias=use_bias),#11->1
|
472 |
+
]
|
473 |
+
elif arch == 4:
|
474 |
+
use_bias = True
|
475 |
+
sequence = [
|
476 |
+
nn.Conv2d(input_nc, ngf, kernel_size=3, stride=1, padding=1, bias=use_bias),#11->11
|
477 |
+
nn.LeakyReLU(0.2, True),
|
478 |
+
nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=1, padding=1, bias=use_bias),#11->11
|
479 |
+
nn.LeakyReLU(0.2, True),
|
480 |
+
nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=1, padding=1, bias=use_bias),#11->11
|
481 |
+
nn.LeakyReLU(0.2, True),
|
482 |
+
nn.Conv2d(ngf*4, 1, kernel_size=11, stride=1, padding=0, bias=use_bias),#11->1
|
483 |
+
]
|
484 |
+
elif arch == 5:
|
485 |
+
use_bias = True
|
486 |
+
sequence = [
|
487 |
+
nn.Conv2d(input_nc, ngf, kernel_size=3, stride=1, padding=1, bias=use_bias),#11->11
|
488 |
+
nn.LeakyReLU(0.2, True),
|
489 |
+
nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=1, padding=1, bias=use_bias),#11->11
|
490 |
+
nn.LeakyReLU(0.2, True),
|
491 |
+
nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=1, padding=1, bias=use_bias),#11->11
|
492 |
+
nn.LeakyReLU(0.2, True),
|
493 |
+
]
|
494 |
+
fc = [
|
495 |
+
nn.Linear(ngf*4*11*11, 4096),
|
496 |
+
nn.ReLU(True),
|
497 |
+
nn.Dropout(),
|
498 |
+
nn.Linear(4096, 1),
|
499 |
+
]
|
500 |
+
self.fc = nn.Sequential(*fc)
|
501 |
+
|
502 |
+
self.model = nn.Sequential(*sequence)
|
503 |
+
|
504 |
+
def forward(self, x):
|
505 |
+
if self.arch <= 4:
|
506 |
+
return self.model(x)
|
507 |
+
else:
|
508 |
+
x = self.model(x)
|
509 |
+
x = x.view(x.size(0), -1)
|
510 |
+
x = self.fc(x)
|
511 |
+
return x
|
512 |
+
|
513 |
+
|
514 |
+
# Defines the generator that consists of Resnet blocks between a few
|
515 |
+
# downsampling/upsampling operations.
|
516 |
+
# Code and idea originally from Justin Johnson's architecture.
|
517 |
+
# https://github.com/jcjohnson/fast-neural-style/
|
518 |
+
class ResnetGenerator(nn.Module):
|
519 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
|
520 |
+
assert(n_blocks >= 0)
|
521 |
+
super(ResnetGenerator, self).__init__()
|
522 |
+
self.input_nc = input_nc
|
523 |
+
self.output_nc = output_nc
|
524 |
+
self.ngf = ngf
|
525 |
+
if type(norm_layer) == functools.partial:
|
526 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
527 |
+
else:
|
528 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
529 |
+
|
530 |
+
model = [nn.ReflectionPad2d(3),
|
531 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
|
532 |
+
bias=use_bias),
|
533 |
+
norm_layer(ngf),
|
534 |
+
nn.ReLU(True)]
|
535 |
+
|
536 |
+
n_downsampling = 2
|
537 |
+
for i in range(n_downsampling):
|
538 |
+
mult = 2**i
|
539 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
|
540 |
+
stride=2, padding=1, bias=use_bias),
|
541 |
+
norm_layer(ngf * mult * 2),
|
542 |
+
nn.ReLU(True)]
|
543 |
+
|
544 |
+
mult = 2**n_downsampling
|
545 |
+
for i in range(n_blocks):
|
546 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
547 |
+
|
548 |
+
for i in range(n_downsampling):
|
549 |
+
mult = 2**(n_downsampling - i)
|
550 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
551 |
+
kernel_size=3, stride=2,
|
552 |
+
padding=1, output_padding=1,
|
553 |
+
bias=use_bias),
|
554 |
+
norm_layer(int(ngf * mult / 2)),
|
555 |
+
nn.ReLU(True)]
|
556 |
+
model += [nn.ReflectionPad2d(3)]
|
557 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
558 |
+
model += [nn.Tanh()]
|
559 |
+
|
560 |
+
self.model = nn.Sequential(*model)
|
561 |
+
|
562 |
+
def forward(self, input):
|
563 |
+
return self.model(input)
|
564 |
+
|
565 |
+
class ResnetStyle2Generator(nn.Module):
|
566 |
+
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
567 |
+
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
568 |
+
"""
|
569 |
+
|
570 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', extra_channel=3, model0_res=0):
|
571 |
+
"""Construct a Resnet-based generator
|
572 |
+
|
573 |
+
Parameters:
|
574 |
+
input_nc (int) -- the number of channels in input images
|
575 |
+
output_nc (int) -- the number of channels in output images
|
576 |
+
ngf (int) -- the number of filters in the last conv layer
|
577 |
+
norm_layer -- normalization layer
|
578 |
+
use_dropout (bool) -- if use dropout layers
|
579 |
+
n_blocks (int) -- the number of ResNet blocks
|
580 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
581 |
+
"""
|
582 |
+
assert(n_blocks >= 0)
|
583 |
+
super(ResnetStyle2Generator, self).__init__()
|
584 |
+
if type(norm_layer) == functools.partial:
|
585 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
586 |
+
else:
|
587 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
588 |
+
|
589 |
+
model0 = [nn.ReflectionPad2d(3),
|
590 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
591 |
+
norm_layer(ngf),
|
592 |
+
nn.ReLU(True)]
|
593 |
+
|
594 |
+
n_downsampling = 2
|
595 |
+
for i in range(n_downsampling): # add downsampling layers
|
596 |
+
mult = 2 ** i
|
597 |
+
model0 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
598 |
+
norm_layer(ngf * mult * 2),
|
599 |
+
nn.ReLU(True)]
|
600 |
+
|
601 |
+
mult = 2 ** n_downsampling
|
602 |
+
for i in range(model0_res): # add ResNet blocks
|
603 |
+
model0 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
604 |
+
|
605 |
+
model = []
|
606 |
+
model += [nn.Conv2d(ngf * mult + extra_channel, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias),
|
607 |
+
norm_layer(ngf * mult),
|
608 |
+
nn.ReLU(True)]
|
609 |
+
|
610 |
+
for i in range(n_blocks-model0_res): # add ResNet blocks
|
611 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
612 |
+
|
613 |
+
for i in range(n_downsampling): # add upsampling layers
|
614 |
+
mult = 2 ** (n_downsampling - i)
|
615 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
616 |
+
kernel_size=3, stride=2,
|
617 |
+
padding=1, output_padding=1,
|
618 |
+
bias=use_bias),
|
619 |
+
norm_layer(int(ngf * mult / 2)),
|
620 |
+
nn.ReLU(True)]
|
621 |
+
model += [nn.ReflectionPad2d(3)]
|
622 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
623 |
+
model += [nn.Tanh()]
|
624 |
+
|
625 |
+
self.model0 = nn.Sequential(*model0)
|
626 |
+
self.model = nn.Sequential(*model)
|
627 |
+
print(list(self.modules()))
|
628 |
+
|
629 |
+
def forward(self, input1, input2): # input2 [bs,c]
|
630 |
+
"""Standard forward"""
|
631 |
+
f1 = self.model0(input1)
|
632 |
+
[bs,c,h,w] = f1.shape
|
633 |
+
input2 = input2.repeat(h,w,1,1).permute([2,3,0,1])
|
634 |
+
y1 = torch.cat([f1, input2], 1)
|
635 |
+
return self.model(y1)
|
636 |
+
|
637 |
+
class Combiner(nn.Module):
|
638 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
|
639 |
+
assert(n_blocks >= 0)
|
640 |
+
super(Combiner, self).__init__()
|
641 |
+
self.input_nc = input_nc
|
642 |
+
self.output_nc = output_nc
|
643 |
+
self.ngf = ngf
|
644 |
+
if type(norm_layer) == functools.partial:
|
645 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
646 |
+
else:
|
647 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
648 |
+
|
649 |
+
model = [nn.ReflectionPad2d(3),
|
650 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
|
651 |
+
bias=use_bias),
|
652 |
+
norm_layer(ngf),
|
653 |
+
nn.ReLU(True)]
|
654 |
+
|
655 |
+
for i in range(n_blocks):
|
656 |
+
model += [ResnetBlock(ngf, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
657 |
+
|
658 |
+
model += [nn.ReflectionPad2d(3)]
|
659 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
660 |
+
model += [nn.Tanh()]
|
661 |
+
|
662 |
+
self.model = nn.Sequential(*model)
|
663 |
+
|
664 |
+
def forward(self, input):
|
665 |
+
return self.model(input)
|
666 |
+
|
667 |
+
class Combiner2(nn.Module):
|
668 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
669 |
+
norm_layer=nn.BatchNorm2d, use_dropout=False):
|
670 |
+
super(Combiner2, self).__init__()
|
671 |
+
|
672 |
+
# construct unet structure
|
673 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
|
674 |
+
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
|
675 |
+
|
676 |
+
self.model = unet_block
|
677 |
+
|
678 |
+
def forward(self, input):
|
679 |
+
return self.model(input)
|
680 |
+
|
681 |
+
|
682 |
+
# Define a resnet block
|
683 |
+
class ResnetBlock(nn.Module):
|
684 |
+
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
685 |
+
super(ResnetBlock, self).__init__()
|
686 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
687 |
+
|
688 |
+
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
689 |
+
conv_block = []
|
690 |
+
p = 0
|
691 |
+
if padding_type == 'reflect':
|
692 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
693 |
+
elif padding_type == 'replicate':
|
694 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
695 |
+
elif padding_type == 'zero':
|
696 |
+
p = 1
|
697 |
+
else:
|
698 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
699 |
+
|
700 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
701 |
+
norm_layer(dim),
|
702 |
+
nn.ReLU(True)]
|
703 |
+
if use_dropout:
|
704 |
+
conv_block += [nn.Dropout(0.5)]
|
705 |
+
|
706 |
+
p = 0
|
707 |
+
if padding_type == 'reflect':
|
708 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
709 |
+
elif padding_type == 'replicate':
|
710 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
711 |
+
elif padding_type == 'zero':
|
712 |
+
p = 1
|
713 |
+
else:
|
714 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
715 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
716 |
+
norm_layer(dim)]
|
717 |
+
|
718 |
+
return nn.Sequential(*conv_block)
|
719 |
+
|
720 |
+
def forward(self, x):
|
721 |
+
out = x + self.conv_block(x)
|
722 |
+
return out
|
723 |
+
|
724 |
+
|
725 |
+
# Defines the Unet generator.
|
726 |
+
# |num_downs|: number of downsamplings in UNet. For example,
|
727 |
+
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
|
728 |
+
# at the bottleneck
|
729 |
+
class UnetGenerator(nn.Module):
|
730 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
731 |
+
norm_layer=nn.BatchNorm2d, use_dropout=False):
|
732 |
+
super(UnetGenerator, self).__init__()
|
733 |
+
|
734 |
+
# construct unet structure
|
735 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
|
736 |
+
for i in range(num_downs - 5):
|
737 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
738 |
+
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
739 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
740 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
741 |
+
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
|
742 |
+
|
743 |
+
self.model = unet_block
|
744 |
+
|
745 |
+
def forward(self, input):
|
746 |
+
return self.model(input)
|
747 |
+
|
748 |
+
class UnetResGenerator(nn.Module):
|
749 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
750 |
+
norm_layer=nn.BatchNorm2d, use_dropout=False, nres=1):
|
751 |
+
super(UnetResGenerator, self).__init__()
|
752 |
+
|
753 |
+
# construct unet structure
|
754 |
+
unet_block = UnetSkipConnectionResBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, nres=nres)
|
755 |
+
for i in range(num_downs - 5):
|
756 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
757 |
+
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
758 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
759 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
760 |
+
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
|
761 |
+
|
762 |
+
self.model = unet_block
|
763 |
+
|
764 |
+
def forward(self, input):
|
765 |
+
return self.model(input)
|
766 |
+
|
767 |
+
class PartUnet(nn.Module):
|
768 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
769 |
+
norm_layer=nn.BatchNorm2d, use_dropout=False):
|
770 |
+
super(PartUnet, self).__init__()
|
771 |
+
|
772 |
+
# construct unet structure
|
773 |
+
# 3 downs
|
774 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
|
775 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
776 |
+
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
|
777 |
+
|
778 |
+
self.model = unet_block
|
779 |
+
|
780 |
+
def forward(self, input):
|
781 |
+
return self.model(input)
|
782 |
+
|
783 |
+
class PartUnetRes(nn.Module):
|
784 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
785 |
+
norm_layer=nn.BatchNorm2d, use_dropout=False, nres=1):
|
786 |
+
super(PartUnetRes, self).__init__()
|
787 |
+
|
788 |
+
# construct unet structure
|
789 |
+
# 3 downs
|
790 |
+
unet_block = UnetSkipConnectionResBlock(ngf * 2, ngf * 4, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, nres=nres)
|
791 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
792 |
+
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
|
793 |
+
|
794 |
+
self.model = unet_block
|
795 |
+
|
796 |
+
def forward(self, input):
|
797 |
+
return self.model(input)
|
798 |
+
|
799 |
+
class PartUnet2(nn.Module):
|
800 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
801 |
+
norm_layer=nn.BatchNorm2d, use_dropout=False):
|
802 |
+
super(PartUnet2, self).__init__()
|
803 |
+
|
804 |
+
# construct unet structure
|
805 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 2, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
|
806 |
+
for i in range(num_downs - 3):
|
807 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
808 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
809 |
+
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
|
810 |
+
|
811 |
+
self.model = unet_block
|
812 |
+
|
813 |
+
def forward(self, input):
|
814 |
+
return self.model(input)
|
815 |
+
|
816 |
+
class PartUnet2Res(nn.Module):
|
817 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
818 |
+
norm_layer=nn.BatchNorm2d, use_dropout=False, nres=1):
|
819 |
+
super(PartUnet2Res, self).__init__()
|
820 |
+
|
821 |
+
# construct unet structure
|
822 |
+
unet_block = UnetSkipConnectionResBlock(ngf * 2, ngf * 2, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, nres=nres)
|
823 |
+
for i in range(num_downs - 3):
|
824 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
825 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
826 |
+
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
|
827 |
+
|
828 |
+
self.model = unet_block
|
829 |
+
|
830 |
+
def forward(self, input):
|
831 |
+
return self.model(input)
|
832 |
+
|
833 |
+
class PartUnet2Style(nn.Module):
|
834 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, extra_channel=2,
|
835 |
+
norm_layer=nn.BatchNorm2d, use_dropout=False):
|
836 |
+
super(PartUnet2Style, self).__init__()
|
837 |
+
# construct unet structure
|
838 |
+
unet_block = UnetSkipConnectionStyleBlock(ngf * 2, ngf * 2, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, extra_channel=extra_channel)
|
839 |
+
for i in range(num_downs - 3):
|
840 |
+
unet_block = UnetSkipConnectionStyleBlock(ngf * 2, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout, extra_channel=extra_channel)
|
841 |
+
unet_block = UnetSkipConnectionStyleBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, extra_channel=extra_channel)
|
842 |
+
unet_block = UnetSkipConnectionStyleBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer, extra_channel=extra_channel)
|
843 |
+
|
844 |
+
self.model = unet_block
|
845 |
+
|
846 |
+
def forward(self, input, cate):
|
847 |
+
return self.model(input, cate)
|
848 |
+
|
849 |
+
class PartUnet2ResStyle(nn.Module):
|
850 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, extra_channel=2,
|
851 |
+
norm_layer=nn.BatchNorm2d, use_dropout=False, nres=1):
|
852 |
+
super(PartUnet2ResStyle, self).__init__()
|
853 |
+
# construct unet structure
|
854 |
+
unet_block = UnetSkipConnectionResStyleBlock(ngf * 2, ngf * 2, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, extra_channel=extra_channel, nres=nres)
|
855 |
+
for i in range(num_downs - 3):
|
856 |
+
unet_block = UnetSkipConnectionStyleBlock(ngf * 2, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout, extra_channel=extra_channel)
|
857 |
+
unet_block = UnetSkipConnectionStyleBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, extra_channel=extra_channel)
|
858 |
+
unet_block = UnetSkipConnectionStyleBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer, extra_channel=extra_channel)
|
859 |
+
|
860 |
+
self.model = unet_block
|
861 |
+
|
862 |
+
def forward(self, input, cate):
|
863 |
+
return self.model(input, cate)
|
864 |
+
|
865 |
+
|
866 |
+
# Defines the submodule with skip connection.
|
867 |
+
# X -------------------identity---------------------- X
|
868 |
+
# |-- downsampling -- |submodule| -- upsampling --|
|
869 |
+
class UnetSkipConnectionBlock(nn.Module):
|
870 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
871 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
872 |
+
super(UnetSkipConnectionBlock, self).__init__()
|
873 |
+
self.outermost = outermost
|
874 |
+
if type(norm_layer) == functools.partial:
|
875 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
876 |
+
else:
|
877 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
878 |
+
if input_nc is None:
|
879 |
+
input_nc = outer_nc
|
880 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
881 |
+
stride=2, padding=1, bias=use_bias)
|
882 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
883 |
+
downnorm = norm_layer(inner_nc)
|
884 |
+
uprelu = nn.ReLU(True)
|
885 |
+
upnorm = norm_layer(outer_nc)
|
886 |
+
|
887 |
+
if outermost:
|
888 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
889 |
+
kernel_size=4, stride=2,
|
890 |
+
padding=1)
|
891 |
+
down = [downconv]
|
892 |
+
up = [uprelu, upconv, nn.Tanh()]
|
893 |
+
model = down + [submodule] + up
|
894 |
+
elif innermost:
|
895 |
+
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
896 |
+
kernel_size=4, stride=2,
|
897 |
+
padding=1, bias=use_bias)
|
898 |
+
down = [downrelu, downconv]
|
899 |
+
up = [uprelu, upconv, upnorm]
|
900 |
+
model = down + up
|
901 |
+
else:
|
902 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
903 |
+
kernel_size=4, stride=2,
|
904 |
+
padding=1, bias=use_bias)
|
905 |
+
down = [downrelu, downconv, downnorm]
|
906 |
+
up = [uprelu, upconv, upnorm]
|
907 |
+
|
908 |
+
if use_dropout:
|
909 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
910 |
+
else:
|
911 |
+
model = down + [submodule] + up
|
912 |
+
|
913 |
+
self.model = nn.Sequential(*model)
|
914 |
+
|
915 |
+
def forward(self, x):
|
916 |
+
if self.outermost:
|
917 |
+
return self.model(x)
|
918 |
+
else:
|
919 |
+
return torch.cat([x, self.model(x)], 1)
|
920 |
+
|
921 |
+
class UnetSkipConnectionResBlock(nn.Module):
|
922 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
923 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, nres=1):
|
924 |
+
super(UnetSkipConnectionResBlock, self).__init__()
|
925 |
+
self.outermost = outermost
|
926 |
+
if type(norm_layer) == functools.partial:
|
927 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
928 |
+
else:
|
929 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
930 |
+
if input_nc is None:
|
931 |
+
input_nc = outer_nc
|
932 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
933 |
+
stride=2, padding=1, bias=use_bias)
|
934 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
935 |
+
downnorm = norm_layer(inner_nc)
|
936 |
+
uprelu = nn.ReLU(True)
|
937 |
+
upnorm = norm_layer(outer_nc)
|
938 |
+
|
939 |
+
if outermost:
|
940 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
941 |
+
kernel_size=4, stride=2,
|
942 |
+
padding=1)
|
943 |
+
down = [downconv]
|
944 |
+
up = [uprelu, upconv, nn.Tanh()]
|
945 |
+
model = down + [submodule] + up
|
946 |
+
elif innermost:
|
947 |
+
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
948 |
+
kernel_size=4, stride=2,
|
949 |
+
padding=1, bias=use_bias)
|
950 |
+
down = [downrelu, downconv, downrelu]
|
951 |
+
up = [upconv, upnorm]
|
952 |
+
model = down
|
953 |
+
# resblock: conv norm relu conv norm +
|
954 |
+
for i in range(nres):
|
955 |
+
model += [ResnetBlock(inner_nc, padding_type='reflect', norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
956 |
+
model += up
|
957 |
+
#model = down + [submodule] + up
|
958 |
+
print('UnetSkipConnectionResBlock','nres',nres,'inner_nc',inner_nc)
|
959 |
+
else:
|
960 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
961 |
+
kernel_size=4, stride=2,
|
962 |
+
padding=1, bias=use_bias)
|
963 |
+
down = [downrelu, downconv, downnorm]
|
964 |
+
up = [uprelu, upconv, upnorm]
|
965 |
+
|
966 |
+
if use_dropout:
|
967 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
968 |
+
else:
|
969 |
+
model = down + [submodule] + up
|
970 |
+
|
971 |
+
self.model = nn.Sequential(*model)
|
972 |
+
|
973 |
+
def forward(self, x):
|
974 |
+
if self.outermost:
|
975 |
+
return self.model(x)
|
976 |
+
else:
|
977 |
+
return torch.cat([x, self.model(x)], 1)
|
978 |
+
|
979 |
+
class UnetSkipConnectionStyleBlock(nn.Module):
|
980 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
981 |
+
submodule=None, outermost=False, innermost=False,
|
982 |
+
extra_channel=2, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
983 |
+
super(UnetSkipConnectionStyleBlock, self).__init__()
|
984 |
+
self.outermost = outermost
|
985 |
+
self.innermost = innermost
|
986 |
+
self.extra_channel = extra_channel
|
987 |
+
if type(norm_layer) == functools.partial:
|
988 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
989 |
+
else:
|
990 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
991 |
+
if input_nc is None:
|
992 |
+
input_nc = outer_nc
|
993 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
994 |
+
stride=2, padding=1, bias=use_bias)
|
995 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
996 |
+
downnorm = norm_layer(inner_nc)
|
997 |
+
uprelu = nn.ReLU(True)
|
998 |
+
upnorm = norm_layer(outer_nc)
|
999 |
+
|
1000 |
+
if outermost:
|
1001 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
1002 |
+
kernel_size=4, stride=2,
|
1003 |
+
padding=1)
|
1004 |
+
down = [downconv]
|
1005 |
+
up = [uprelu, upconv, nn.Tanh()]
|
1006 |
+
model = down + [submodule] + up
|
1007 |
+
elif innermost:
|
1008 |
+
upconv = nn.ConvTranspose2d(inner_nc+extra_channel, outer_nc,
|
1009 |
+
kernel_size=4, stride=2,
|
1010 |
+
padding=1, bias=use_bias)
|
1011 |
+
down = [downrelu, downconv]
|
1012 |
+
up = [uprelu, upconv, upnorm]
|
1013 |
+
model = down + up
|
1014 |
+
else:
|
1015 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
1016 |
+
kernel_size=4, stride=2,
|
1017 |
+
padding=1, bias=use_bias)
|
1018 |
+
down = [downrelu, downconv, downnorm]
|
1019 |
+
up = [uprelu, upconv, upnorm]
|
1020 |
+
|
1021 |
+
if use_dropout:
|
1022 |
+
up = up + [nn.Dropout(0.5)]
|
1023 |
+
model = down + [submodule] + up
|
1024 |
+
|
1025 |
+
self.model = nn.Sequential(*model)
|
1026 |
+
|
1027 |
+
self.downmodel = nn.Sequential(*down)
|
1028 |
+
self.upmodel = nn.Sequential(*up)
|
1029 |
+
self.submodule = submodule
|
1030 |
+
|
1031 |
+
def forward(self, x, cate):# cate [bs,c]
|
1032 |
+
if self.innermost:
|
1033 |
+
y1 = self.downmodel(x)
|
1034 |
+
[bs,c,h,w] = y1.shape
|
1035 |
+
map = cate.repeat(h,w,1,1).permute([2,3,0,1])
|
1036 |
+
y2 = torch.cat([y1,map], 1)
|
1037 |
+
y3 = self.upmodel(y2)
|
1038 |
+
return torch.cat([x, y3], 1)
|
1039 |
+
else:
|
1040 |
+
y1 = self.downmodel(x)
|
1041 |
+
y2 = self.submodule(y1,cate)
|
1042 |
+
y3 = self.upmodel(y2)
|
1043 |
+
if self.outermost:
|
1044 |
+
return y3
|
1045 |
+
else:
|
1046 |
+
return torch.cat([x, y3], 1)
|
1047 |
+
|
1048 |
+
class UnetSkipConnectionResStyleBlock(nn.Module):
|
1049 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
1050 |
+
submodule=None, outermost=False, innermost=False,
|
1051 |
+
extra_channel=2, norm_layer=nn.BatchNorm2d, use_dropout=False, nres=1):
|
1052 |
+
super(UnetSkipConnectionResStyleBlock, self).__init__()
|
1053 |
+
self.outermost = outermost
|
1054 |
+
self.innermost = innermost
|
1055 |
+
self.extra_channel = extra_channel
|
1056 |
+
if type(norm_layer) == functools.partial:
|
1057 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1058 |
+
else:
|
1059 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1060 |
+
if input_nc is None:
|
1061 |
+
input_nc = outer_nc
|
1062 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
1063 |
+
stride=2, padding=1, bias=use_bias)
|
1064 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
1065 |
+
downnorm = norm_layer(inner_nc)
|
1066 |
+
uprelu = nn.ReLU(True)
|
1067 |
+
upnorm = norm_layer(outer_nc)
|
1068 |
+
|
1069 |
+
if outermost:
|
1070 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
1071 |
+
kernel_size=4, stride=2,
|
1072 |
+
padding=1)
|
1073 |
+
down = [downconv]
|
1074 |
+
up = [uprelu, upconv, nn.Tanh()]
|
1075 |
+
model = down + [submodule] + up
|
1076 |
+
elif innermost:
|
1077 |
+
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
1078 |
+
kernel_size=4, stride=2,
|
1079 |
+
padding=1, bias=use_bias)
|
1080 |
+
down = [downrelu, downconv, downrelu]
|
1081 |
+
up = [nn.Conv2d(inner_nc+extra_channel, inner_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
|
1082 |
+
norm_layer(inner_nc),
|
1083 |
+
nn.ReLU(True)]
|
1084 |
+
for i in range(nres):
|
1085 |
+
up += [ResnetBlock(inner_nc, padding_type='reflect', norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
1086 |
+
up += [ upconv, upnorm]
|
1087 |
+
model = down + up
|
1088 |
+
print('UnetSkipConnectionResStyleBlock','nres',nres,'inner_nc',inner_nc)
|
1089 |
+
else:
|
1090 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
1091 |
+
kernel_size=4, stride=2,
|
1092 |
+
padding=1, bias=use_bias)
|
1093 |
+
down = [downrelu, downconv, downnorm]
|
1094 |
+
up = [uprelu, upconv, upnorm]
|
1095 |
+
|
1096 |
+
if use_dropout:
|
1097 |
+
up = up + [nn.Dropout(0.5)]
|
1098 |
+
model = down + [submodule] + up
|
1099 |
+
|
1100 |
+
self.model = nn.Sequential(*model)
|
1101 |
+
|
1102 |
+
self.downmodel = nn.Sequential(*down)
|
1103 |
+
self.upmodel = nn.Sequential(*up)
|
1104 |
+
self.submodule = submodule
|
1105 |
+
|
1106 |
+
def forward(self, x, cate):# cate [bs,c]
|
1107 |
+
# concate in the innermost block
|
1108 |
+
if self.innermost:
|
1109 |
+
y1 = self.downmodel(x)
|
1110 |
+
[bs,c,h,w] = y1.shape
|
1111 |
+
map = cate.repeat(h,w,1,1).permute([2,3,0,1])
|
1112 |
+
y2 = torch.cat([y1,map], 1)
|
1113 |
+
y3 = self.upmodel(y2)
|
1114 |
+
return torch.cat([x, y3], 1)
|
1115 |
+
else:
|
1116 |
+
y1 = self.downmodel(x)
|
1117 |
+
y2 = self.submodule(y1,cate)
|
1118 |
+
y3 = self.upmodel(y2)
|
1119 |
+
if self.outermost:
|
1120 |
+
return y3
|
1121 |
+
else:
|
1122 |
+
return torch.cat([x, y3], 1)
|
1123 |
+
|
1124 |
+
# Defines the PatchGAN discriminator with the specified arguments.
|
1125 |
+
class NLayerDiscriminator(nn.Module):
|
1126 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
|
1127 |
+
super(NLayerDiscriminator, self).__init__()
|
1128 |
+
if type(norm_layer) == functools.partial:
|
1129 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1130 |
+
else:
|
1131 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1132 |
+
|
1133 |
+
kw = 4
|
1134 |
+
padw = 1
|
1135 |
+
sequence = [
|
1136 |
+
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
1137 |
+
nn.LeakyReLU(0.2, True)
|
1138 |
+
]
|
1139 |
+
|
1140 |
+
nf_mult = 1
|
1141 |
+
nf_mult_prev = 1
|
1142 |
+
for n in range(1, n_layers):
|
1143 |
+
nf_mult_prev = nf_mult
|
1144 |
+
nf_mult = min(2**n, 8)
|
1145 |
+
sequence += [
|
1146 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
1147 |
+
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
1148 |
+
norm_layer(ndf * nf_mult),
|
1149 |
+
nn.LeakyReLU(0.2, True)
|
1150 |
+
]
|
1151 |
+
|
1152 |
+
nf_mult_prev = nf_mult
|
1153 |
+
nf_mult = min(2**n_layers, 8)
|
1154 |
+
sequence += [
|
1155 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
1156 |
+
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
1157 |
+
norm_layer(ndf * nf_mult),
|
1158 |
+
nn.LeakyReLU(0.2, True)
|
1159 |
+
]
|
1160 |
+
|
1161 |
+
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
|
1162 |
+
|
1163 |
+
if use_sigmoid:#no_lsgan, use sigmoid before calculating bceloss(binary cross entropy)
|
1164 |
+
sequence += [nn.Sigmoid()]
|
1165 |
+
|
1166 |
+
self.model = nn.Sequential(*sequence)
|
1167 |
+
|
1168 |
+
def forward(self, input):
|
1169 |
+
return self.model(input)
|
1170 |
+
|
1171 |
+
|
1172 |
+
class PixelDiscriminator(nn.Module):
|
1173 |
+
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
|
1174 |
+
super(PixelDiscriminator, self).__init__()
|
1175 |
+
if type(norm_layer) == functools.partial:
|
1176 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1177 |
+
else:
|
1178 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1179 |
+
|
1180 |
+
self.net = [
|
1181 |
+
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
|
1182 |
+
nn.LeakyReLU(0.2, True),
|
1183 |
+
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
|
1184 |
+
norm_layer(ndf * 2),
|
1185 |
+
nn.LeakyReLU(0.2, True),
|
1186 |
+
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
|
1187 |
+
|
1188 |
+
if use_sigmoid:
|
1189 |
+
self.net.append(nn.Sigmoid())
|
1190 |
+
|
1191 |
+
self.net = nn.Sequential(*self.net)
|
1192 |
+
|
1193 |
+
def forward(self, input):
|
1194 |
+
return self.net(input)
|
APDrawingGAN2/models/test_model.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_model import BaseModel
|
2 |
+
from . import networks
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class TestModel(BaseModel):
|
7 |
+
def name(self):
|
8 |
+
return 'TestModel'
|
9 |
+
|
10 |
+
@staticmethod
|
11 |
+
def modify_commandline_options(parser, is_train=True):
|
12 |
+
assert not is_train, 'TestModel cannot be used in train mode'
|
13 |
+
# uncomment because default CycleGAN did not use dropout ( parser.set_defaults(no_dropout=True) )
|
14 |
+
# parser = CycleGANModel.modify_commandline_options(parser, is_train=False)
|
15 |
+
parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch')# no_lsgan=True, use_lsgan=False
|
16 |
+
parser.set_defaults(dataset_mode='single')
|
17 |
+
parser.set_defaults(auxiliary_root='auxiliaryeye2o')
|
18 |
+
parser.set_defaults(use_local=True, hair_local=True, bg_local=True)
|
19 |
+
parser.set_defaults(nose_ae=True, others_ae=True, compactmask=True, MOUTH_H=56)
|
20 |
+
parser.set_defaults(soft_border=1)
|
21 |
+
parser.add_argument('--nnG_hairc', type=int, default=6, help='nnG for hair classifier')
|
22 |
+
parser.add_argument('--use_resnet', action='store_true', help='use resnet for generator')
|
23 |
+
|
24 |
+
parser.add_argument('--model_suffix', type=str, default='',
|
25 |
+
help='In checkpoints_dir, [which_epoch]_net_G[model_suffix].pth will'
|
26 |
+
' be loaded as the generator of TestModel')
|
27 |
+
|
28 |
+
return parser
|
29 |
+
|
30 |
+
def initialize(self, opt):
|
31 |
+
assert(not opt.isTrain)
|
32 |
+
BaseModel.initialize(self, opt)
|
33 |
+
|
34 |
+
# specify the training losses you want to print out. The program will call base_model.get_current_losses
|
35 |
+
self.loss_names = []
|
36 |
+
# specify the images you want to save/display. The program will call base_model.get_current_visuals
|
37 |
+
self.visual_names = ['real_A', 'fake_B']
|
38 |
+
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
|
39 |
+
self.model_names = ['G' + opt.model_suffix]
|
40 |
+
self.auxiliary_model_names = []
|
41 |
+
if self.opt.use_local:
|
42 |
+
self.model_names += ['GLEyel','GLEyer','GLNose','GLMouth','GLHair','GLBG','GCombine']
|
43 |
+
self.auxiliary_model_names += ['CLm','CLh']
|
44 |
+
# auxiliary nets for local output refinement
|
45 |
+
if self.opt.nose_ae:
|
46 |
+
self.auxiliary_model_names += ['AE']
|
47 |
+
if self.opt.others_ae:
|
48 |
+
self.auxiliary_model_names += ['AEel','AEer','AEmowhite','AEmoblack']
|
49 |
+
print('model_names', self.model_names)
|
50 |
+
print('auxiliary_model_names', self.auxiliary_model_names)
|
51 |
+
|
52 |
+
# load/define networks
|
53 |
+
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
|
54 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
55 |
+
opt.nnG)
|
56 |
+
print('netG', opt.netG)
|
57 |
+
if self.opt.use_local:
|
58 |
+
netlocal1 = 'partunet' if self.opt.use_resnet == 0 else 'resnet_nblocks'
|
59 |
+
netlocal2 = 'partunet2' if self.opt.use_resnet == 0 else 'resnet_6blocks'
|
60 |
+
netlocal2_style = 'partunet2style' if self.opt.use_resnet == 0 else 'resnet_style2_6blocks'
|
61 |
+
self.netGLEyel = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
|
62 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
|
63 |
+
self.netGLEyer = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
|
64 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
|
65 |
+
self.netGLNose = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
|
66 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
|
67 |
+
self.netGLMouth = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
|
68 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
|
69 |
+
self.netGLHair = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal2_style, opt.norm,
|
70 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=4,
|
71 |
+
extra_channel=3)
|
72 |
+
self.netGLBG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal2, opt.norm,
|
73 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=4)
|
74 |
+
# by default combiner_type is combiner, which uses resnet
|
75 |
+
print('combiner_type', self.opt.combiner_type)
|
76 |
+
self.netGCombine = networks.define_G(2*opt.output_nc, opt.output_nc, opt.ngf, self.opt.combiner_type, opt.norm,
|
77 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 2)
|
78 |
+
# auxiliary classifiers for mouth and hair
|
79 |
+
ratio = self.opt.fineSize / 256
|
80 |
+
self.MOUTH_H = int(self.opt.MOUTH_H * ratio)
|
81 |
+
self.MOUTH_W = int(self.opt.MOUTH_W * ratio)
|
82 |
+
self.netCLm = networks.define_G(opt.input_nc, 2, opt.ngf, 'classifier', opt.norm,
|
83 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
84 |
+
nnG = 3, ae_h = self.MOUTH_H, ae_w = self.MOUTH_W)
|
85 |
+
self.netCLh = networks.define_G(opt.input_nc, 3, opt.ngf, 'classifier', opt.norm,
|
86 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
87 |
+
nnG = opt.nnG_hairc, ae_h = opt.fineSize, ae_w = opt.fineSize)
|
88 |
+
# ==================================auxiliary nets (loaded, parameters fixed)=============================
|
89 |
+
if self.opt.use_local and self.opt.nose_ae:
|
90 |
+
ratio = self.opt.fineSize / 256
|
91 |
+
NOSE_H = self.opt.NOSE_H * ratio
|
92 |
+
NOSE_W = self.opt.NOSE_W * ratio
|
93 |
+
self.netAE = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
94 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
95 |
+
latent_dim=self.opt.ae_latentno, ae_h=NOSE_H, ae_w=NOSE_W)
|
96 |
+
self.set_requires_grad(self.netAE, False)
|
97 |
+
if self.opt.use_local and self.opt.others_ae:
|
98 |
+
ratio = self.opt.fineSize / 256
|
99 |
+
EYE_H = self.opt.EYE_H * ratio
|
100 |
+
EYE_W = self.opt.EYE_W * ratio
|
101 |
+
MOUTH_H = self.opt.MOUTH_H * ratio
|
102 |
+
MOUTH_W = self.opt.MOUTH_W * ratio
|
103 |
+
self.netAEel = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
104 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
105 |
+
latent_dim=self.opt.ae_latenteye, ae_h=EYE_H, ae_w=EYE_W)
|
106 |
+
self.netAEer = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
107 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
108 |
+
latent_dim=self.opt.ae_latenteye, ae_h=EYE_H, ae_w=EYE_W)
|
109 |
+
self.netAEmowhite = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
110 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
111 |
+
latent_dim=self.opt.ae_latentmo, ae_h=MOUTH_H, ae_w=MOUTH_W)
|
112 |
+
self.netAEmoblack = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
113 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
114 |
+
latent_dim=self.opt.ae_latentmo, ae_h=MOUTH_H, ae_w=MOUTH_W)
|
115 |
+
self.set_requires_grad(self.netAEel, False)
|
116 |
+
self.set_requires_grad(self.netAEer, False)
|
117 |
+
self.set_requires_grad(self.netAEmowhite, False)
|
118 |
+
self.set_requires_grad(self.netAEmoblack, False)
|
119 |
+
|
120 |
+
# assigns the model to self.netG_[suffix] so that it can be loaded
|
121 |
+
# please see BaseModel.load_networks
|
122 |
+
setattr(self, 'netG' + opt.model_suffix, self.netG)
|
123 |
+
|
124 |
+
def set_input(self, input):
|
125 |
+
# we need to use single_dataset mode
|
126 |
+
self.real_A = input['A'].to(self.device)
|
127 |
+
self.image_paths = input['A_paths']
|
128 |
+
self.batch_size = len(self.image_paths)
|
129 |
+
if self.opt.use_local:
|
130 |
+
self.real_A_eyel = input['eyel_A'].to(self.device)
|
131 |
+
self.real_A_eyer = input['eyer_A'].to(self.device)
|
132 |
+
self.real_A_nose = input['nose_A'].to(self.device)
|
133 |
+
self.real_A_mouth = input['mouth_A'].to(self.device)
|
134 |
+
self.center = input['center']
|
135 |
+
if self.opt.soft_border:
|
136 |
+
self.softel = input['soft_eyel_mask'].to(self.device)
|
137 |
+
self.softer = input['soft_eyer_mask'].to(self.device)
|
138 |
+
self.softno = input['soft_nose_mask'].to(self.device)
|
139 |
+
self.softmo = input['soft_mouth_mask'].to(self.device)
|
140 |
+
if self.opt.compactmask:
|
141 |
+
self.cmask = input['cmask'].to(self.device)
|
142 |
+
self.cmask1 = self.cmask*2-1#[0,1]->[-1,1]
|
143 |
+
self.cmaskel = input['cmaskel'].to(self.device)
|
144 |
+
self.cmask1el = self.cmaskel*2-1
|
145 |
+
self.cmasker = input['cmasker'].to(self.device)
|
146 |
+
self.cmask1er = self.cmasker*2-1
|
147 |
+
self.cmaskmo = input['cmaskmo'].to(self.device)
|
148 |
+
self.cmask1mo = self.cmaskmo*2-1
|
149 |
+
self.real_A_hair = input['hair_A'].to(self.device)
|
150 |
+
self.mask = input['mask'].to(self.device) # mask for non-eyes,nose,mouth
|
151 |
+
self.mask2 = input['mask2'].to(self.device) # mask for non-bg
|
152 |
+
self.real_A_bg = input['bg_A'].to(self.device)
|
153 |
+
|
154 |
+
def getonehot(self,outputs,classes):
|
155 |
+
[maxv,index] = torch.max(outputs,1)
|
156 |
+
y = torch.unsqueeze(index,1)
|
157 |
+
onehot = torch.FloatTensor(self.batch_size,classes).to(self.device)
|
158 |
+
onehot.zero_()
|
159 |
+
onehot.scatter_(1,y,1)
|
160 |
+
return onehot
|
161 |
+
|
162 |
+
def forward(self):
|
163 |
+
if not self.opt.use_local:
|
164 |
+
self.fake_B = self.netG(self.real_A)
|
165 |
+
else:
|
166 |
+
self.fake_B0 = self.netG(self.real_A)
|
167 |
+
# EYES, MOUTH
|
168 |
+
outputs1 = self.netCLm(self.real_A_mouth)
|
169 |
+
onehot1 = self.getonehot(outputs1,2)
|
170 |
+
|
171 |
+
if not self.opt.others_ae:
|
172 |
+
fake_B_eyel = self.netGLEyel(self.real_A_eyel)
|
173 |
+
fake_B_eyer = self.netGLEyer(self.real_A_eyer)
|
174 |
+
fake_B_mouth = self.netGLMouth(self.real_A_mouth)
|
175 |
+
else: # use AE that only constains compact region, need cmask!
|
176 |
+
self.fake_B_eyel1 = self.netGLEyel(self.real_A_eyel)
|
177 |
+
self.fake_B_eyer1 = self.netGLEyer(self.real_A_eyer)
|
178 |
+
self.fake_B_mouth1 = self.netGLMouth(self.real_A_mouth)
|
179 |
+
self.fake_B_eyel2,_ = self.netAEel(self.fake_B_eyel1)
|
180 |
+
self.fake_B_eyer2,_ = self.netAEer(self.fake_B_eyer1)
|
181 |
+
# USE 2 AEs
|
182 |
+
self.fake_B_mouth2 = torch.FloatTensor(self.batch_size,self.opt.output_nc,self.MOUTH_H,self.MOUTH_W).to(self.device)
|
183 |
+
for i in range(self.batch_size):
|
184 |
+
if onehot1[i][0] == 1:
|
185 |
+
self.fake_B_mouth2[i],_ = self.netAEmowhite(self.fake_B_mouth1[i].unsqueeze(0))
|
186 |
+
#print('AEmowhite')
|
187 |
+
elif onehot1[i][1] == 1:
|
188 |
+
self.fake_B_mouth2[i],_ = self.netAEmoblack(self.fake_B_mouth1[i].unsqueeze(0))
|
189 |
+
#print('AEmoblack')
|
190 |
+
fake_B_eyel = self.add_with_mask(self.fake_B_eyel2,self.fake_B_eyel1,self.cmaskel)
|
191 |
+
fake_B_eyer = self.add_with_mask(self.fake_B_eyer2,self.fake_B_eyer1,self.cmasker)
|
192 |
+
fake_B_mouth = self.add_with_mask(self.fake_B_mouth2,self.fake_B_mouth1,self.cmaskmo)
|
193 |
+
# NOSE
|
194 |
+
if not self.opt.nose_ae:
|
195 |
+
fake_B_nose = self.netGLNose(self.real_A_nose)
|
196 |
+
else: # use AE that only constains compact region, need cmask!
|
197 |
+
self.fake_B_nose1 = self.netGLNose(self.real_A_nose)
|
198 |
+
self.fake_B_nose2,_ = self.netAE(self.fake_B_nose1)
|
199 |
+
fake_B_nose = self.add_with_mask(self.fake_B_nose2,self.fake_B_nose1,self.cmask)
|
200 |
+
|
201 |
+
# HAIR, BG AND PARTCOMBINE
|
202 |
+
outputs2 = self.netCLh(self.real_A_hair)
|
203 |
+
onehot2 = self.getonehot(outputs2,3)
|
204 |
+
|
205 |
+
fake_B_hair = self.netGLHair(self.real_A_hair,onehot2)
|
206 |
+
fake_B_bg = self.netGLBG(self.real_A_bg)
|
207 |
+
self.fake_B_hair = self.masked(fake_B_hair,self.mask*self.mask2)
|
208 |
+
self.fake_B_bg = self.masked(fake_B_bg,self.inverse_mask(self.mask2))
|
209 |
+
if not self.opt.compactmask:
|
210 |
+
self.fake_B1 = self.partCombiner2_bg(fake_B_eyel,fake_B_eyer,fake_B_nose,fake_B_mouth,fake_B_hair,fake_B_bg,self.mask*self.mask2,self.inverse_mask(self.mask2),self.opt.comb_op)
|
211 |
+
else:
|
212 |
+
self.fake_B1 = self.partCombiner2_bg(fake_B_eyel,fake_B_eyer,fake_B_nose,fake_B_mouth,fake_B_hair,fake_B_bg,self.mask*self.mask2,self.inverse_mask(self.mask2),self.opt.comb_op,self.opt.region_enm,self.cmaskel,self.cmasker,self.cmask,self.cmaskmo)
|
213 |
+
|
214 |
+
self.fake_B = self.netGCombine(torch.cat([self.fake_B0,self.fake_B1],1))
|
APDrawingGAN2/options/__init__.py
ADDED
File without changes
|
APDrawingGAN2/options/base_options.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from util import util
|
4 |
+
import torch
|
5 |
+
import models
|
6 |
+
import data
|
7 |
+
|
8 |
+
|
9 |
+
class BaseOptions():
|
10 |
+
def __init__(self):
|
11 |
+
self.initialized = False
|
12 |
+
|
13 |
+
def initialize(self, parser):
|
14 |
+
parser.add_argument('--dataroot', type=str, default='', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
|
15 |
+
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
|
16 |
+
parser.add_argument('--loadSize', type=int, default=512, help='scale images to this size')
|
17 |
+
parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size')
|
18 |
+
parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
|
19 |
+
parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels')
|
20 |
+
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
|
21 |
+
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
|
22 |
+
parser.add_argument('--netD', type=str, default='basic', help='selects model to use for netD')
|
23 |
+
parser.add_argument('--netG', type=str, default='unet_256', help='selects model to use for netG')
|
24 |
+
parser.add_argument('--nnG', type=int, default=9, help='specify nblock for resnet_nblocks, ndown for unet for unet_ndown')
|
25 |
+
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
|
26 |
+
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
|
27 |
+
parser.add_argument('--gpu_ids_p', type=str, default='0', help='gpu ids for pretrained auxiliary models: e.g. 0 0,1,2, 0,2. use -1 for CPU')
|
28 |
+
parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
|
29 |
+
parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [unaligned | aligned | single]')
|
30 |
+
parser.add_argument('--model', type=str, default='apdrawing',
|
31 |
+
help='chooses which model to use. cycle_gan, pix2pix, test, autoencoder')
|
32 |
+
parser.add_argument('--use_local', action='store_true', help='use local part network')
|
33 |
+
parser.add_argument('--lm_dir', type=str, default='dataset/landmark/', help='path to facial landmarks')
|
34 |
+
parser.add_argument('--nose_ae', action='store_true', help='use nose autoencoder')
|
35 |
+
parser.add_argument('--others_ae', action='store_true', help='use autoencoder for eyes and mouth too')
|
36 |
+
parser.add_argument('--nose_ae_net', type=str, default='autoencoderfc', help='net for nose autoencoder [autoencoder | autoencoderfc]')
|
37 |
+
parser.add_argument('--comb_op', type=int, default=1, help='use min-pooling(1) or max-pooling(0) for overlapping regions')
|
38 |
+
parser.add_argument('--hair_local', action='store_true', help='add hair part')
|
39 |
+
parser.add_argument('--bg_local', action='store_true', help='use background mask to seperate background')
|
40 |
+
parser.add_argument('--bg_dir', default='dataset/mask/bg/', type=str, help='choose bg_dir')
|
41 |
+
parser.add_argument('--region_enm', type=int, default=0, help='region type for eyes nose mouth: 0 for rectangle, 1 for campact mask in rectangle, 2 for mask no rectangle (1,2 must have compactmask, 0 use compactmask for AE)')
|
42 |
+
parser.add_argument('--soft_border', type=int, default=0, help='use mask with soft border')
|
43 |
+
parser.add_argument('--EYE_H', type=int, default=40, help='EYE_H')
|
44 |
+
parser.add_argument('--EYE_W', type=int, default=56, help='EYE_W')
|
45 |
+
parser.add_argument('--NOSE_H', type=int, default=48, help='NOSE_H')
|
46 |
+
parser.add_argument('--NOSE_W', type=int, default=48, help='NOSE_W')
|
47 |
+
parser.add_argument('--MOUTH_H', type=int, default=40, help='MOUTH_H')
|
48 |
+
parser.add_argument('--MOUTH_W', type=int, default=64, help='MOUTH_W')
|
49 |
+
parser.add_argument('--average_pos', action='store_true', help='use avg pos in partCombiner')
|
50 |
+
parser.add_argument('--combiner_type', type=str, default='combiner', help='choose combiner type')
|
51 |
+
parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
|
52 |
+
parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
|
53 |
+
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
|
54 |
+
parser.add_argument('--auxiliary_root', type=str, default='auxiliary', help='auxiliary model folder')
|
55 |
+
parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
|
56 |
+
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
|
57 |
+
parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
|
58 |
+
parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
|
59 |
+
parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
|
60 |
+
parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
|
61 |
+
parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
|
62 |
+
parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
|
63 |
+
parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
|
64 |
+
parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
|
65 |
+
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
|
66 |
+
parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
|
67 |
+
parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
|
68 |
+
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
|
69 |
+
parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{loadSize}')
|
70 |
+
# compact mask
|
71 |
+
parser.add_argument('--compactmask', action='store_true', help='use compact mask as input and apply to loss')# "when you calculate the (ae) loss, you should also restrict to nose pixels"
|
72 |
+
parser.add_argument('--cmask_dir', type=str, default='dataset/mask/', help='compact mask directory')
|
73 |
+
parser.add_argument('--ae_latentno', type=int, default=1024 ,help='latent space dim for pretrained NOSE AEwithfc')
|
74 |
+
parser.add_argument('--ae_latentmo', type=int, default=1024 ,help='latent space dim for pretrained MOUTH AEwithfc')
|
75 |
+
parser.add_argument('--ae_latenteye', type=int, default=1024 ,help='latent space dim for pretrained EYEL/EYER AEwithfc')
|
76 |
+
parser.add_argument('--ae_small', type=int, default=0 ,help='use latent dim smaller than default 1024 in 4 AEs')
|
77 |
+
# below for autoencoder
|
78 |
+
parser.add_argument('--ae_latent', type=int, default=1024 ,help='latent space dim for autoencoderfc')
|
79 |
+
parser.add_argument('--ae_multiple', type=float, default=2 ,help='filter number change in ae encoder')
|
80 |
+
parser.add_argument('--ae_h', type=int, default=96 ,help='ae input h')
|
81 |
+
parser.add_argument('--ae_w', type=int, default=96 ,help='ae input w')
|
82 |
+
parser.add_argument('--ae_region', type=str, default='nose' ,help='autoencoder for which region')
|
83 |
+
parser.add_argument('--no_ae', action='store_true', help='no ae')
|
84 |
+
self.initialized = True
|
85 |
+
return parser
|
86 |
+
|
87 |
+
def gather_options(self):
|
88 |
+
# initialize parser with basic options
|
89 |
+
if not self.initialized:
|
90 |
+
parser = argparse.ArgumentParser(
|
91 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
92 |
+
parser = self.initialize(parser)
|
93 |
+
|
94 |
+
# get the basic options
|
95 |
+
opt, _ = parser.parse_known_args()
|
96 |
+
|
97 |
+
# modify model-related parser options
|
98 |
+
model_name = opt.model
|
99 |
+
model_option_setter = models.get_option_setter(model_name)
|
100 |
+
parser = model_option_setter(parser, self.isTrain)
|
101 |
+
opt, _ = parser.parse_known_args() # parse again with the new defaults
|
102 |
+
|
103 |
+
# modify dataset-related parser options
|
104 |
+
dataset_name = opt.dataset_mode
|
105 |
+
dataset_option_setter = data.get_option_setter(dataset_name)
|
106 |
+
parser = dataset_option_setter(parser, self.isTrain)
|
107 |
+
|
108 |
+
self.parser = parser
|
109 |
+
|
110 |
+
return parser.parse_args()
|
111 |
+
|
112 |
+
def print_options(self, opt):
|
113 |
+
message = ''
|
114 |
+
message += '----------------- Options ---------------\n'
|
115 |
+
for k, v in sorted(vars(opt).items()):
|
116 |
+
comment = ''
|
117 |
+
default = self.parser.get_default(k)
|
118 |
+
if v != default:
|
119 |
+
comment = '\t[default: %s]' % str(default)
|
120 |
+
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
121 |
+
message += '----------------- End -------------------'
|
122 |
+
print(message)
|
123 |
+
|
124 |
+
# save to the disk
|
125 |
+
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
126 |
+
util.mkdirs(expr_dir)
|
127 |
+
file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
|
128 |
+
with open(file_name, 'wt') as opt_file:
|
129 |
+
opt_file.write(message)
|
130 |
+
opt_file.write('\n')
|
131 |
+
|
132 |
+
def parse(self, print=True):
|
133 |
+
|
134 |
+
opt = self.gather_options()
|
135 |
+
if opt.use_local:
|
136 |
+
opt.loadSize = opt.fineSize
|
137 |
+
if opt.region_enm in [1,2]:
|
138 |
+
opt.compactmask = True
|
139 |
+
if opt.nose_ae or opt.others_ae:
|
140 |
+
opt.compactmask = True
|
141 |
+
if opt.ae_latentno < 1024 and opt.ae_latentmo < 1024 and opt.ae_latenteye < 1024:
|
142 |
+
opt.ae_small = 1
|
143 |
+
opt.isTrain = self.isTrain # train or test
|
144 |
+
|
145 |
+
# process opt.suffix
|
146 |
+
if opt.suffix:
|
147 |
+
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
|
148 |
+
opt.name = opt.name + suffix
|
149 |
+
|
150 |
+
if self.isTrain and opt.pretrain:
|
151 |
+
opt.nose_ae = False
|
152 |
+
opt.others_ae = False
|
153 |
+
opt.compactmask = False
|
154 |
+
opt.chamfer_loss = False
|
155 |
+
if not self.isTrain and opt.pretrain:
|
156 |
+
opt.nose_ae = False
|
157 |
+
opt.others_ae = False
|
158 |
+
opt.compactmask = False
|
159 |
+
if opt.no_ae:
|
160 |
+
opt.nose_ae = False
|
161 |
+
opt.others_ae = False
|
162 |
+
opt.compactmask = False
|
163 |
+
if self.isTrain and opt.no_dtremap:
|
164 |
+
opt.dt_nonlinear = ''
|
165 |
+
opt.lambda_chamfer = 0.1
|
166 |
+
opt.lambda_chamfer2 = 0.1
|
167 |
+
if self.isTrain and opt.no_dt:
|
168 |
+
opt.chamfer_loss = False
|
169 |
+
|
170 |
+
if print:
|
171 |
+
self.print_options(opt)
|
172 |
+
|
173 |
+
# set gpu ids
|
174 |
+
str_ids = opt.gpu_ids.split(',')
|
175 |
+
opt.gpu_ids = []
|
176 |
+
for str_id in str_ids:
|
177 |
+
id = int(str_id)
|
178 |
+
if id >= 0:
|
179 |
+
opt.gpu_ids.append(id)
|
180 |
+
if len(opt.gpu_ids) > 0:
|
181 |
+
torch.cuda.set_device(opt.gpu_ids[0])
|
182 |
+
|
183 |
+
# set gpu ids
|
184 |
+
str_ids = opt.gpu_ids_p.split(',')
|
185 |
+
opt.gpu_ids_p = []
|
186 |
+
for str_id in str_ids:
|
187 |
+
id = int(str_id)
|
188 |
+
if id >= 0:
|
189 |
+
opt.gpu_ids_p.append(id)
|
190 |
+
|
191 |
+
self.opt = opt
|
192 |
+
return self.opt
|
APDrawingGAN2/options/test_options.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_options import BaseOptions
|
2 |
+
|
3 |
+
|
4 |
+
class TestOptions(BaseOptions):
|
5 |
+
def initialize(self, parser):
|
6 |
+
parser = BaseOptions.initialize(self, parser)
|
7 |
+
parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
|
8 |
+
parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
|
9 |
+
parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
|
10 |
+
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
|
11 |
+
parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
|
12 |
+
parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
|
13 |
+
parser.add_argument('--test_continuity_loss', action='store_true', help='get continuity value in test')
|
14 |
+
parser.add_argument('--netG_line', type=str, default='unet_512', help='selects model to use for netG_line')
|
15 |
+
parser.add_argument('--save2', action='store_true', help='only save real_A and fake_B')
|
16 |
+
parser.add_argument('--imagefolder', type=str, default='images', help='subfolder to save images')
|
17 |
+
parser.add_argument('--pretrain', action='store_true', help='pretrain stage, no dt loss, no ae')
|
18 |
+
|
19 |
+
parser.set_defaults(model='test')
|
20 |
+
# To avoid cropping, the loadSize should be the same as fineSize
|
21 |
+
parser.set_defaults(loadSize=parser.get_default('fineSize'))
|
22 |
+
self.isTrain = False
|
23 |
+
return parser
|
APDrawingGAN2/options/train_options.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_options import BaseOptions
|
2 |
+
|
3 |
+
|
4 |
+
class TrainOptions(BaseOptions):
|
5 |
+
def initialize(self, parser):
|
6 |
+
parser = BaseOptions.initialize(self, parser)
|
7 |
+
parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
|
8 |
+
parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
|
9 |
+
parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
|
10 |
+
parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
|
11 |
+
parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
|
12 |
+
parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
|
13 |
+
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
|
14 |
+
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
|
15 |
+
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
|
16 |
+
parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
|
17 |
+
parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
|
18 |
+
parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
|
19 |
+
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
|
20 |
+
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
|
21 |
+
parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
|
22 |
+
parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
|
23 |
+
parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
|
24 |
+
parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine')
|
25 |
+
parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
|
26 |
+
# ============================================loss=========================================================
|
27 |
+
# chamfer loss
|
28 |
+
parser.add_argument('--chamfer_loss', action='store_true', help='use chamfer loss')
|
29 |
+
parser.add_argument('--chamfer_2way', action='store_true', help='use chamfer loss 2 way')
|
30 |
+
parser.add_argument('--chamfer_only_line', action='store_true', help='use chamfer only on lines')
|
31 |
+
parser.add_argument('--lambda_chamfer', type=float, default=0.1, help='weight for chamfer loss')
|
32 |
+
parser.add_argument('--lambda_chamfer2', type=float, default=0.1, help='weight for chamfer loss2')
|
33 |
+
parser.add_argument('--dt_nonlinear', type=str, default='', help='nonlinear remap on dt [atan | sigmoid | tanh]')
|
34 |
+
parser.add_argument('--dt_xmax', type=float, default=10, help='first mutiply dt to range [0,xmax], then use atan/sigmoid/tanh etc, to have more nonlinearity (not much nonlinearity in range [0,1])')
|
35 |
+
# line continuity loss
|
36 |
+
parser.add_argument('--continuity_loss', action='store_true', help='use line continuity loss')
|
37 |
+
parser.add_argument('--lambda_continuity', type=float, default=10.0, help='weight for continuity loss')
|
38 |
+
parser.add_argument('--emphasis_conti_face', action='store_true', help='constrain conti loss to pixels in original lines (avoid apply to background etc)')
|
39 |
+
parser.add_argument('--facemask_dir', type=str, default='dataset/mask/face/', help='mask folder to constrain conti loss to pixels in original lines')
|
40 |
+
# =====================================auxilary net structure===============================================
|
41 |
+
# dt & line net structure
|
42 |
+
parser.add_argument('--netG_dt', type=str, default='unet_512', help='selects model to use for netG_dt, for chamfer loss')
|
43 |
+
parser.add_argument('--netG_line', type=str, default='unet_512', help='selects model to use for netG_line, for chamfer loss')
|
44 |
+
# multiple discriminators
|
45 |
+
parser.add_argument('--discriminator_local', action='store_true', help='use six diffent local discriminator for 6 local regions')
|
46 |
+
parser.add_argument('--gan_loss_strategy', type=int, default=2, help='specify how to calculate gan loss for g, 1: average global and local discriminators; 2: not change global discriminator weight, 0.25 for local')
|
47 |
+
parser.add_argument('--addw_eye', type=float, default=1.0, help='additional weight for eye region')
|
48 |
+
parser.add_argument('--addw_nose', type=float, default=1.0, help='additional weight for nose region')
|
49 |
+
parser.add_argument('--addw_mouth', type=float, default=1.0, help='additional weight for mouth region')
|
50 |
+
parser.add_argument('--addw_hair', type=float, default=1.0, help='additional weight for hair region')
|
51 |
+
parser.add_argument('--addw_bg', type=float, default=1.0, help='additional weight for bg region')
|
52 |
+
# ==========================================ablation========================================================
|
53 |
+
parser.add_argument('--no_l1_loss', action='store_true', help='no l1 loss')
|
54 |
+
parser.add_argument('--no_G_local_loss', action='store_true', help='not using local transfer loss for local generator output')
|
55 |
+
parser.add_argument('--no_dtremap', action='store_true', help='no dt remap')
|
56 |
+
parser.add_argument('--no_dt', action='store_true', help='no dt')
|
57 |
+
|
58 |
+
parser.add_argument('--pretrain', action='store_true', help='pretrain stage, no dt loss, no ae')
|
59 |
+
|
60 |
+
|
61 |
+
self.isTrain = True
|
62 |
+
return parser
|
APDrawingGAN2/preprocess/combine_A_and_B.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
parser = argparse.ArgumentParser('create image pairs')
|
7 |
+
parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges')
|
8 |
+
parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg')
|
9 |
+
parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB')
|
10 |
+
parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000)
|
11 |
+
parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true')
|
12 |
+
args = parser.parse_args()
|
13 |
+
|
14 |
+
for arg in vars(args):
|
15 |
+
print('[%s] = ' % arg, getattr(args, arg))
|
16 |
+
|
17 |
+
splits = os.listdir(args.fold_A)
|
18 |
+
|
19 |
+
for sp in splits:
|
20 |
+
img_fold_A = os.path.join(args.fold_A, sp)
|
21 |
+
img_fold_B = os.path.join(args.fold_B, sp)
|
22 |
+
img_list = os.listdir(img_fold_A)
|
23 |
+
if args.use_AB:
|
24 |
+
img_list = [img_path for img_path in img_list if '_A.' in img_path]
|
25 |
+
|
26 |
+
num_imgs = min(args.num_imgs, len(img_list))
|
27 |
+
print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list)))
|
28 |
+
img_fold_AB = os.path.join(args.fold_AB, sp)
|
29 |
+
if not os.path.isdir(img_fold_AB):
|
30 |
+
os.makedirs(img_fold_AB)
|
31 |
+
print('split = %s, number of images = %d' % (sp, num_imgs))
|
32 |
+
for n in range(num_imgs):
|
33 |
+
name_A = img_list[n]
|
34 |
+
path_A = os.path.join(img_fold_A, name_A)
|
35 |
+
if args.use_AB:
|
36 |
+
name_B = name_A.replace('_A.', '_B.')
|
37 |
+
else:
|
38 |
+
name_B = name_A
|
39 |
+
path_B = os.path.join(img_fold_B, name_B)
|
40 |
+
if os.path.isfile(path_A) and os.path.isfile(path_B):
|
41 |
+
name_AB = name_A
|
42 |
+
if args.use_AB:
|
43 |
+
name_AB = name_AB.replace('_A.', '.') # remove _A
|
44 |
+
path_AB = os.path.join(img_fold_AB, name_AB)
|
45 |
+
im_A = cv2.imread(path_A, cv2.IMREAD_COLOR)
|
46 |
+
im_B = cv2.imread(path_B, cv2.IMREAD_COLOR)
|
47 |
+
im_AB = np.concatenate([im_A, im_B], 1)
|
48 |
+
cv2.imwrite(path_AB, im_AB)
|
APDrawingGAN2/preprocess/example/img_1701.jpg
ADDED
![]() |
APDrawingGAN2/preprocess/example/img_1701_aligned.png
ADDED
![]() |
APDrawingGAN2/preprocess/example/img_1701_aligned.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
194 248
|
2 |
+
314 249
|
3 |
+
261 312
|
4 |
+
209 368
|
5 |
+
302 371
|
APDrawingGAN2/preprocess/example/img_1701_aligned_68lm.txt
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
120 261
|
2 |
+
124 294
|
3 |
+
129 326
|
4 |
+
133 358
|
5 |
+
142 388
|
6 |
+
162 412
|
7 |
+
190 430
|
8 |
+
220 445
|
9 |
+
253 449
|
10 |
+
287 447
|
11 |
+
317 432
|
12 |
+
344 411
|
13 |
+
362 385
|
14 |
+
370 354
|
15 |
+
375 322
|
16 |
+
382 291
|
17 |
+
385 258
|
18 |
+
142 225
|
19 |
+
161 209
|
20 |
+
188 204
|
21 |
+
215 208
|
22 |
+
242 218
|
23 |
+
269 218
|
24 |
+
296 208
|
25 |
+
324 206
|
26 |
+
351 213
|
27 |
+
369 231
|
28 |
+
256 244
|
29 |
+
256 264
|
30 |
+
256 284
|
31 |
+
256 305
|
32 |
+
232 324
|
33 |
+
244 328
|
34 |
+
256 332
|
35 |
+
267 329
|
36 |
+
277 325
|
37 |
+
172 252
|
38 |
+
186 243
|
39 |
+
203 243
|
40 |
+
218 253
|
41 |
+
203 257
|
42 |
+
186 257
|
43 |
+
290 254
|
44 |
+
305 244
|
45 |
+
322 246
|
46 |
+
336 255
|
47 |
+
322 260
|
48 |
+
305 259
|
49 |
+
210 368
|
50 |
+
229 358
|
51 |
+
245 352
|
52 |
+
256 354
|
53 |
+
267 352
|
54 |
+
283 358
|
55 |
+
300 368
|
56 |
+
284 382
|
57 |
+
268 388
|
58 |
+
255 389
|
59 |
+
244 388
|
60 |
+
228 381
|
61 |
+
220 368
|
62 |
+
245 363
|
63 |
+
256 364
|
64 |
+
267 364
|
65 |
+
290 368
|
66 |
+
267 370
|
67 |
+
255 372
|
68 |
+
244 371
|
APDrawingGAN2/preprocess/example/img_1701_aligned_bgmask.png
ADDED
![]() |
APDrawingGAN2/preprocess/example/img_1701_aligned_eyelmask.png
ADDED
![]() |
APDrawingGAN2/preprocess/example/img_1701_aligned_eyermask.png
ADDED
![]() |
APDrawingGAN2/preprocess/example/img_1701_aligned_facemask.png
ADDED
![]() |
APDrawingGAN2/preprocess/example/img_1701_aligned_mouthmask.png
ADDED
![]() |
APDrawingGAN2/preprocess/example/img_1701_aligned_nosemask.png
ADDED
![]() |
APDrawingGAN2/preprocess/example/img_1701_facial5point.mat
ADDED
Binary file (230 Bytes). View file
|
|
APDrawingGAN2/preprocess/face_align_512.m
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
function [trans_img,trans_facial5point]=face_align_512(impath,facial5point,savedir)
|
2 |
+
% align the faces by similarity transformation.
|
3 |
+
% using 5 facial landmarks: 2 eyes, nose, 2 mouth corners.
|
4 |
+
% impath: path to image
|
5 |
+
% facial5point: 5x2 size, 5 facial landmark positions, detected by MTCNN
|
6 |
+
% savedir: savedir for cropped image and transformed facial landmarks
|
7 |
+
|
8 |
+
%% alignment settings
|
9 |
+
imgSize = [512,512];
|
10 |
+
coord5point = [180,230;
|
11 |
+
300,230;
|
12 |
+
240,301;
|
13 |
+
186,365.6;
|
14 |
+
294,365.6];%480x480
|
15 |
+
coord5point = (coord5point-240)/560 * 512 + 256;
|
16 |
+
|
17 |
+
%% face alignment
|
18 |
+
|
19 |
+
% load and align, resize image to imgSize
|
20 |
+
img = imread(impath);
|
21 |
+
facial5point = double(facial5point);
|
22 |
+
transf = cp2tform(facial5point, coord5point, 'similarity');
|
23 |
+
trans_img = imtransform(img, transf, 'XData', [1 imgSize(2)],...
|
24 |
+
'YData', [1 imgSize(1)],...
|
25 |
+
'Size', imgSize,...
|
26 |
+
'FillValues', [255;255;255]);
|
27 |
+
trans_facial5point = round(tformfwd(transf,facial5point));
|
28 |
+
|
29 |
+
|
30 |
+
%% save results
|
31 |
+
if ~exist(savedir,'dir')
|
32 |
+
mkdir(savedir)
|
33 |
+
end
|
34 |
+
[~,name,~] = fileparts(impath);
|
35 |
+
% save trans_img
|
36 |
+
imwrite(trans_img, fullfile(savedir,[name,'_aligned.png']));
|
37 |
+
fprintf('write aligned image to %s\n',fullfile(savedir,[name,'_aligned.png']));
|
38 |
+
% save trans_facial5point
|
39 |
+
write_5pt(fullfile(savedir, [name, '_aligned.txt']), trans_facial5point);
|
40 |
+
fprintf('write transformed facial landmark to %s\n',fullfile(savedir,[name,'_aligned.txt']));
|
41 |
+
|
42 |
+
%% show results
|
43 |
+
imshow(trans_img); hold on;
|
44 |
+
plot(trans_facial5point(:,1),trans_facial5point(:,2),'b');
|
45 |
+
plot(trans_facial5point(:,1),trans_facial5point(:,2),'r+');
|
46 |
+
|
47 |
+
end
|
48 |
+
|
49 |
+
function [] = write_5pt(fn, trans_pt)
|
50 |
+
fid = fopen(fn, 'w');
|
51 |
+
for i = 1:5
|
52 |
+
fprintf(fid, '%d %d\n', trans_pt(i,1), trans_pt(i,2));%will be read as np.int32
|
53 |
+
end
|
54 |
+
fclose(fid);
|
55 |
+
end
|
APDrawingGAN2/preprocess/get_partmask.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import os, glob, csv, shutil
|
3 |
+
import numpy as np
|
4 |
+
import dlib
|
5 |
+
import math
|
6 |
+
from shapely.geometry import Point
|
7 |
+
from shapely.geometry import Polygon
|
8 |
+
import sys
|
9 |
+
|
10 |
+
detector = dlib.get_frontal_face_detector()
|
11 |
+
predictor = dlib.shape_predictor('../checkpoints/shape_predictor_68_face_landmarks.dat')
|
12 |
+
|
13 |
+
def getfeats(featpath):
|
14 |
+
trans_points = np.empty([68,2],dtype=np.int64)
|
15 |
+
with open(featpath, 'r') as csvfile:
|
16 |
+
reader = csv.reader(csvfile, delimiter=' ')
|
17 |
+
for ind,row in enumerate(reader):
|
18 |
+
trans_points[ind,:] = row
|
19 |
+
return trans_points
|
20 |
+
|
21 |
+
def getinternal(lm1,lm2):
|
22 |
+
lminternal = []
|
23 |
+
if abs(lm1[1]-lm2[1]) > abs(lm1[0]-lm2[0]):
|
24 |
+
if lm1[1] > lm2[1]:
|
25 |
+
tmp = lm1
|
26 |
+
lm1 = lm2
|
27 |
+
lm2 = tmp
|
28 |
+
for y in range(lm1[1]+1,lm2[1]):
|
29 |
+
x = int(round(float(y-lm1[1])/(lm2[1]-lm1[1])*(lm2[0]-lm1[0])+lm1[0]))
|
30 |
+
lminternal.append((x,y))
|
31 |
+
else:
|
32 |
+
if lm1[0] > lm2[0]:
|
33 |
+
tmp = lm1
|
34 |
+
lm1 = lm2
|
35 |
+
lm2 = tmp
|
36 |
+
for x in range(lm1[0]+1,lm2[0]):
|
37 |
+
y = int(round(float(x-lm1[0])/(lm2[0]-lm1[0])*(lm2[1]-lm1[1])+lm1[1]))
|
38 |
+
lminternal.append((x,y))
|
39 |
+
return lminternal
|
40 |
+
|
41 |
+
def mulcross(p,x_1,x):#p-x_1,x-x_1
|
42 |
+
vp = [p[0]-x_1[0],p[1]-x_1[1]]
|
43 |
+
vq = [x[0]-x_1[0],x[1]-x_1[1]]
|
44 |
+
return vp[0]*vq[1]-vp[1]*vq[0]
|
45 |
+
|
46 |
+
def shape_to_np(shape, dtype="int"):
|
47 |
+
# initialize the list of (x, y)-coordinates
|
48 |
+
coords = np.zeros((shape.num_parts, 2), dtype=dtype)
|
49 |
+
# loop over all facial landmarks and convert them
|
50 |
+
# to a 2-tuple of (x, y)-coordinates
|
51 |
+
for i in range(0, shape.num_parts):
|
52 |
+
coords[i] = (shape.part(i).x, shape.part(i).y)
|
53 |
+
# return the list of (x, y)-coordinates
|
54 |
+
return coords
|
55 |
+
|
56 |
+
def get_68lm(imgfile,savepath):
|
57 |
+
image = cv2.imread(imgfile)
|
58 |
+
rgbImg = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
59 |
+
rects = detector(rgbImg, 1)
|
60 |
+
for (i, rect) in enumerate(rects):
|
61 |
+
landmarks = predictor(rgbImg, rect)
|
62 |
+
landmarks = shape_to_np(landmarks)
|
63 |
+
f = open(savepath,'w')
|
64 |
+
for i in range(len(landmarks)):
|
65 |
+
lm = landmarks[i]
|
66 |
+
print(lm[0], lm[1], file=f)
|
67 |
+
f.close()
|
68 |
+
|
69 |
+
def get_partmask(imgfile,part,lmpath,savefile):
|
70 |
+
img = cv2.imread(imgfile)
|
71 |
+
mask = np.zeros(img.shape, np.uint8)
|
72 |
+
lms = getfeats(lmpath)
|
73 |
+
|
74 |
+
if os.path.exists(savefile):
|
75 |
+
return
|
76 |
+
|
77 |
+
if part == 'nose':
|
78 |
+
# 27,31....,35 -> up, left, right, lower5 -- eight points
|
79 |
+
up = [int(round(1.2*lms[27][0]-0.2*lms[33][0])),int(round(1.2*lms[27][1]-0.2*lms[33][1]))]
|
80 |
+
lower5 = [[0,0]]*5
|
81 |
+
for i in range(31,36):
|
82 |
+
lower5[i-31] = [int(round(1.1*lms[i][0]-0.1*lms[27][0])),int(round(1.1*lms[i][1]-0.1*lms[27][1]))]
|
83 |
+
ratio = 2.5
|
84 |
+
left = [int(round(ratio*lower5[0][0]-(ratio-1)*lower5[1][0])),int(round(ratio*lower5[0][1]-(ratio-1)*lower5[1][1]))]
|
85 |
+
right = [int(round(ratio*lower5[4][0]-(ratio-1)*lower5[3][0])),int(round(ratio*lower5[4][1]-(ratio-1)*lower5[3][1]))]
|
86 |
+
loop = [up,left,lower5[0],lower5[1],lower5[2],lower5[3],lower5[4],right]
|
87 |
+
elif part == 'eyel':
|
88 |
+
height = max(lms[41][1]-lms[37][1],lms[40][1]-lms[38][1])
|
89 |
+
width = lms[39][0]-lms[36][0]
|
90 |
+
ratio = 0.1
|
91 |
+
gap = int(math.ceil(width*ratio))
|
92 |
+
ratio2 = 0.6
|
93 |
+
gaph = int(math.ceil(height*ratio2))
|
94 |
+
ratio3 = 1.5
|
95 |
+
gaph2 = int(math.ceil(height*ratio3))
|
96 |
+
upper = [[lms[17][0]-2*gap,lms[17][1]],[lms[17][0]-2*gap,lms[17][1]-gaph],[lms[18][0],lms[18][1]-gaph],[lms[19][0],lms[19][1]-gaph],[lms[20][0],lms[20][1]-gaph],[lms[21][0]+gap*2,lms[21][1]-gaph]]
|
97 |
+
lower = [[lms[39][0]+gap,lms[40][1]+gaph2],[lms[40][0],lms[40][1]+gaph2],[lms[41][0],lms[41][1]+gaph2],[lms[36][0]-2*gap,lms[41][1]+gaph2]]
|
98 |
+
loop = upper + lower
|
99 |
+
loop.reverse()
|
100 |
+
elif part == 'eyer':
|
101 |
+
height = max(lms[47][1]-lms[43][1],lms[46][1]-lms[44][1])
|
102 |
+
width = lms[45][0]-lms[42][0]
|
103 |
+
ratio = 0.1
|
104 |
+
gap = int(math.ceil(width*ratio))
|
105 |
+
ratio2 = 0.6
|
106 |
+
gaph = int(math.ceil(height*ratio2))
|
107 |
+
ratio3 = 1.5
|
108 |
+
gaph2 = int(math.ceil(height*ratio3))
|
109 |
+
upper = [[lms[22][0]-2*gap,lms[22][1]],[lms[22][0]-2*gap,lms[22][1]-gaph],[lms[23][0],lms[23][1]-gaph],[lms[24][0],lms[24][1]-gaph],[lms[25][0],lms[25][1]-gaph],[lms[26][0]+gap*2,lms[26][1]-gaph]]
|
110 |
+
lower = [[lms[45][0]+2*gap,lms[46][1]+gaph2],[lms[46][0],lms[46][1]+gaph2],[lms[47][0],lms[47][1]+gaph2],[lms[42][0]-gap,lms[42][1]+gaph2]]
|
111 |
+
loop = upper + lower
|
112 |
+
loop.reverse()
|
113 |
+
elif part == 'mouth':
|
114 |
+
height = lms[62][1]-lms[51][1]
|
115 |
+
width = lms[54][0]-lms[48][0]
|
116 |
+
ratio = 1
|
117 |
+
ratio2 = 0.2#0.1
|
118 |
+
gaph = int(math.ceil(ratio*height))
|
119 |
+
gapw = int(math.ceil(ratio2*width))
|
120 |
+
left = [(lms[48][0]-gapw,lms[48][1])]
|
121 |
+
upper = [(lms[i][0], lms[i][1]-gaph) for i in range(48,55)]
|
122 |
+
right = [(lms[54][0]+gapw,lms[54][1])]
|
123 |
+
lower = [(lms[i][0], lms[i][1]+gaph) for i in list(range(54,60))+[48]]
|
124 |
+
loop = left + upper + right + lower
|
125 |
+
loop.reverse()
|
126 |
+
pl = Polygon(loop)
|
127 |
+
|
128 |
+
for i in range(mask.shape[0]):
|
129 |
+
for j in range(mask.shape[1]):
|
130 |
+
if part != 'mouth' and part != 'jaw':
|
131 |
+
p = [j,i]
|
132 |
+
flag = 1
|
133 |
+
for k in range(len(loop)):
|
134 |
+
if mulcross(p,loop[k],loop[(k+1)%len(loop)]) < 0:#y downside... >0 represents counter-clockwise, <0 clockwise
|
135 |
+
flag = 0
|
136 |
+
break
|
137 |
+
else:
|
138 |
+
p = Point(j,i)
|
139 |
+
flag = pl.contains(p)
|
140 |
+
if flag:
|
141 |
+
mask[i,j] = [255,255,255]
|
142 |
+
if not os.path.exists(os.path.dirname(savefile)):
|
143 |
+
os.mkdir(os.path.dirname(savefile))
|
144 |
+
cv2.imwrite(savefile,mask)
|
145 |
+
|
146 |
+
if __name__ == '__main__':
|
147 |
+
imgfile = 'example/img_1701_aligned.png'
|
148 |
+
lmfile = 'example/img_1701_aligned_68lm.txt'
|
149 |
+
get_68lm(imgfile,lmfile)
|
150 |
+
for part in ['eyel','eyer','nose','mouth']:
|
151 |
+
savepath = 'example/img_1701_aligned_'+part+'mask.png'
|
152 |
+
get_partmask(imgfile,part,lmfile,savepath)
|
APDrawingGAN2/preprocess/readme.md
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Preprocessing steps
|
2 |
+
|
3 |
+
Both training and testing images need:
|
4 |
+
|
5 |
+
- align to 512x512
|
6 |
+
- facial landmarks
|
7 |
+
- mask for eyes,nose,mouth,background
|
8 |
+
|
9 |
+
Training images additionally need:
|
10 |
+
|
11 |
+
- mask for face region
|
12 |
+
|
13 |
+
|
14 |
+
### 1. Align, resize, crop images to 512x512, and get facial landmarks
|
15 |
+
|
16 |
+
All training and testing images in our model are aligned using facial landmarks. And landmarks after alignment are needed in our code.
|
17 |
+
|
18 |
+
- First, 5 facial landmark for a face photo need to be detected (we detect using [MTCNN](https://github.com/kpzhang93/MTCNN_face_detection_alignment)(MTCNNv1)).
|
19 |
+
|
20 |
+
- Then, we provide a matlab function in `face_align_512.m` to align, resize and crop face photos (and corresponding drawings) to 512x512.Call this function in MATLAB to align the image to 512x512.
|
21 |
+
For example, for `img_1701.jpg` in `example` dir, 5 detected facial landmark is saved in `example/img_1701_facial5point.mat`. Call following in MATLAB:
|
22 |
+
```bash
|
23 |
+
load('example/img_1701_facial5point.mat');
|
24 |
+
[trans_img,trans_facial5point]=face_align_512('example/img_1701.jpg',facial5point,'example');
|
25 |
+
```
|
26 |
+
|
27 |
+
This will align the image, and output aligned image + transformed facial landmark (in txt format) in `example` folder.
|
28 |
+
See `face_align_512.m` for more instructions.
|
29 |
+
|
30 |
+
The saved transformed facial landmark need to be copied to `dataset/landmark/`, and has the **same filename** with aligned face photos (e.g. `dataset/data/test_single/31.png` should have landmark file `dataset/landmark/31.txt`).
|
31 |
+
|
32 |
+
### 2. Prepare background masks
|
33 |
+
|
34 |
+
In our work, background mask is segmented by method in
|
35 |
+
"Automatic Portrait Segmentation for Image Stylization"
|
36 |
+
Xiaoyong Shen, Aaron Hertzmann, Jiaya Jia, Sylvain Paris, Brian Price, Eli Shechtman, Ian Sachs. Computer Graphics Forum, 35(2)(Proc. Eurographics), 2016.
|
37 |
+
|
38 |
+
We use code in http://xiaoyongshen.me/webpage_portrait/index.html to detect background masks for aligned face photos.
|
39 |
+
An example background mask is shown in `example/img_1701_aligned_bgmask.png`.
|
40 |
+
|
41 |
+
The background masks need to be copied to `dataset/mask/bg/`, and has the **same filename** with aligned face photos (e.g. `dataset/data/test_single/31.png` should have background mask `dataset/mask/bg/31.png`)
|
42 |
+
|
43 |
+
### 3. Prepare eyes/nose/mouth masks
|
44 |
+
|
45 |
+
We use dlib to extract 68 landmarks for aligned face photos, and use these landmarks to get masks for local regions.
|
46 |
+
See an example in `get_partmask.py`, the eyes, nose, mouth masks for `example/img_1701_aligned.png` are `example/img_1701_aligned_[part]mask.png`, where part is in [eyel,eyer,nose,mouth].
|
47 |
+
|
48 |
+
The part masks need to be copied to `dataset/mask/[part]/`, and has the **same filename** with aligned face photos.
|
49 |
+
|
50 |
+
### 4. (For training) Prepare face masks
|
51 |
+
|
52 |
+
We use the face parsing net in https://github.com/cientgu/Mask_Guided_Portrait_Editing to detect face region.
|
53 |
+
The face parsing net will label each face into 11 classes, the 0 is for background, 10 is for hair, and the 1~9 are face regions.
|
54 |
+
An example face mask is shown in `example/img_1701_aligned_facemask.png`.
|
55 |
+
|
56 |
+
The face masks need to be copied to `dataset/mask/face/`, and has the **same filename** with aligned face photos.
|
57 |
+
|
58 |
+
### 5. (For training) Combine A and B
|
59 |
+
|
60 |
+
We provide a python script to generate training data in the form of pairs of images {A,B}, i.e. pairs {face photo, drawing}. This script will concatenate each pair of images horizontally into one single image. Then we can learn to translate A to B:
|
61 |
+
|
62 |
+
Create folder `/path/to/data` with subfolders `A` and `B`. `A` and `B` should each have their own subfolders `train`, `test`, etc. In `/path/to/data/A/train`, put training face photos. In `/path/to/data/B/train`, put the corresponding artist drawings. Repeat same for `test`.
|
63 |
+
|
64 |
+
Corresponding images in a pair {A,B} must both be images after aligning and of size 512x512, and have the same filename, e.g., `/path/to/data/A/train/1.png` is considered to correspond to `/path/to/data/B/train/1.png`.
|
65 |
+
|
66 |
+
Once the data is formatted this way, call:
|
67 |
+
```bash
|
68 |
+
python datasets/combine_A_and_B.py --fold_A /path/to/data/A --fold_B /path/to/data/B --fold_AB /path/to/data
|
69 |
+
```
|
70 |
+
|
71 |
+
This will combine each pair of images (A,B) into a single image file, ready for training.
|
APDrawingGAN2/readme.md
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# APDrawingGAN++
|
3 |
+
|
4 |
+
We provide PyTorch implementations for our TPAMI paper "Line Drawings for Face Portraits from Photos using Global and Local Structure based GANs".
|
5 |
+
It is a journal extension of our previous CVPR 2019 work [APDrawingGAN](https://github.com/yiranran/APDrawingGAN).
|
6 |
+
|
7 |
+
This project generates artistic portrait drawings from face photos using a GAN-based model.
|
8 |
+
You may find useful information in [preprocessing steps](preprocess/readme.md) and [training/testing tips](docs/tips.md).
|
9 |
+
|
10 |
+
[[Jittor implementation]](https://github.com/yiranran/APDrawingGAN2-Jittor)
|
11 |
+
|
12 |
+
## Our Proposed Framework
|
13 |
+
|
14 |
+
<img src = 'imgs/architecture-pami.jpg'>
|
15 |
+
|
16 |
+
## Sample Results
|
17 |
+
Up: input, Down: output
|
18 |
+
<p>
|
19 |
+
<img src='imgs/sample/140_large-img_1696_real_A.png' width="16%"/>
|
20 |
+
<img src='imgs/sample/140_large-img_1615_real_A.png' width="16%"/>
|
21 |
+
<img src='imgs/sample/140_large-img_1684_real_A.png' width="16%"/>
|
22 |
+
<img src='imgs/sample/140_large-img_1616_real_A.png' width="16%"/>
|
23 |
+
<img src='imgs/sample/140_large-img_1673_real_A.png' width="16%"/>
|
24 |
+
<img src='imgs/sample/140_large-img_1701_real_A.png' width="16%"/>
|
25 |
+
</p>
|
26 |
+
<p>
|
27 |
+
<img src='imgs/sample/140_large-img_1696_fake_B.png' width="16%"/>
|
28 |
+
<img src='imgs/sample/140_large-img_1615_fake_B.png' width="16%"/>
|
29 |
+
<img src='imgs/sample/140_large-img_1684_fake_B.png' width="16%"/>
|
30 |
+
<img src='imgs/sample/140_large-img_1616_fake_B.png' width="16%"/>
|
31 |
+
<img src='imgs/sample/140_large-img_1673_fake_B.png' width="16%"/>
|
32 |
+
<img src='imgs/sample/140_large-img_1701_fake_B.png' width="16%"/>
|
33 |
+
</p>
|
34 |
+
|
35 |
+
## Citation
|
36 |
+
If you use this code for your research, please cite our paper.
|
37 |
+
```
|
38 |
+
@inproceedings{YiXLLR20,
|
39 |
+
title = {Line Drawings for Face Portraits from Photos using Global and Local Structure based {GAN}s},
|
40 |
+
author = {Yi, Ran and Xia, Mengfei and Liu, Yong-Jin and Lai, Yu-Kun and Rosin, Paul L},
|
41 |
+
booktitle = {{IEEE} Transactions on Pattern Analysis and Machine Intelligence (TPAMI)},
|
42 |
+
doi = {10.1109/TPAMI.2020.2987931},
|
43 |
+
year = {2020}
|
44 |
+
}
|
45 |
+
```
|
46 |
+
|
47 |
+
## Prerequisites
|
48 |
+
- Linux or macOS
|
49 |
+
- Python 2 or 3
|
50 |
+
- CPU or NVIDIA GPU + CUDA CuDNN
|
51 |
+
|
52 |
+
|
53 |
+
## Getting Started
|
54 |
+
### 1.Installation
|
55 |
+
```bash
|
56 |
+
pip install -r requirements.txt
|
57 |
+
```
|
58 |
+
|
59 |
+
### 2.Quick Start (Apply a Pre-trained Model)
|
60 |
+
- Download APDrawing dataset from [BaiduYun](https://pan.baidu.com/s/1cN5gEYJ2tnE9WboLA79Z5g)(extract code:0zuv) or [YandexDrive](https://yadi.sk/d/4vWhi8-ZQj_nRw), and extract to `dataset`.
|
61 |
+
|
62 |
+
- Download pre-trained models and auxiliary nets from [BaiduYun](https://pan.baidu.com/s/1nrtCHQmgcwbSGxWuAVzWhA)(extract code:imqp) or [YandexDrive](https://yadi.sk/d/DS4271lbEPhGVQ), and extract to `checkpoints`.
|
63 |
+
|
64 |
+
- Generate artistic portrait drawings for example photos in `dataset/test_single` using
|
65 |
+
``` bash
|
66 |
+
python test.py --dataroot dataset/test_single --name apdrawinggan++_author --model test --use_resnet --netG resnet_9blocks --which_epoch 150 --how_many 1000 --gpu_ids 0 --gpu_ids_p 0 --imagefolder images-single
|
67 |
+
```
|
68 |
+
The test results will be saved to a html file here: `./results/apdrawinggan++_author/test_150/index-single.html`.
|
69 |
+
|
70 |
+
- If you want to test on your own data, please first align your pictures and prepare your data's facial landmarks and masks according to tutorial in [preprocessing steps](preprocess/readme.md), then change the --dataroot flag above to your directory of aligned photos.
|
71 |
+
|
72 |
+
### 3.Train
|
73 |
+
- Run `python -m visdom.server`
|
74 |
+
- Train a model (with pre-training as initialization):
|
75 |
+
first copy "pre2" models into checkpoints dir of current experiment, e.g. `checkpoints/apdrawinggan++_1`.
|
76 |
+
```bash
|
77 |
+
mkdir checkpoints/apdrawinggan++_1/
|
78 |
+
cp checkpoints/pre2/*.pt checkpoints/apdrawinggan++_1/
|
79 |
+
python train.py --dataroot dataset/AB_140_aug3_H_hm2 --name apdrawinggan++_1 --model apdrawingpp_style --use_resnet --netG resnet_9blocks --continue_train --continuity_loss --lambda_continuity 40.0 --gpu_ids 0 --gpu_ids_p 1 --display_env apdrawinggan++_1 --niter 200 --niter_decay 0 --lr 0.0001 --batch_size 1 --emphasis_conti_face --auxiliary_root auxiliaryeye2o
|
80 |
+
```
|
81 |
+
- To view training results and loss plots, click the URL http://localhost:8097. To see more intermediate results, check out `./checkpoints/apdrawinggan++_1/web/index.html`
|
82 |
+
|
83 |
+
### 4.Test
|
84 |
+
- To test the model on test set:
|
85 |
+
```bash
|
86 |
+
python test.py --dataroot dataset/AB_140_aug3_H_hm2 --name apdrawinggan++_author --model apdrawingpp_style --use_resnet --netG resnet_9blocks --which_epoch 150 --how_many 1000 --gpu_ids 0 --gpu_ids_p 0 --imagefolder images-apd70
|
87 |
+
```
|
88 |
+
The test results will be saved to a html file: `./results/apdrawinggan++_author/test_150/index-apd70.html`.
|
89 |
+
|
90 |
+
- To test the model on images without paired ground truth, same as 2. Apply a pre-trained model.
|
91 |
+
|
92 |
+
You can find these scripts at `scripts` directory.
|
93 |
+
|
94 |
+
|
95 |
+
## [Preprocessing Steps](preprocess/readme.md)
|
96 |
+
Preprocessing steps for your own data (either for testing or training).
|
97 |
+
|
98 |
+
|
99 |
+
## [Training/Test Tips](docs/tips.md)
|
100 |
+
Best practice for training and testing your models.
|
101 |
+
|
102 |
+
You can contact email [email protected] for any questions.
|
103 |
+
|
104 |
+
## Acknowledgments
|
105 |
+
Our code is inspired by [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
|
APDrawingGAN2/requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.1.0
|
2 |
+
torchvision==0.4.0
|
3 |
+
dominate==2.4.0
|
4 |
+
visdom==0.1.8.9
|
5 |
+
scipy==1.1.0
|
6 |
+
numpy==1.16.4
|
7 |
+
Pillow==4.3.0
|
8 |
+
opencv-python==4.1.0.25
|
9 |
+
dlib==19.18.0
|
10 |
+
shapely==1.7.0
|
APDrawingGAN2/script/test.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
set -ex
|
2 |
+
python test.py --dataroot dataset/AB_140_aug3_H_hm2 --name apdrawinggan++_author --model apdrawingpp_style --use_resnet --netG resnet_9blocks --which_epoch 150 --how_many 1000 --gpu_ids 0 --gpu_ids_p 0 --imagefolder images-apd70
|
APDrawingGAN2/script/test_single.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
set -ex
|
2 |
+
python test.py --dataroot dataset/test_single --name apdrawinggan++_author --model test --use_resnet --netG resnet_9blocks --which_epoch 150 --how_many 1000 --gpu_ids 0 --gpu_ids_p 0 --imagefolder images-single
|
APDrawingGAN2/script/train.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
set -ex
|
2 |
+
python train.py --dataroot dataset/AB_140_aug3_H_hm2 --name apdrawinggan++_1 --model apdrawingpp_style --use_resnet --netG resnet_9blocks --continue_train --continuity_loss --lambda_continuity 40.0 --gpu_ids 0 --gpu_ids_p 1 --display_env apdrawinggan++_1 --niter 200 --niter_decay 0 --lr 0.0001 --batch_size 1 --emphasis_conti_face --auxiliary_root auxiliaryeye2o
|
3 |
+
|
APDrawingGAN2/test.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from options.test_options import TestOptions
|
3 |
+
from data import CreateDataLoader
|
4 |
+
from models import create_model
|
5 |
+
from util.visualizer import save_images
|
6 |
+
from util import html
|
7 |
+
|
8 |
+
|
9 |
+
if __name__ == '__main__':
|
10 |
+
opt = TestOptions().parse()
|
11 |
+
opt.num_threads = 1 # test code only supports num_threads = 1
|
12 |
+
opt.batch_size = 1 # test code only supports batch_size = 1
|
13 |
+
opt.serial_batches = True # no shuffle
|
14 |
+
opt.no_flip = True # no flip
|
15 |
+
opt.display_id = -1 # no visdom display
|
16 |
+
data_loader = CreateDataLoader(opt)
|
17 |
+
dataset = data_loader.load_data()
|
18 |
+
model = create_model(opt)
|
19 |
+
model.setup(opt)
|
20 |
+
# create website
|
21 |
+
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
|
22 |
+
#webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
|
23 |
+
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch),reflesh=0, folder=opt.imagefolder)
|
24 |
+
if opt.test_continuity_loss:
|
25 |
+
file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), 'continuity.txt')
|
26 |
+
file_name1 = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), 'continuity-r.txt')
|
27 |
+
if os.path.exists(file_name):
|
28 |
+
os.remove(file_name)
|
29 |
+
if os.path.exists(file_name1):
|
30 |
+
os.remove(file_name1)
|
31 |
+
# test
|
32 |
+
#model.eval()
|
33 |
+
for i, data in enumerate(dataset):
|
34 |
+
if i >= opt.how_many:#test code only supports batch_size = 1, how_many means how many test images to run
|
35 |
+
break
|
36 |
+
model.set_input(data)
|
37 |
+
model.test()
|
38 |
+
visuals = model.get_current_visuals()#in test the loadSize is set to the same as fineSize
|
39 |
+
img_path = model.get_image_paths()
|
40 |
+
#if i % 5 == 0:
|
41 |
+
# print('processing (%04d)-th image... %s' % (i, img_path))
|
42 |
+
save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
|
43 |
+
|
44 |
+
webpage.save()
|
45 |
+
if opt.model == 'regressor':
|
46 |
+
print(model.cnt)
|
47 |
+
print(model.value/model.cnt)
|
48 |
+
print(model.minval)
|
49 |
+
print(model.avg/model.cnt)
|
50 |
+
print(model.max)
|
51 |
+
html = os.path.join(web_dir,'cindex'+opt.imagefolder[6:]+'.html')
|
52 |
+
f=open(html,'w')
|
53 |
+
print('<table border="1" style=\"text-align:center;\">',file=f,end='')
|
54 |
+
print('<tr>',file=f,end='')
|
55 |
+
print('<td>image name</td>',file=f,end='')
|
56 |
+
print('<td>realA</td>',file=f,end='')
|
57 |
+
print('<td>realB</td>',file=f,end='')
|
58 |
+
print('<td>fakeB</td>',file=f,end='')
|
59 |
+
print('</tr>',file=f,end='')
|
60 |
+
for info in model.info:
|
61 |
+
basen = os.path.basename(info[0])[:-4]
|
62 |
+
print('<tr>',file=f,end='')
|
63 |
+
print('<td>%s</td>'%basen,file=f,end='')
|
64 |
+
print('<td><img src=\"%s/%s_real_A.png\" style=\"width:44px\"></td>'%(opt.imagefolder,basen),file=f,end='')
|
65 |
+
print('<td>%.4f</td>'%info[1],file=f,end='')
|
66 |
+
print('<td>%.4f</td>'%info[2],file=f,end='')
|
67 |
+
print('</tr>',file=f,end='')
|
68 |
+
print('</table>',file=f,end='')
|
69 |
+
f.close()
|
APDrawingGAN2/train.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from options.train_options import TrainOptions
|
3 |
+
from data import CreateDataLoader
|
4 |
+
from models import create_model
|
5 |
+
from util.visualizer import Visualizer
|
6 |
+
|
7 |
+
if __name__ == '__main__':
|
8 |
+
start = time.time()
|
9 |
+
opt = TrainOptions().parse()
|
10 |
+
data_loader = CreateDataLoader(opt)
|
11 |
+
dataset = data_loader.load_data()
|
12 |
+
dataset_size = len(data_loader)
|
13 |
+
print('#training images = %d' % dataset_size)
|
14 |
+
|
15 |
+
model = create_model(opt)
|
16 |
+
model.setup(opt)
|
17 |
+
visualizer = Visualizer(opt)
|
18 |
+
total_steps = 0
|
19 |
+
model.save_networks2(opt.which_epoch)
|
20 |
+
|
21 |
+
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
|
22 |
+
epoch_start_time = time.time()
|
23 |
+
iter_data_time = time.time()
|
24 |
+
epoch_iter = 0
|
25 |
+
|
26 |
+
for i, data in enumerate(dataset):
|
27 |
+
iter_start_time = time.time()
|
28 |
+
if total_steps % opt.print_freq == 0:
|
29 |
+
t_data = iter_start_time - iter_data_time
|
30 |
+
visualizer.reset()
|
31 |
+
total_steps += opt.batch_size
|
32 |
+
epoch_iter += opt.batch_size
|
33 |
+
model.set_input(data)
|
34 |
+
model.optimize_parameters()
|
35 |
+
|
36 |
+
if total_steps % opt.display_freq == 0:
|
37 |
+
save_result = total_steps % opt.update_html_freq == 0
|
38 |
+
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
|
39 |
+
#print('display',total_steps)
|
40 |
+
|
41 |
+
if total_steps % opt.print_freq == 0:#print freq 100
|
42 |
+
losses = model.get_current_losses()
|
43 |
+
t = (time.time() - iter_start_time) / opt.batch_size
|
44 |
+
visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data)
|
45 |
+
if opt.display_id > 0:
|
46 |
+
visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses)
|
47 |
+
|
48 |
+
if total_steps % opt.save_latest_freq == 0:
|
49 |
+
print('saving the latest model (epoch %d, total_steps %d)' %
|
50 |
+
(epoch, total_steps))
|
51 |
+
#model.save_networks('latest')
|
52 |
+
model.save_networks2('latest')
|
53 |
+
|
54 |
+
iter_data_time = time.time()
|
55 |
+
if epoch % opt.save_epoch_freq == 0:
|
56 |
+
print('saving the model at the end of epoch %d, iters %d' %
|
57 |
+
(epoch, total_steps))
|
58 |
+
#model.save_networks('latest')
|
59 |
+
#model.save_networks(epoch)
|
60 |
+
model.save_networks2('latest')
|
61 |
+
model.save_networks2(epoch)
|
62 |
+
|
63 |
+
print('End of epoch %d / %d \t Time Taken: %d sec' %
|
64 |
+
(epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
|
65 |
+
model.update_learning_rate()
|
66 |
+
|
67 |
+
print('Total Time Taken: %d sec' % (time.time() - start))
|
APDrawingGAN2/util/__init__.py
ADDED
File without changes
|
APDrawingGAN2/util/get_data.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
import tarfile
|
4 |
+
import requests
|
5 |
+
from warnings import warn
|
6 |
+
from zipfile import ZipFile
|
7 |
+
from bs4 import BeautifulSoup
|
8 |
+
from os.path import abspath, isdir, join, basename
|
9 |
+
|
10 |
+
|
11 |
+
class GetData(object):
|
12 |
+
"""
|
13 |
+
|
14 |
+
Download CycleGAN or Pix2Pix Data.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
technique : str
|
18 |
+
One of: 'cyclegan' or 'pix2pix'.
|
19 |
+
verbose : bool
|
20 |
+
If True, print additional information.
|
21 |
+
|
22 |
+
Examples:
|
23 |
+
>>> from util.get_data import GetData
|
24 |
+
>>> gd = GetData(technique='cyclegan')
|
25 |
+
>>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
|
26 |
+
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, technique='cyclegan', verbose=True):
|
30 |
+
url_dict = {
|
31 |
+
'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets',
|
32 |
+
'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
|
33 |
+
}
|
34 |
+
self.url = url_dict.get(technique.lower())
|
35 |
+
self._verbose = verbose
|
36 |
+
|
37 |
+
def _print(self, text):
|
38 |
+
if self._verbose:
|
39 |
+
print(text)
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def _get_options(r):
|
43 |
+
soup = BeautifulSoup(r.text, 'lxml')
|
44 |
+
options = [h.text for h in soup.find_all('a', href=True)
|
45 |
+
if h.text.endswith(('.zip', 'tar.gz'))]
|
46 |
+
return options
|
47 |
+
|
48 |
+
def _present_options(self):
|
49 |
+
r = requests.get(self.url)
|
50 |
+
options = self._get_options(r)
|
51 |
+
print('Options:\n')
|
52 |
+
for i, o in enumerate(options):
|
53 |
+
print("{0}: {1}".format(i, o))
|
54 |
+
choice = input("\nPlease enter the number of the "
|
55 |
+
"dataset above you wish to download:")
|
56 |
+
return options[int(choice)]
|
57 |
+
|
58 |
+
def _download_data(self, dataset_url, save_path):
|
59 |
+
if not isdir(save_path):
|
60 |
+
os.makedirs(save_path)
|
61 |
+
|
62 |
+
base = basename(dataset_url)
|
63 |
+
temp_save_path = join(save_path, base)
|
64 |
+
|
65 |
+
with open(temp_save_path, "wb") as f:
|
66 |
+
r = requests.get(dataset_url)
|
67 |
+
f.write(r.content)
|
68 |
+
|
69 |
+
if base.endswith('.tar.gz'):
|
70 |
+
obj = tarfile.open(temp_save_path)
|
71 |
+
elif base.endswith('.zip'):
|
72 |
+
obj = ZipFile(temp_save_path, 'r')
|
73 |
+
else:
|
74 |
+
raise ValueError("Unknown File Type: {0}.".format(base))
|
75 |
+
|
76 |
+
self._print("Unpacking Data...")
|
77 |
+
obj.extractall(save_path)
|
78 |
+
obj.close()
|
79 |
+
os.remove(temp_save_path)
|
80 |
+
|
81 |
+
def get(self, save_path, dataset=None):
|
82 |
+
"""
|
83 |
+
|
84 |
+
Download a dataset.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
save_path : str
|
88 |
+
A directory to save the data to.
|
89 |
+
dataset : str, optional
|
90 |
+
A specific dataset to download.
|
91 |
+
Note: this must include the file extension.
|
92 |
+
If None, options will be presented for you
|
93 |
+
to choose from.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
save_path_full : str
|
97 |
+
The absolute path to the downloaded data.
|
98 |
+
|
99 |
+
"""
|
100 |
+
if dataset is None:
|
101 |
+
selected_dataset = self._present_options()
|
102 |
+
else:
|
103 |
+
selected_dataset = dataset
|
104 |
+
|
105 |
+
save_path_full = join(save_path, selected_dataset.split('.')[0])
|
106 |
+
|
107 |
+
if isdir(save_path_full):
|
108 |
+
warn("\n'{0}' already exists. Voiding Download.".format(
|
109 |
+
save_path_full))
|
110 |
+
else:
|
111 |
+
self._print('Downloading Data...')
|
112 |
+
url = "{0}/{1}".format(self.url, selected_dataset)
|
113 |
+
self._download_data(url, save_path=save_path)
|
114 |
+
|
115 |
+
return abspath(save_path_full)
|
APDrawingGAN2/util/html.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dominate
|
2 |
+
from dominate.tags import *
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class HTML:
|
7 |
+
def __init__(self, web_dir, title, reflesh=0, folder='images'):
|
8 |
+
self.title = title
|
9 |
+
self.web_dir = web_dir
|
10 |
+
#self.img_dir = os.path.join(self.web_dir, 'images')
|
11 |
+
self.img_dir = os.path.join(self.web_dir, folder)
|
12 |
+
self.folder = folder
|
13 |
+
if not os.path.exists(self.web_dir):
|
14 |
+
os.makedirs(self.web_dir)
|
15 |
+
if not os.path.exists(self.img_dir):
|
16 |
+
os.makedirs(self.img_dir)
|
17 |
+
# print(self.img_dir)
|
18 |
+
|
19 |
+
self.doc = dominate.document(title=title)
|
20 |
+
if reflesh > 0:
|
21 |
+
with self.doc.head:
|
22 |
+
meta(http_equiv="reflesh", content=str(reflesh))
|
23 |
+
|
24 |
+
def get_image_dir(self):
|
25 |
+
return self.img_dir
|
26 |
+
|
27 |
+
def add_header(self, str):
|
28 |
+
with self.doc:
|
29 |
+
h3(str)
|
30 |
+
|
31 |
+
def add_table(self, border=1):
|
32 |
+
self.t = table(border=border, style="table-layout: fixed;")
|
33 |
+
self.doc.add(self.t)
|
34 |
+
|
35 |
+
def add_images(self, ims, txts, links, width=400):
|
36 |
+
self.add_table()
|
37 |
+
with self.t:
|
38 |
+
with tr():
|
39 |
+
for im, txt, link in zip(ims, txts, links):
|
40 |
+
with td(style="word-wrap: break-word;", halign="center", valign="top"):
|
41 |
+
with p():
|
42 |
+
with a(href=os.path.join('images', link)):
|
43 |
+
#img(style="width:%dpx" % width, src=os.path.join('images', im))
|
44 |
+
img(style="width:%dpx" % width, src=os.path.join(self.folder, im))
|
45 |
+
br()
|
46 |
+
p(txt)
|
47 |
+
|
48 |
+
def save(self):
|
49 |
+
#html_file = '%s/index.html' % self.web_dir
|
50 |
+
html_file = '%s/index%s.html' % (self.web_dir, self.folder[6:])
|
51 |
+
f = open(html_file, 'wt')
|
52 |
+
f.write(self.doc.render())
|
53 |
+
f.close()
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == '__main__':
|
57 |
+
html = HTML('web/', 'test_html')
|
58 |
+
html.add_header('hello world')
|
59 |
+
|
60 |
+
ims = []
|
61 |
+
txts = []
|
62 |
+
links = []
|
63 |
+
for n in range(4):
|
64 |
+
ims.append('image_%d.png' % n)
|
65 |
+
txts.append('text_%d' % n)
|
66 |
+
links.append('image_%d.png' % n)
|
67 |
+
html.add_images(ims, txts, links)
|
68 |
+
html.save()
|
APDrawingGAN2/util/image_pool.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class ImagePool():
|
6 |
+
def __init__(self, pool_size):
|
7 |
+
self.pool_size = pool_size
|
8 |
+
if self.pool_size > 0:
|
9 |
+
self.num_imgs = 0
|
10 |
+
self.images = []
|
11 |
+
|
12 |
+
def query(self, images):
|
13 |
+
if self.pool_size == 0:
|
14 |
+
return images
|
15 |
+
return_images = []
|
16 |
+
for image in images:
|
17 |
+
image = torch.unsqueeze(image.data, 0)
|
18 |
+
if self.num_imgs < self.pool_size:
|
19 |
+
self.num_imgs = self.num_imgs + 1
|
20 |
+
self.images.append(image)
|
21 |
+
return_images.append(image)
|
22 |
+
else:
|
23 |
+
p = random.uniform(0, 1)
|
24 |
+
if p > 0.5:
|
25 |
+
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
|
26 |
+
tmp = self.images[random_id].clone()
|
27 |
+
self.images[random_id] = image
|
28 |
+
return_images.append(tmp)
|
29 |
+
else:
|
30 |
+
return_images.append(image)
|
31 |
+
return_images = torch.cat(return_images, 0)
|
32 |
+
return return_images
|
APDrawingGAN2/util/util.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import os
|
6 |
+
|
7 |
+
|
8 |
+
# Converts a Tensor into an image array (numpy)
|
9 |
+
# |imtype|: the desired type of the converted numpy array
|
10 |
+
def tensor2im(input_image, imtype=np.uint8):
|
11 |
+
if isinstance(input_image, torch.Tensor):
|
12 |
+
image_tensor = input_image.data
|
13 |
+
else:
|
14 |
+
return input_image
|
15 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
16 |
+
if image_numpy.shape[0] == 1:
|
17 |
+
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
18 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
19 |
+
return image_numpy.astype(imtype)
|
20 |
+
|
21 |
+
|
22 |
+
def diagnose_network(net, name='network'):
|
23 |
+
mean = 0.0
|
24 |
+
count = 0
|
25 |
+
for param in net.parameters():
|
26 |
+
if param.grad is not None:
|
27 |
+
mean += torch.mean(torch.abs(param.grad.data))
|
28 |
+
count += 1
|
29 |
+
if count > 0:
|
30 |
+
mean = mean / count
|
31 |
+
print(name)
|
32 |
+
print(mean)
|
33 |
+
|
34 |
+
|
35 |
+
def save_image(image_numpy, image_path):
|
36 |
+
image_pil = Image.fromarray(image_numpy)
|
37 |
+
image_pil.save(image_path)
|
38 |
+
|
39 |
+
|
40 |
+
def print_numpy(x, val=True, shp=False):
|
41 |
+
x = x.astype(np.float64)
|
42 |
+
if shp:
|
43 |
+
print('shape,', x.shape)
|
44 |
+
if val:
|
45 |
+
x = x.flatten()
|
46 |
+
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
|
47 |
+
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
|
48 |
+
|
49 |
+
|
50 |
+
def mkdirs(paths):
|
51 |
+
if isinstance(paths, list) and not isinstance(paths, str):
|
52 |
+
for path in paths:
|
53 |
+
mkdir(path)
|
54 |
+
else:
|
55 |
+
mkdir(paths)
|
56 |
+
|
57 |
+
|
58 |
+
def mkdir(path):
|
59 |
+
if not os.path.exists(path):
|
60 |
+
os.makedirs(path)
|
APDrawingGAN2/util/visualizer.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import ntpath
|
4 |
+
import time
|
5 |
+
from . import util
|
6 |
+
from . import html
|
7 |
+
from scipy.misc import imresize
|
8 |
+
|
9 |
+
|
10 |
+
# save image to the disk
|
11 |
+
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
|
12 |
+
image_dir = webpage.get_image_dir()
|
13 |
+
short_path = ntpath.basename(image_path[0])
|
14 |
+
name = os.path.splitext(short_path)[0]
|
15 |
+
|
16 |
+
webpage.add_header(name)
|
17 |
+
ims, txts, links = [], [], []
|
18 |
+
|
19 |
+
for label, im_data in visuals.items():
|
20 |
+
im = util.tensor2im(im_data)#tensor to numpy array [-1,1]->[0,1]->[0,255]
|
21 |
+
image_name = '%s_%s.png' % (name, label)
|
22 |
+
save_path = os.path.join(image_dir, image_name)
|
23 |
+
h, w, _ = im.shape
|
24 |
+
if aspect_ratio > 1.0:
|
25 |
+
im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
|
26 |
+
if aspect_ratio < 1.0:
|
27 |
+
im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
|
28 |
+
util.save_image(im, save_path)
|
29 |
+
|
30 |
+
ims.append(image_name)
|
31 |
+
txts.append(label)
|
32 |
+
links.append(image_name)
|
33 |
+
webpage.add_images(ims, txts, links, width=width)
|
34 |
+
|
35 |
+
|
36 |
+
class Visualizer():
|
37 |
+
def __init__(self, opt):
|
38 |
+
self.display_id = opt.display_id
|
39 |
+
self.use_html = opt.isTrain and not opt.no_html
|
40 |
+
self.win_size = opt.display_winsize
|
41 |
+
self.name = opt.name
|
42 |
+
self.opt = opt
|
43 |
+
self.saved = False
|
44 |
+
if self.display_id > 0:
|
45 |
+
import visdom
|
46 |
+
self.ncols = opt.display_ncols
|
47 |
+
self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env, raise_exceptions=True)
|
48 |
+
|
49 |
+
if self.use_html:
|
50 |
+
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
|
51 |
+
self.img_dir = os.path.join(self.web_dir, 'images')
|
52 |
+
print('create web directory %s...' % self.web_dir)
|
53 |
+
util.mkdirs([self.web_dir, self.img_dir])
|
54 |
+
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
|
55 |
+
with open(self.log_name, "a") as log_file:
|
56 |
+
now = time.strftime("%c")
|
57 |
+
log_file.write('================ Training Loss (%s) ================\n' % now)
|
58 |
+
|
59 |
+
def reset(self):
|
60 |
+
self.saved = False
|
61 |
+
|
62 |
+
def throw_visdom_connection_error(self):
|
63 |
+
print('\n\nCould not connect to Visdom server (https://github.com/facebookresearch/visdom) for displaying training progress.\nYou can suppress connection to Visdom using the option --display_id -1. To install visdom, run \n$ pip install visdom\n, and start the server by \n$ python -m visdom.server.\n\n')
|
64 |
+
exit(1)
|
65 |
+
|
66 |
+
# |visuals|: dictionary of images to display or save
|
67 |
+
def display_current_results(self, visuals, epoch, save_result):
|
68 |
+
if self.display_id > 0: # show images in the browser
|
69 |
+
ncols = self.ncols
|
70 |
+
if ncols > 0:
|
71 |
+
ncols = min(ncols, len(visuals))
|
72 |
+
h, w = next(iter(visuals.values())).shape[:2]
|
73 |
+
table_css = """<style>
|
74 |
+
table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center}
|
75 |
+
table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black}
|
76 |
+
</style>""" % (w, h)
|
77 |
+
title = self.name
|
78 |
+
label_html = ''
|
79 |
+
label_html_row = ''
|
80 |
+
images = []
|
81 |
+
idx = 0
|
82 |
+
for label, image in visuals.items():
|
83 |
+
image_numpy = util.tensor2im(image)
|
84 |
+
label_html_row += '<td>%s</td>' % label
|
85 |
+
images.append(image_numpy.transpose([2, 0, 1]))
|
86 |
+
idx += 1
|
87 |
+
if idx % ncols == 0:
|
88 |
+
label_html += '<tr>%s</tr>' % label_html_row
|
89 |
+
label_html_row = ''
|
90 |
+
white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
|
91 |
+
while idx % ncols != 0:
|
92 |
+
images.append(white_image)
|
93 |
+
label_html_row += '<td></td>'
|
94 |
+
idx += 1
|
95 |
+
if label_html_row != '':
|
96 |
+
label_html += '<tr>%s</tr>' % label_html_row
|
97 |
+
# pane col = image row
|
98 |
+
try:
|
99 |
+
self.vis.images(images, nrow=ncols, win=self.display_id + 1,
|
100 |
+
padding=2, opts=dict(title=title + ' images'))
|
101 |
+
label_html = '<table>%s</table>' % label_html
|
102 |
+
self.vis.text(table_css + label_html, win=self.display_id + 2,
|
103 |
+
opts=dict(title=title + ' labels'))
|
104 |
+
except ConnectionError:
|
105 |
+
self.throw_visdom_connection_error()
|
106 |
+
|
107 |
+
else:
|
108 |
+
idx = 1
|
109 |
+
for label, image in visuals.items():
|
110 |
+
image_numpy = util.tensor2im(image)
|
111 |
+
self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
|
112 |
+
win=self.display_id + idx)
|
113 |
+
idx += 1
|
114 |
+
|
115 |
+
if self.use_html and (save_result or not self.saved): # save images to a html file
|
116 |
+
self.saved = True
|
117 |
+
for label, image in visuals.items():
|
118 |
+
image_numpy = util.tensor2im(image)
|
119 |
+
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
|
120 |
+
util.save_image(image_numpy, img_path)
|
121 |
+
# update website
|
122 |
+
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
|
123 |
+
for n in range(epoch, 0, -1):
|
124 |
+
webpage.add_header('epoch [%d]' % n)
|
125 |
+
ims, txts, links = [], [], []
|
126 |
+
|
127 |
+
for label, image_numpy in visuals.items():
|
128 |
+
image_numpy = util.tensor2im(image)
|
129 |
+
img_path = 'epoch%.3d_%s.png' % (n, label)
|
130 |
+
ims.append(img_path)
|
131 |
+
txts.append(label)
|
132 |
+
links.append(img_path)
|
133 |
+
webpage.add_images(ims, txts, links, width=self.win_size)
|
134 |
+
webpage.save()
|
135 |
+
|
136 |
+
def save_current_results1(self, visuals, epoch, epoch_iter):
|
137 |
+
if not os.path.exists(self.img_dir+'/detailed'):
|
138 |
+
os.mkdir(self.img_dir+'/detailed')
|
139 |
+
for label, image in visuals.items():
|
140 |
+
image_numpy = util.tensor2im(image)
|
141 |
+
img_path = os.path.join(self.img_dir, 'detailed', 'epoch%.3d_%.3d_%s.png' % (epoch, epoch_iter, label))
|
142 |
+
util.save_image(image_numpy, img_path)
|
143 |
+
|
144 |
+
# losses: dictionary of error labels and values
|
145 |
+
def plot_current_losses(self, epoch, counter_ratio, opt, losses):
|
146 |
+
if not hasattr(self, 'plot_data'):
|
147 |
+
self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
|
148 |
+
self.plot_data['X'].append(epoch + counter_ratio)
|
149 |
+
self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
|
150 |
+
try:
|
151 |
+
self.vis.line(
|
152 |
+
X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
|
153 |
+
Y=np.array(self.plot_data['Y']),
|
154 |
+
opts={
|
155 |
+
'title': self.name + ' loss over time',
|
156 |
+
'legend': self.plot_data['legend'],
|
157 |
+
'xlabel': 'epoch',
|
158 |
+
'ylabel': 'loss'},
|
159 |
+
win=self.display_id)
|
160 |
+
except ConnectionError:
|
161 |
+
self.throw_visdom_connection_error()
|
162 |
+
|
163 |
+
# losses: same format as |losses| of plot_current_losses
|
164 |
+
def print_current_losses(self, epoch, i, losses, t, t_data):
|
165 |
+
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data)
|
166 |
+
for k, v in losses.items():
|
167 |
+
message += '%s: %.6f ' % (k, v)
|
168 |
+
|
169 |
+
print(message)
|
170 |
+
with open(self.log_name, "a") as log_file:
|
171 |
+
log_file.write('%s\n' % message)
|
README.md
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
---
|
|
|
2 |
title: Apdrawing
|
3 |
emoji: 💻
|
4 |
colorFrom: indigo
|
|
|
1 |
---
|
2 |
+
python_version: 3.7
|
3 |
title: Apdrawing
|
4 |
emoji: 💻
|
5 |
colorFrom: indigo
|
app.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
import argparse
|
5 |
+
import functools
|
6 |
+
import os
|
7 |
+
import pathlib
|
8 |
+
import sys
|
9 |
+
from typing import Callable
|
10 |
+
import uuid
|
11 |
+
|
12 |
+
sys.path.insert(0, 'APDrawingGAN2')
|
13 |
+
|
14 |
+
import gradio as gr
|
15 |
+
import huggingface_hub
|
16 |
+
import numpy as np
|
17 |
+
import PIL.Image
|
18 |
+
|
19 |
+
from io import BytesIO
|
20 |
+
import shutil
|
21 |
+
|
22 |
+
from options.test_options import TestOptions
|
23 |
+
from data import CreateDataLoader
|
24 |
+
from models import create_model
|
25 |
+
|
26 |
+
from util import html
|
27 |
+
|
28 |
+
import ntpath
|
29 |
+
from util import util
|
30 |
+
|
31 |
+
|
32 |
+
ORIGINAL_REPO_URL = 'https://github.com/yiranran/APDrawingGAN2'
|
33 |
+
TITLE = 'yiranran/APDrawingGAN2'
|
34 |
+
DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}.
|
35 |
+
|
36 |
+
"""
|
37 |
+
ARTICLE = """
|
38 |
+
|
39 |
+
"""
|
40 |
+
|
41 |
+
|
42 |
+
MODEL_REPO = 'hylee/apdrawing_model'
|
43 |
+
|
44 |
+
def parse_args() -> argparse.Namespace:
|
45 |
+
parser = argparse.ArgumentParser()
|
46 |
+
parser.add_argument('--device', type=str, default='cpu')
|
47 |
+
parser.add_argument('--theme', type=str)
|
48 |
+
parser.add_argument('--live', action='store_true')
|
49 |
+
parser.add_argument('--share', action='store_true')
|
50 |
+
parser.add_argument('--port', type=int)
|
51 |
+
parser.add_argument('--disable-queue',
|
52 |
+
dest='enable_queue',
|
53 |
+
action='store_false')
|
54 |
+
parser.add_argument('--allow-flagging', type=str, default='never')
|
55 |
+
parser.add_argument('--allow-screenshot', action='store_true')
|
56 |
+
return parser.parse_args()
|
57 |
+
|
58 |
+
|
59 |
+
def load_checkpoint():
|
60 |
+
dir = 'checkpoint'
|
61 |
+
checkpoint_path = huggingface_hub.hf_hub_download(MODEL_REPO,
|
62 |
+
'checkpoints.zip',
|
63 |
+
force_filename='checkpoints.zip')
|
64 |
+
print(checkpoint_path)
|
65 |
+
shutil.unpack_archive(checkpoint_path, extract_dir=dir)
|
66 |
+
|
67 |
+
print(os.listdir(dir+'/checkpoints'))
|
68 |
+
|
69 |
+
return dir+'/checkpoints'
|
70 |
+
|
71 |
+
# save image to the disk
|
72 |
+
def save_images2(image_dir, visuals, image_path, aspect_ratio=1.0, width=256):
|
73 |
+
short_path = ntpath.basename(image_path[0])
|
74 |
+
name = os.path.splitext(short_path)[0]
|
75 |
+
|
76 |
+
imgs = []
|
77 |
+
|
78 |
+
for label, im_data in visuals.items():
|
79 |
+
im = util.tensor2im(im_data)#tensor to numpy array [-1,1]->[0,1]->[0,255]
|
80 |
+
image_name = '%s_%s.png' % (name, label)
|
81 |
+
save_path = os.path.join(image_dir, image_name)
|
82 |
+
h, w, _ = im.shape
|
83 |
+
if aspect_ratio > 1.0:
|
84 |
+
im = np.array(PIL.Image.fromarray(arr).resize(im, (h, int(w * aspect_ratio))))
|
85 |
+
if aspect_ratio < 1.0:
|
86 |
+
im = np.array(PIL.Image.fromarray(arr).resize(im, (int(h / aspect_ratio), w)))
|
87 |
+
util.save_image(im, save_path)
|
88 |
+
imgs.append(save_path)
|
89 |
+
|
90 |
+
return imgs
|
91 |
+
|
92 |
+
|
93 |
+
SAFEHASH = [x for x in "0123456789-abcdefghijklmnopqrstuvwxyz_ABCDEFGHIJKLMNOPQRSTUVWXYZ"]
|
94 |
+
def compress_UUID():
|
95 |
+
'''
|
96 |
+
根据http://www.ietf.org/rfc/rfc1738.txt,由uuid编码扩bai大字符域生成du串
|
97 |
+
包括:[0-9a-zA-Z\-_]共64个
|
98 |
+
长度:(32-2)/3*2=20
|
99 |
+
备注:可在地球上人zhi人都用,使用100年不重复(2^120)
|
100 |
+
:return:String
|
101 |
+
'''
|
102 |
+
row = str(uuid.uuid4()).replace('-', '')
|
103 |
+
safe_code = ''
|
104 |
+
for i in range(10):
|
105 |
+
enbin = "%012d" % int(bin(int(row[i * 3] + row[i * 3 + 1] + row[i * 3 + 2], 16))[2:], 10)
|
106 |
+
safe_code += (SAFEHASH[int(enbin[0:6], 2)] + SAFEHASH[int(enbin[6:12], 2)])
|
107 |
+
safe_code = safe_code.replace('-', '')
|
108 |
+
return safe_code
|
109 |
+
|
110 |
+
|
111 |
+
def run(
|
112 |
+
image,
|
113 |
+
model,
|
114 |
+
opt,
|
115 |
+
) -> tuple[PIL.Image.Image]:
|
116 |
+
|
117 |
+
dataroot = 'images/'+compress_UUID()
|
118 |
+
opt.dataroot = os.path.join(dataroot, 'src/')
|
119 |
+
os.makedirs(opt.dataroot, exist_ok=True)
|
120 |
+
opt.results_dir = os.path.join(dataroot, 'results/')
|
121 |
+
os.makedirs(opt.results_dir, exist_ok=True)
|
122 |
+
|
123 |
+
shutil.copy(image.name, opt.dataroot)
|
124 |
+
|
125 |
+
data_loader = CreateDataLoader(opt)
|
126 |
+
dataset = data_loader.load_data()
|
127 |
+
|
128 |
+
imgs = [image.name]
|
129 |
+
# test
|
130 |
+
# model.eval()
|
131 |
+
for i, data in enumerate(dataset):
|
132 |
+
if i >= opt.how_many: # test code only supports batch_size = 1, how_many means how many test images to run
|
133 |
+
break
|
134 |
+
model.set_input(data)
|
135 |
+
model.test()
|
136 |
+
visuals = model.get_current_visuals() # in test the loadSize is set to the same as fineSize
|
137 |
+
img_path = model.get_image_paths()
|
138 |
+
# if i % 5 == 0:
|
139 |
+
# print('processing (%04d)-th image... %s' % (i, img_path))
|
140 |
+
imgs = save_images2(opt.results_dir, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
|
141 |
+
|
142 |
+
print(imgs)
|
143 |
+
return PIL.Image.open(imgs[0])
|
144 |
+
|
145 |
+
|
146 |
+
def main():
|
147 |
+
gr.close_all()
|
148 |
+
|
149 |
+
args = parse_args()
|
150 |
+
|
151 |
+
checkpoint_dir = load_checkpoint()
|
152 |
+
|
153 |
+
opt = TestOptions().parse()
|
154 |
+
opt.num_threads = 1 # test code only supports num_threads = 1
|
155 |
+
opt.batch_size = 1 # test code only supports batch_size = 1
|
156 |
+
opt.serial_batches = True # no shuffle
|
157 |
+
opt.no_flip = True # no flip
|
158 |
+
opt.display_id = -1 # no visdom display
|
159 |
+
|
160 |
+
'''
|
161 |
+
python test.py --dataroot dataset/test_single --name apdrawinggan++_author --model test --use_resnet --netG resnet_9blocks --which_epoch 150 --how_many 1000 --gpu_ids 0 --gpu_ids_p 0 --imagefolder images-single
|
162 |
+
'''
|
163 |
+
opt.dataroot = 'dataset/test_single'
|
164 |
+
opt.name = 'apdrawinggan++_author'
|
165 |
+
opt.model = 'test'
|
166 |
+
opt.use_resnet = True
|
167 |
+
opt.netG = 'resnet_9blocks'
|
168 |
+
opt.which_epoch = 150
|
169 |
+
opt.how_many = 1000
|
170 |
+
opt.gpu_ids = -1
|
171 |
+
opt.gpu_ids_p = -1
|
172 |
+
opt.imagefolder = 'images-single'
|
173 |
+
|
174 |
+
opt.checkpoints_dir = checkpoint_dir
|
175 |
+
|
176 |
+
|
177 |
+
model = create_model(opt)
|
178 |
+
model.setup(opt)
|
179 |
+
|
180 |
+
func = functools.partial(run, model=model, opt=opt)
|
181 |
+
func = functools.update_wrapper(func, run)
|
182 |
+
|
183 |
+
|
184 |
+
gr.Interface(
|
185 |
+
func,
|
186 |
+
[
|
187 |
+
gr.inputs.Image(type='file', label='Input Image'),
|
188 |
+
],
|
189 |
+
[
|
190 |
+
gr.outputs.Image(
|
191 |
+
type='pil',
|
192 |
+
label='Result'),
|
193 |
+
],
|
194 |
+
#examples=examples,
|
195 |
+
theme=args.theme,
|
196 |
+
title=TITLE,
|
197 |
+
description=DESCRIPTION,
|
198 |
+
article=ARTICLE,
|
199 |
+
allow_screenshot=args.allow_screenshot,
|
200 |
+
allow_flagging=args.allow_flagging,
|
201 |
+
live=args.live,
|
202 |
+
).launch(
|
203 |
+
enable_queue=args.enable_queue,
|
204 |
+
server_port=args.port,
|
205 |
+
share=args.share,
|
206 |
+
)
|
207 |
+
|
208 |
+
|
209 |
+
if __name__ == '__main__':
|
210 |
+
main()
|
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=0.4.0
|
2 |
+
torchvision>=0.2.1
|
3 |
+
dominate>=2.3.1
|
4 |
+
visdom>=0.1.8.3
|
5 |
+
scipy>=1.1.0
|
6 |
+
numpy>=1.14.1
|
7 |
+
Pillow>=5.0.0
|
8 |
+
opencv-python>=3.4.2
|