Spaces:
Build error
Build error
PKUWilliamYang
commited on
Commit
·
4e3dd77
1
Parent(s):
3b98894
Upload 7 files
Browse files- scripts/align_all_parallel.py +215 -0
- scripts/calc_id_loss_parallel.py +119 -0
- scripts/calc_losses_on_images.py +84 -0
- scripts/generate_sketch_data.py +62 -0
- scripts/inference.py +136 -0
- scripts/style_mixing.py +101 -0
- scripts/train.py +32 -0
scripts/align_all_parallel.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
|
3 |
+
author: lzhbrian (https://lzhbrian.me)
|
4 |
+
date: 2020.1.5
|
5 |
+
note: code is heavily borrowed from
|
6 |
+
https://github.com/NVlabs/ffhq-dataset
|
7 |
+
http://dlib.net/face_landmark_detection.py.html
|
8 |
+
|
9 |
+
requirements:
|
10 |
+
apt install cmake
|
11 |
+
conda install Pillow numpy scipy
|
12 |
+
pip install dlib
|
13 |
+
# download face landmark model from:
|
14 |
+
# http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
|
15 |
+
"""
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
import time
|
18 |
+
import numpy as np
|
19 |
+
import PIL
|
20 |
+
import PIL.Image
|
21 |
+
import os
|
22 |
+
import scipy
|
23 |
+
import scipy.ndimage
|
24 |
+
import dlib
|
25 |
+
import multiprocessing as mp
|
26 |
+
import math
|
27 |
+
|
28 |
+
from configs.paths_config import model_paths
|
29 |
+
SHAPE_PREDICTOR_PATH = model_paths["shape_predictor"]
|
30 |
+
|
31 |
+
|
32 |
+
def get_landmark(filepath, predictor):
|
33 |
+
"""get landmark with dlib
|
34 |
+
:return: np.array shape=(68, 2)
|
35 |
+
"""
|
36 |
+
detector = dlib.get_frontal_face_detector()
|
37 |
+
if type(filepath) == str:
|
38 |
+
img = dlib.load_rgb_image(filepath)
|
39 |
+
else:
|
40 |
+
img = filepath
|
41 |
+
dets = detector(img, 1)
|
42 |
+
|
43 |
+
if len(dets) == 0:
|
44 |
+
print('Error: no face detected! If you are sure there are faces in your input, you may rerun the code or change the image several times until the face is detected. Sometimes the detector is unstable.')
|
45 |
+
return None
|
46 |
+
|
47 |
+
shape = None
|
48 |
+
for k, d in enumerate(dets):
|
49 |
+
shape = predictor(img, d)
|
50 |
+
|
51 |
+
t = list(shape.parts())
|
52 |
+
a = []
|
53 |
+
for tt in t:
|
54 |
+
a.append([tt.x, tt.y])
|
55 |
+
lm = np.array(a)
|
56 |
+
return lm
|
57 |
+
|
58 |
+
|
59 |
+
def align_face(filepath, predictor):
|
60 |
+
"""
|
61 |
+
:param filepath: str
|
62 |
+
:return: PIL Image
|
63 |
+
"""
|
64 |
+
|
65 |
+
lm = get_landmark(filepath, predictor)
|
66 |
+
if lm is None:
|
67 |
+
return None
|
68 |
+
|
69 |
+
lm_chin = lm[0: 17] # left-right
|
70 |
+
lm_eyebrow_left = lm[17: 22] # left-right
|
71 |
+
lm_eyebrow_right = lm[22: 27] # left-right
|
72 |
+
lm_nose = lm[27: 31] # top-down
|
73 |
+
lm_nostrils = lm[31: 36] # top-down
|
74 |
+
lm_eye_left = lm[36: 42] # left-clockwise
|
75 |
+
lm_eye_right = lm[42: 48] # left-clockwise
|
76 |
+
lm_mouth_outer = lm[48: 60] # left-clockwise
|
77 |
+
lm_mouth_inner = lm[60: 68] # left-clockwise
|
78 |
+
|
79 |
+
# Calculate auxiliary vectors.
|
80 |
+
eye_left = np.mean(lm_eye_left, axis=0)
|
81 |
+
eye_right = np.mean(lm_eye_right, axis=0)
|
82 |
+
eye_avg = (eye_left + eye_right) * 0.5
|
83 |
+
eye_to_eye = eye_right - eye_left
|
84 |
+
mouth_left = lm_mouth_outer[0]
|
85 |
+
mouth_right = lm_mouth_outer[6]
|
86 |
+
mouth_avg = (mouth_left + mouth_right) * 0.5
|
87 |
+
eye_to_mouth = mouth_avg - eye_avg
|
88 |
+
|
89 |
+
# Choose oriented crop rectangle.
|
90 |
+
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
91 |
+
x /= np.hypot(*x)
|
92 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
|
93 |
+
y = np.flipud(x) * [-1, 1]
|
94 |
+
c = eye_avg + eye_to_mouth * 0.1
|
95 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
96 |
+
qsize = np.hypot(*x) * 2
|
97 |
+
|
98 |
+
# read image
|
99 |
+
if type(filepath) == str:
|
100 |
+
img = PIL.Image.open(filepath)
|
101 |
+
else:
|
102 |
+
img = PIL.Image.fromarray(filepath)
|
103 |
+
|
104 |
+
output_size = 256
|
105 |
+
transform_size = 256
|
106 |
+
enable_padding = True
|
107 |
+
|
108 |
+
# Shrink.
|
109 |
+
shrink = int(np.floor(qsize / output_size * 0.5))
|
110 |
+
if shrink > 1:
|
111 |
+
rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
|
112 |
+
img = img.resize(rsize, PIL.Image.ANTIALIAS)
|
113 |
+
quad /= shrink
|
114 |
+
qsize /= shrink
|
115 |
+
|
116 |
+
# Crop.
|
117 |
+
border = max(int(np.rint(qsize * 0.1)), 3)
|
118 |
+
crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
119 |
+
int(np.ceil(max(quad[:, 1]))))
|
120 |
+
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
|
121 |
+
min(crop[3] + border, img.size[1]))
|
122 |
+
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
|
123 |
+
img = img.crop(crop)
|
124 |
+
quad -= crop[0:2]
|
125 |
+
|
126 |
+
# Pad.
|
127 |
+
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
128 |
+
int(np.ceil(max(quad[:, 1]))))
|
129 |
+
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
|
130 |
+
max(pad[3] - img.size[1] + border, 0))
|
131 |
+
if enable_padding and max(pad) > border - 4:
|
132 |
+
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
|
133 |
+
img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
|
134 |
+
h, w, _ = img.shape
|
135 |
+
y, x, _ = np.ogrid[:h, :w, :1]
|
136 |
+
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
|
137 |
+
1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
|
138 |
+
blur = qsize * 0.02
|
139 |
+
img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
140 |
+
img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
|
141 |
+
img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
|
142 |
+
quad += pad[:2]
|
143 |
+
|
144 |
+
# Transform.
|
145 |
+
img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
|
146 |
+
if output_size < transform_size:
|
147 |
+
img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
|
148 |
+
|
149 |
+
# Save aligned image.
|
150 |
+
return img
|
151 |
+
|
152 |
+
|
153 |
+
def chunks(lst, n):
|
154 |
+
"""Yield successive n-sized chunks from lst."""
|
155 |
+
for i in range(0, len(lst), n):
|
156 |
+
yield lst[i:i + n]
|
157 |
+
|
158 |
+
|
159 |
+
def extract_on_paths(file_paths):
|
160 |
+
predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH)
|
161 |
+
pid = mp.current_process().name
|
162 |
+
print('\t{} is starting to extract on #{} images'.format(pid, len(file_paths)))
|
163 |
+
tot_count = len(file_paths)
|
164 |
+
count = 0
|
165 |
+
for file_path, res_path in file_paths:
|
166 |
+
count += 1
|
167 |
+
if count % 100 == 0:
|
168 |
+
print('{} done with {}/{}'.format(pid, count, tot_count))
|
169 |
+
try:
|
170 |
+
res = align_face(file_path, predictor)
|
171 |
+
res = res.convert('RGB')
|
172 |
+
os.makedirs(os.path.dirname(res_path), exist_ok=True)
|
173 |
+
res.save(res_path)
|
174 |
+
except Exception:
|
175 |
+
continue
|
176 |
+
print('\tDone!')
|
177 |
+
|
178 |
+
|
179 |
+
def parse_args():
|
180 |
+
parser = ArgumentParser(add_help=False)
|
181 |
+
parser.add_argument('--num_threads', type=int, default=1)
|
182 |
+
parser.add_argument('--root_path', type=str, default='')
|
183 |
+
args = parser.parse_args()
|
184 |
+
return args
|
185 |
+
|
186 |
+
|
187 |
+
def run(args):
|
188 |
+
root_path = args.root_path
|
189 |
+
out_crops_path = root_path + '_crops'
|
190 |
+
if not os.path.exists(out_crops_path):
|
191 |
+
os.makedirs(out_crops_path, exist_ok=True)
|
192 |
+
|
193 |
+
file_paths = []
|
194 |
+
for root, dirs, files in os.walk(root_path):
|
195 |
+
for file in files:
|
196 |
+
file_path = os.path.join(root, file)
|
197 |
+
fname = os.path.join(out_crops_path, os.path.relpath(file_path, root_path))
|
198 |
+
res_path = '{}.jpg'.format(os.path.splitext(fname)[0])
|
199 |
+
if os.path.splitext(file_path)[1] == '.txt' or os.path.exists(res_path):
|
200 |
+
continue
|
201 |
+
file_paths.append((file_path, res_path))
|
202 |
+
|
203 |
+
file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads))))
|
204 |
+
print(len(file_chunks))
|
205 |
+
pool = mp.Pool(args.num_threads)
|
206 |
+
print('Running on {} paths\nHere we goooo'.format(len(file_paths)))
|
207 |
+
tic = time.time()
|
208 |
+
pool.map(extract_on_paths, file_chunks)
|
209 |
+
toc = time.time()
|
210 |
+
print('Mischief managed in {}s'.format(toc - tic))
|
211 |
+
|
212 |
+
|
213 |
+
if __name__ == '__main__':
|
214 |
+
args = parse_args()
|
215 |
+
run(args)
|
scripts/calc_id_loss_parallel.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentParser
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
import sys
|
7 |
+
from PIL import Image
|
8 |
+
import multiprocessing as mp
|
9 |
+
import math
|
10 |
+
import torch
|
11 |
+
import torchvision.transforms as trans
|
12 |
+
|
13 |
+
sys.path.append(".")
|
14 |
+
sys.path.append("..")
|
15 |
+
|
16 |
+
from models.mtcnn.mtcnn import MTCNN
|
17 |
+
from models.encoders.model_irse import IR_101
|
18 |
+
from configs.paths_config import model_paths
|
19 |
+
CIRCULAR_FACE_PATH = model_paths['circular_face']
|
20 |
+
|
21 |
+
|
22 |
+
def chunks(lst, n):
|
23 |
+
"""Yield successive n-sized chunks from lst."""
|
24 |
+
for i in range(0, len(lst), n):
|
25 |
+
yield lst[i:i + n]
|
26 |
+
|
27 |
+
|
28 |
+
def extract_on_paths(file_paths):
|
29 |
+
facenet = IR_101(input_size=112)
|
30 |
+
facenet.load_state_dict(torch.load(CIRCULAR_FACE_PATH))
|
31 |
+
facenet.cuda()
|
32 |
+
facenet.eval()
|
33 |
+
mtcnn = MTCNN()
|
34 |
+
id_transform = trans.Compose([
|
35 |
+
trans.ToTensor(),
|
36 |
+
trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
37 |
+
])
|
38 |
+
|
39 |
+
pid = mp.current_process().name
|
40 |
+
print('\t{} is starting to extract on {} images'.format(pid, len(file_paths)))
|
41 |
+
tot_count = len(file_paths)
|
42 |
+
count = 0
|
43 |
+
|
44 |
+
scores_dict = {}
|
45 |
+
for res_path, gt_path in file_paths:
|
46 |
+
count += 1
|
47 |
+
if count % 100 == 0:
|
48 |
+
print('{} done with {}/{}'.format(pid, count, tot_count))
|
49 |
+
if True:
|
50 |
+
input_im = Image.open(res_path)
|
51 |
+
input_im, _ = mtcnn.align(input_im)
|
52 |
+
if input_im is None:
|
53 |
+
print('{} skipping {}'.format(pid, res_path))
|
54 |
+
continue
|
55 |
+
|
56 |
+
input_id = facenet(id_transform(input_im).unsqueeze(0).cuda())[0]
|
57 |
+
|
58 |
+
result_im = Image.open(gt_path)
|
59 |
+
result_im, _ = mtcnn.align(result_im)
|
60 |
+
if result_im is None:
|
61 |
+
print('{} skipping {}'.format(pid, gt_path))
|
62 |
+
continue
|
63 |
+
|
64 |
+
result_id = facenet(id_transform(result_im).unsqueeze(0).cuda())[0]
|
65 |
+
score = float(input_id.dot(result_id))
|
66 |
+
scores_dict[os.path.basename(gt_path)] = score
|
67 |
+
|
68 |
+
return scores_dict
|
69 |
+
|
70 |
+
|
71 |
+
def parse_args():
|
72 |
+
parser = ArgumentParser(add_help=False)
|
73 |
+
parser.add_argument('--num_threads', type=int, default=4)
|
74 |
+
parser.add_argument('--data_path', type=str, default='results')
|
75 |
+
parser.add_argument('--gt_path', type=str, default='gt_images')
|
76 |
+
args = parser.parse_args()
|
77 |
+
return args
|
78 |
+
|
79 |
+
|
80 |
+
def run(args):
|
81 |
+
file_paths = []
|
82 |
+
for f in os.listdir(args.data_path):
|
83 |
+
image_path = os.path.join(args.data_path, f)
|
84 |
+
gt_path = os.path.join(args.gt_path, f)
|
85 |
+
if f.endswith(".jpg") or f.endswith('.png'):
|
86 |
+
file_paths.append([image_path, gt_path.replace('.png','.jpg')])
|
87 |
+
|
88 |
+
file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads))))
|
89 |
+
pool = mp.Pool(args.num_threads)
|
90 |
+
print('Running on {} paths\nHere we goooo'.format(len(file_paths)))
|
91 |
+
|
92 |
+
tic = time.time()
|
93 |
+
results = pool.map(extract_on_paths, file_chunks)
|
94 |
+
scores_dict = {}
|
95 |
+
for d in results:
|
96 |
+
scores_dict.update(d)
|
97 |
+
|
98 |
+
all_scores = list(scores_dict.values())
|
99 |
+
mean = np.mean(all_scores)
|
100 |
+
std = np.std(all_scores)
|
101 |
+
result_str = 'New Average score is {:.2f}+-{:.2f}'.format(mean, std)
|
102 |
+
print(result_str)
|
103 |
+
|
104 |
+
out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics')
|
105 |
+
if not os.path.exists(out_path):
|
106 |
+
os.makedirs(out_path)
|
107 |
+
|
108 |
+
with open(os.path.join(out_path, 'stat_id.txt'), 'w') as f:
|
109 |
+
f.write(result_str)
|
110 |
+
with open(os.path.join(out_path, 'scores_id.json'), 'w') as f:
|
111 |
+
json.dump(scores_dict, f)
|
112 |
+
|
113 |
+
toc = time.time()
|
114 |
+
print('Mischief managed in {}s'.format(toc - tic))
|
115 |
+
|
116 |
+
|
117 |
+
if __name__ == '__main__':
|
118 |
+
args = parse_args()
|
119 |
+
run(args)
|
scripts/calc_losses_on_images.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentParser
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import sys
|
5 |
+
from tqdm import tqdm
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
|
11 |
+
sys.path.append(".")
|
12 |
+
sys.path.append("..")
|
13 |
+
|
14 |
+
from criteria.lpips.lpips import LPIPS
|
15 |
+
from datasets.gt_res_dataset import GTResDataset
|
16 |
+
|
17 |
+
|
18 |
+
def parse_args():
|
19 |
+
parser = ArgumentParser(add_help=False)
|
20 |
+
parser.add_argument('--mode', type=str, default='lpips', choices=['lpips', 'l2'])
|
21 |
+
parser.add_argument('--data_path', type=str, default='results')
|
22 |
+
parser.add_argument('--gt_path', type=str, default='gt_images')
|
23 |
+
parser.add_argument('--workers', type=int, default=4)
|
24 |
+
parser.add_argument('--batch_size', type=int, default=4)
|
25 |
+
args = parser.parse_args()
|
26 |
+
return args
|
27 |
+
|
28 |
+
|
29 |
+
def run(args):
|
30 |
+
|
31 |
+
transform = transforms.Compose([transforms.Resize((256, 256)),
|
32 |
+
transforms.ToTensor(),
|
33 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
34 |
+
|
35 |
+
print('Loading dataset')
|
36 |
+
dataset = GTResDataset(root_path=args.data_path,
|
37 |
+
gt_dir=args.gt_path,
|
38 |
+
transform=transform)
|
39 |
+
|
40 |
+
dataloader = DataLoader(dataset,
|
41 |
+
batch_size=args.batch_size,
|
42 |
+
shuffle=False,
|
43 |
+
num_workers=int(args.workers),
|
44 |
+
drop_last=True)
|
45 |
+
|
46 |
+
if args.mode == 'lpips':
|
47 |
+
loss_func = LPIPS(net_type='alex')
|
48 |
+
elif args.mode == 'l2':
|
49 |
+
loss_func = torch.nn.MSELoss()
|
50 |
+
else:
|
51 |
+
raise Exception('Not a valid mode!')
|
52 |
+
loss_func.cuda()
|
53 |
+
|
54 |
+
global_i = 0
|
55 |
+
scores_dict = {}
|
56 |
+
all_scores = []
|
57 |
+
for result_batch, gt_batch in tqdm(dataloader):
|
58 |
+
for i in range(args.batch_size):
|
59 |
+
loss = float(loss_func(result_batch[i:i+1].cuda(), gt_batch[i:i+1].cuda()))
|
60 |
+
all_scores.append(loss)
|
61 |
+
im_path = dataset.pairs[global_i][0]
|
62 |
+
scores_dict[os.path.basename(im_path)] = loss
|
63 |
+
global_i += 1
|
64 |
+
|
65 |
+
all_scores = list(scores_dict.values())
|
66 |
+
mean = np.mean(all_scores)
|
67 |
+
std = np.std(all_scores)
|
68 |
+
result_str = 'Average loss is {:.2f}+-{:.2f}'.format(mean, std)
|
69 |
+
print('Finished with ', args.data_path)
|
70 |
+
print(result_str)
|
71 |
+
|
72 |
+
out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics')
|
73 |
+
if not os.path.exists(out_path):
|
74 |
+
os.makedirs(out_path)
|
75 |
+
|
76 |
+
with open(os.path.join(out_path, 'stat_{}.txt'.format(args.mode)), 'w') as f:
|
77 |
+
f.write(result_str)
|
78 |
+
with open(os.path.join(out_path, 'scores_{}.json'.format(args.mode)), 'w') as f:
|
79 |
+
json.dump(scores_dict, f)
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == '__main__':
|
83 |
+
args = parse_args()
|
84 |
+
run(args)
|
scripts/generate_sketch_data.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import transforms
|
2 |
+
from torchvision.utils import save_image
|
3 |
+
from torch.utils.serialization import load_lua
|
4 |
+
import os
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
"""
|
9 |
+
NOTE!: Must have torch==0.4.1 and torchvision==0.2.1
|
10 |
+
The sketch simplification model (sketch_gan.t7) from Simo Serra et al. can be downloaded from their official implementation:
|
11 |
+
https://github.com/bobbens/sketch_simplification
|
12 |
+
"""
|
13 |
+
|
14 |
+
|
15 |
+
def sobel(img):
|
16 |
+
opImgx = cv2.Sobel(img, cv2.CV_8U, 0, 1, ksize=3)
|
17 |
+
opImgy = cv2.Sobel(img, cv2.CV_8U, 1, 0, ksize=3)
|
18 |
+
return cv2.bitwise_or(opImgx, opImgy)
|
19 |
+
|
20 |
+
|
21 |
+
def sketch(frame):
|
22 |
+
frame = cv2.GaussianBlur(frame, (3, 3), 0)
|
23 |
+
invImg = 255 - frame
|
24 |
+
edgImg0 = sobel(frame)
|
25 |
+
edgImg1 = sobel(invImg)
|
26 |
+
edgImg = cv2.addWeighted(edgImg0, 0.75, edgImg1, 0.75, 0)
|
27 |
+
opImg = 255 - edgImg
|
28 |
+
return opImg
|
29 |
+
|
30 |
+
|
31 |
+
def get_sketch_image(image_path):
|
32 |
+
original = cv2.imread(image_path)
|
33 |
+
original = cv2.cvtColor(original, cv2.COLOR_BGR2GRAY)
|
34 |
+
sketch_image = sketch(original)
|
35 |
+
return sketch_image[:, :, np.newaxis]
|
36 |
+
|
37 |
+
|
38 |
+
use_cuda = True
|
39 |
+
|
40 |
+
cache = load_lua("/path/to/sketch_gan.t7")
|
41 |
+
model = cache.model
|
42 |
+
immean = cache.mean
|
43 |
+
imstd = cache.std
|
44 |
+
model.evaluate()
|
45 |
+
|
46 |
+
data_path = "/path/to/data/imgs"
|
47 |
+
images = [os.path.join(data_path, f) for f in os.listdir(data_path)]
|
48 |
+
|
49 |
+
output_dir = "/path/to/data/edges"
|
50 |
+
if not os.path.exists(output_dir):
|
51 |
+
os.makedirs(output_dir)
|
52 |
+
|
53 |
+
for idx, image_path in enumerate(images):
|
54 |
+
if idx % 50 == 0:
|
55 |
+
print("{} out of {}".format(idx, len(images)))
|
56 |
+
data = get_sketch_image(image_path)
|
57 |
+
data = ((transforms.ToTensor()(data) - immean) / imstd).unsqueeze(0)
|
58 |
+
if use_cuda:
|
59 |
+
pred = model.cuda().forward(data.cuda()).float()
|
60 |
+
else:
|
61 |
+
pred = model.forward(data)
|
62 |
+
save_image(pred[0], os.path.join(output_dir, "{}_edges.jpg".format(image_path.split("/")[-1].split('.')[0])))
|
scripts/inference.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from argparse import Namespace
|
3 |
+
|
4 |
+
from tqdm import tqdm
|
5 |
+
import time
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
import sys
|
11 |
+
|
12 |
+
sys.path.append(".")
|
13 |
+
sys.path.append("..")
|
14 |
+
|
15 |
+
from configs import data_configs
|
16 |
+
from datasets.inference_dataset import InferenceDataset
|
17 |
+
from utils.common import tensor2im, log_input_image
|
18 |
+
from options.test_options import TestOptions
|
19 |
+
from models.psp import pSp
|
20 |
+
|
21 |
+
|
22 |
+
def run():
|
23 |
+
test_opts = TestOptions().parse()
|
24 |
+
|
25 |
+
if test_opts.resize_factors is not None:
|
26 |
+
assert len(
|
27 |
+
test_opts.resize_factors.split(',')) == 1, "When running inference, provide a single downsampling factor!"
|
28 |
+
out_path_results = os.path.join(test_opts.exp_dir, 'inference_results',
|
29 |
+
'downsampling_{}'.format(test_opts.resize_factors))
|
30 |
+
out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled',
|
31 |
+
'downsampling_{}'.format(test_opts.resize_factors))
|
32 |
+
else:
|
33 |
+
out_path_results = os.path.join(test_opts.exp_dir, 'inference_results')
|
34 |
+
out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled')
|
35 |
+
|
36 |
+
os.makedirs(out_path_results, exist_ok=True)
|
37 |
+
os.makedirs(out_path_coupled, exist_ok=True)
|
38 |
+
|
39 |
+
# update test options with options used during training
|
40 |
+
ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
|
41 |
+
opts = ckpt['opts']
|
42 |
+
opts.update(vars(test_opts))
|
43 |
+
if 'learn_in_w' not in opts:
|
44 |
+
opts['learn_in_w'] = False
|
45 |
+
if 'output_size' not in opts:
|
46 |
+
opts['output_size'] = 1024
|
47 |
+
opts = Namespace(**opts)
|
48 |
+
|
49 |
+
net = pSp(opts)
|
50 |
+
net.eval()
|
51 |
+
net.cuda()
|
52 |
+
|
53 |
+
print('Loading dataset for {}'.format(opts.dataset_type))
|
54 |
+
dataset_args = data_configs.DATASETS[opts.dataset_type]
|
55 |
+
transforms_dict = dataset_args['transforms'](opts).get_transforms()
|
56 |
+
dataset = InferenceDataset(root=opts.data_path,
|
57 |
+
transform=transforms_dict['transform_inference'],
|
58 |
+
opts=opts)
|
59 |
+
dataloader = DataLoader(dataset,
|
60 |
+
batch_size=opts.test_batch_size,
|
61 |
+
shuffle=False,
|
62 |
+
num_workers=int(opts.test_workers),
|
63 |
+
drop_last=True)
|
64 |
+
|
65 |
+
if opts.n_images is None:
|
66 |
+
opts.n_images = len(dataset)
|
67 |
+
|
68 |
+
global_i = 0
|
69 |
+
global_time = []
|
70 |
+
for input_batch in tqdm(dataloader):
|
71 |
+
if global_i >= opts.n_images:
|
72 |
+
break
|
73 |
+
with torch.no_grad():
|
74 |
+
input_cuda = input_batch.cuda().float()
|
75 |
+
tic = time.time()
|
76 |
+
result_batch = run_on_batch(input_cuda, net, opts)
|
77 |
+
toc = time.time()
|
78 |
+
global_time.append(toc - tic)
|
79 |
+
|
80 |
+
for i in range(opts.test_batch_size):
|
81 |
+
result = tensor2im(result_batch[i])
|
82 |
+
im_path = dataset.paths[global_i]
|
83 |
+
|
84 |
+
if opts.couple_outputs or global_i % 100 == 0:
|
85 |
+
input_im = log_input_image(input_batch[i], opts)
|
86 |
+
resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size)
|
87 |
+
if opts.resize_factors is not None:
|
88 |
+
# for super resolution, save the original, down-sampled, and output
|
89 |
+
source = Image.open(im_path)
|
90 |
+
res = np.concatenate([np.array(source.resize(resize_amount)),
|
91 |
+
np.array(input_im.resize(resize_amount, resample=Image.NEAREST)),
|
92 |
+
np.array(result.resize(resize_amount))], axis=1)
|
93 |
+
else:
|
94 |
+
# otherwise, save the original and output
|
95 |
+
res = np.concatenate([np.array(input_im.resize(resize_amount)),
|
96 |
+
np.array(result.resize(resize_amount))], axis=1)
|
97 |
+
Image.fromarray(res).save(os.path.join(out_path_coupled, os.path.basename(im_path)))
|
98 |
+
|
99 |
+
im_save_path = os.path.join(out_path_results, os.path.basename(im_path))
|
100 |
+
Image.fromarray(np.array(result)).save(im_save_path)
|
101 |
+
|
102 |
+
global_i += 1
|
103 |
+
|
104 |
+
stats_path = os.path.join(opts.exp_dir, 'stats.txt')
|
105 |
+
result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time))
|
106 |
+
print(result_str)
|
107 |
+
|
108 |
+
with open(stats_path, 'w') as f:
|
109 |
+
f.write(result_str)
|
110 |
+
|
111 |
+
|
112 |
+
def run_on_batch(inputs, net, opts):
|
113 |
+
if opts.latent_mask is None:
|
114 |
+
result_batch = net(inputs, randomize_noise=False, resize=opts.resize_outputs)
|
115 |
+
else:
|
116 |
+
latent_mask = [int(l) for l in opts.latent_mask.split(",")]
|
117 |
+
result_batch = []
|
118 |
+
for image_idx, input_image in enumerate(inputs):
|
119 |
+
# get latent vector to inject into our input image
|
120 |
+
vec_to_inject = np.random.randn(1, 512).astype('float32')
|
121 |
+
_, latent_to_inject = net(torch.from_numpy(vec_to_inject).to("cuda"),
|
122 |
+
input_code=True,
|
123 |
+
return_latents=True)
|
124 |
+
# get output image with injected style vector
|
125 |
+
res = net(input_image.unsqueeze(0).to("cuda").float(),
|
126 |
+
latent_mask=latent_mask,
|
127 |
+
inject_latent=latent_to_inject,
|
128 |
+
alpha=opts.mix_alpha,
|
129 |
+
resize=opts.resize_outputs)
|
130 |
+
result_batch.append(res)
|
131 |
+
result_batch = torch.cat(result_batch, dim=0)
|
132 |
+
return result_batch
|
133 |
+
|
134 |
+
|
135 |
+
if __name__ == '__main__':
|
136 |
+
run()
|
scripts/style_mixing.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from argparse import Namespace
|
3 |
+
|
4 |
+
from tqdm import tqdm
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
import sys
|
10 |
+
|
11 |
+
sys.path.append(".")
|
12 |
+
sys.path.append("..")
|
13 |
+
|
14 |
+
from configs import data_configs
|
15 |
+
from datasets.inference_dataset import InferenceDataset
|
16 |
+
from utils.common import tensor2im, log_input_image
|
17 |
+
from options.test_options import TestOptions
|
18 |
+
from models.psp import pSp
|
19 |
+
|
20 |
+
|
21 |
+
def run():
|
22 |
+
test_opts = TestOptions().parse()
|
23 |
+
|
24 |
+
if test_opts.resize_factors is not None:
|
25 |
+
factors = test_opts.resize_factors.split(',')
|
26 |
+
assert len(factors) == 1, "When running inference, please provide a single downsampling factor!"
|
27 |
+
mixed_path_results = os.path.join(test_opts.exp_dir, 'style_mixing',
|
28 |
+
'downsampling_{}'.format(test_opts.resize_factors))
|
29 |
+
else:
|
30 |
+
mixed_path_results = os.path.join(test_opts.exp_dir, 'style_mixing')
|
31 |
+
os.makedirs(mixed_path_results, exist_ok=True)
|
32 |
+
|
33 |
+
# update test options with options used during training
|
34 |
+
ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
|
35 |
+
opts = ckpt['opts']
|
36 |
+
opts.update(vars(test_opts))
|
37 |
+
if 'learn_in_w' not in opts:
|
38 |
+
opts['learn_in_w'] = False
|
39 |
+
if 'output_size' not in opts:
|
40 |
+
opts['output_size'] = 1024
|
41 |
+
opts = Namespace(**opts)
|
42 |
+
|
43 |
+
net = pSp(opts)
|
44 |
+
net.eval()
|
45 |
+
net.cuda()
|
46 |
+
|
47 |
+
print('Loading dataset for {}'.format(opts.dataset_type))
|
48 |
+
dataset_args = data_configs.DATASETS[opts.dataset_type]
|
49 |
+
transforms_dict = dataset_args['transforms'](opts).get_transforms()
|
50 |
+
dataset = InferenceDataset(root=opts.data_path,
|
51 |
+
transform=transforms_dict['transform_inference'],
|
52 |
+
opts=opts)
|
53 |
+
dataloader = DataLoader(dataset,
|
54 |
+
batch_size=opts.test_batch_size,
|
55 |
+
shuffle=False,
|
56 |
+
num_workers=int(opts.test_workers),
|
57 |
+
drop_last=True)
|
58 |
+
|
59 |
+
latent_mask = [int(l) for l in opts.latent_mask.split(",")]
|
60 |
+
if opts.n_images is None:
|
61 |
+
opts.n_images = len(dataset)
|
62 |
+
|
63 |
+
global_i = 0
|
64 |
+
for input_batch in tqdm(dataloader):
|
65 |
+
if global_i >= opts.n_images:
|
66 |
+
break
|
67 |
+
with torch.no_grad():
|
68 |
+
input_batch = input_batch.cuda()
|
69 |
+
for image_idx, input_image in enumerate(input_batch):
|
70 |
+
# generate random vectors to inject into input image
|
71 |
+
vecs_to_inject = np.random.randn(opts.n_outputs_to_generate, 512).astype('float32')
|
72 |
+
multi_modal_outputs = []
|
73 |
+
for vec_to_inject in vecs_to_inject:
|
74 |
+
cur_vec = torch.from_numpy(vec_to_inject).unsqueeze(0).to("cuda")
|
75 |
+
# get latent vector to inject into our input image
|
76 |
+
_, latent_to_inject = net(cur_vec,
|
77 |
+
input_code=True,
|
78 |
+
return_latents=True)
|
79 |
+
# get output image with injected style vector
|
80 |
+
res = net(input_image.unsqueeze(0).to("cuda").float(),
|
81 |
+
latent_mask=latent_mask,
|
82 |
+
inject_latent=latent_to_inject,
|
83 |
+
alpha=opts.mix_alpha,
|
84 |
+
resize=opts.resize_outputs)
|
85 |
+
multi_modal_outputs.append(res[0])
|
86 |
+
|
87 |
+
# visualize multi modal outputs
|
88 |
+
input_im_path = dataset.paths[global_i]
|
89 |
+
image = input_batch[image_idx]
|
90 |
+
input_image = log_input_image(image, opts)
|
91 |
+
resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size)
|
92 |
+
res = np.array(input_image.resize(resize_amount))
|
93 |
+
for output in multi_modal_outputs:
|
94 |
+
output = tensor2im(output)
|
95 |
+
res = np.concatenate([res, np.array(output.resize(resize_amount))], axis=1)
|
96 |
+
Image.fromarray(res).save(os.path.join(mixed_path_results, os.path.basename(input_im_path)))
|
97 |
+
global_i += 1
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == '__main__':
|
101 |
+
run()
|
scripts/train.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file runs the main training/val loop
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
import sys
|
7 |
+
import pprint
|
8 |
+
|
9 |
+
sys.path.append(".")
|
10 |
+
sys.path.append("..")
|
11 |
+
|
12 |
+
from options.train_options import TrainOptions
|
13 |
+
from training.coach import Coach
|
14 |
+
|
15 |
+
|
16 |
+
def main():
|
17 |
+
opts = TrainOptions().parse()
|
18 |
+
if os.path.exists(opts.exp_dir):
|
19 |
+
raise Exception('Oops... {} already exists'.format(opts.exp_dir))
|
20 |
+
os.makedirs(opts.exp_dir)
|
21 |
+
|
22 |
+
opts_dict = vars(opts)
|
23 |
+
pprint.pprint(opts_dict)
|
24 |
+
with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f:
|
25 |
+
json.dump(opts_dict, f, indent=4, sort_keys=True)
|
26 |
+
|
27 |
+
coach = Coach(opts)
|
28 |
+
coach.train()
|
29 |
+
|
30 |
+
|
31 |
+
if __name__ == '__main__':
|
32 |
+
main()
|