Spaces:
Build error
Build error
complete image and video
Browse files- api.py +248 -0
- augmentation.py +345 -0
- config/vox-256.yaml +93 -0
- modules/__pycache__/dense_motion.cpython-38.pyc +0 -0
- modules/__pycache__/generator.cpython-38.pyc +0 -0
- modules/__pycache__/keypoint_detector.cpython-38.pyc +0 -0
- modules/__pycache__/nerf_verts_util.cpython-38.pyc +0 -0
- modules/__pycache__/util.cpython-38.pyc +0 -0
- modules/dense_motion.py +114 -0
- modules/generator.py +211 -0
- modules/keypoint_detector.py +76 -0
- modules/nerf_verts_util.py +227 -0
- modules/util.py +245 -0
- requirements.txt +93 -0
- sup-mat/driving.mp4 +0 -0
- sup-mat/driving.png +0 -0
- sup-mat/source.png +0 -0
- sup-mat/source_for_video.png +0 -0
- sync_batchnorm/__init__.py +12 -0
- sync_batchnorm/__pycache__/__init__.cpython-38.pyc +0 -0
- sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc +0 -0
- sync_batchnorm/__pycache__/comm.cpython-38.pyc +0 -0
- sync_batchnorm/__pycache__/replicate.cpython-38.pyc +0 -0
- sync_batchnorm/batchnorm.py +315 -0
- sync_batchnorm/comm.py +137 -0
- sync_batchnorm/replicate.py +94 -0
- sync_batchnorm/unittest.py +29 -0
api.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import yaml
|
4 |
+
from argparse import ArgumentParser
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import imageio
|
9 |
+
from skimage.transform import resize
|
10 |
+
from skimage import img_as_ubyte
|
11 |
+
from scipy.spatial import ConvexHull
|
12 |
+
import torch
|
13 |
+
from sync_batchnorm import DataParallelWithCallback
|
14 |
+
import face_alignment
|
15 |
+
|
16 |
+
from modules.generator import OcclusionAwareGenerator_SPADE
|
17 |
+
from modules.keypoint_detector import KPDetector
|
18 |
+
|
19 |
+
|
20 |
+
def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
|
21 |
+
use_relative_movement=False, use_relative_jacobian=False):
|
22 |
+
kp_new = {k: v for k, v in kp_driving.items()}
|
23 |
+
|
24 |
+
if adapt_movement_scale:
|
25 |
+
source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
|
26 |
+
driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
|
27 |
+
adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
|
28 |
+
kp_new['value'] = kp_driving['value'] * adapt_movement_scale # for reenactment demo
|
29 |
+
else:
|
30 |
+
adapt_movement_scale = 1
|
31 |
+
|
32 |
+
if use_relative_movement:
|
33 |
+
kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
|
34 |
+
kp_value_diff *= adapt_movement_scale
|
35 |
+
kp_new['value'] = kp_value_diff + kp_source['value']
|
36 |
+
|
37 |
+
if use_relative_jacobian:
|
38 |
+
jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
|
39 |
+
kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
|
40 |
+
|
41 |
+
return kp_new
|
42 |
+
|
43 |
+
|
44 |
+
def load_checkpoints(config_path, checkpoint_path, cpu=False):
|
45 |
+
with open(config_path) as f:
|
46 |
+
# config = yaml.load(f)
|
47 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
48 |
+
|
49 |
+
generator = OcclusionAwareGenerator_SPADE(**config['model_params']['generator_params'],
|
50 |
+
**config['model_params']['common_params'])
|
51 |
+
if not cpu:
|
52 |
+
generator.cuda()
|
53 |
+
|
54 |
+
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
|
55 |
+
**config['model_params']['common_params'])
|
56 |
+
if not cpu:
|
57 |
+
kp_detector.cuda()
|
58 |
+
|
59 |
+
if cpu:
|
60 |
+
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
61 |
+
else:
|
62 |
+
checkpoint = torch.load(checkpoint_path)
|
63 |
+
|
64 |
+
generator.load_state_dict(checkpoint['generator'])
|
65 |
+
kp_detector.load_state_dict(checkpoint['kp_detector'])
|
66 |
+
|
67 |
+
if not cpu:
|
68 |
+
generator = DataParallelWithCallback(generator)
|
69 |
+
kp_detector = DataParallelWithCallback(kp_detector)
|
70 |
+
|
71 |
+
generator.eval()
|
72 |
+
kp_detector.eval()
|
73 |
+
|
74 |
+
return generator, kp_detector
|
75 |
+
|
76 |
+
|
77 |
+
def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True,
|
78 |
+
cpu=False):
|
79 |
+
with torch.no_grad():
|
80 |
+
predictions = []
|
81 |
+
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
|
82 |
+
if not cpu:
|
83 |
+
source = source.cuda()
|
84 |
+
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
|
85 |
+
kp_source = kp_detector(source)
|
86 |
+
kp_driving_initial = kp_detector(driving[:, :, 0])
|
87 |
+
|
88 |
+
for frame_idx in tqdm(range(driving.shape[2])):
|
89 |
+
driving_frame = driving[:, :, frame_idx]
|
90 |
+
if not cpu:
|
91 |
+
driving_frame = driving_frame.cuda()
|
92 |
+
kp_driving = kp_detector(driving_frame)
|
93 |
+
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
|
94 |
+
kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
|
95 |
+
use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
|
96 |
+
out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
|
97 |
+
|
98 |
+
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
99 |
+
return predictions
|
100 |
+
|
101 |
+
|
102 |
+
def find_best_frame_func(source, driving, cpu=False):
|
103 |
+
def normalize_kp_infunc(kp):
|
104 |
+
kp = kp - kp.mean(axis=0, keepdims=True)
|
105 |
+
area = ConvexHull(kp[:, :2]).volume
|
106 |
+
area = np.sqrt(area)
|
107 |
+
kp[:, :2] = kp[:, :2] / area
|
108 |
+
return kp
|
109 |
+
|
110 |
+
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
|
111 |
+
device='cpu' if cpu else 'cuda')
|
112 |
+
kp_source = fa.get_landmarks(255 * source)[0]
|
113 |
+
kp_source = normalize_kp_infunc(kp_source)
|
114 |
+
norm = float('inf')
|
115 |
+
frame_num = 0
|
116 |
+
for i, image in tqdm(enumerate(driving)):
|
117 |
+
kp_driving = fa.get_landmarks(255 * image)[0]
|
118 |
+
kp_driving = normalize_kp_infunc(kp_driving)
|
119 |
+
new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
|
120 |
+
if new_norm < norm:
|
121 |
+
norm = new_norm
|
122 |
+
frame_num = i
|
123 |
+
return frame_num
|
124 |
+
|
125 |
+
|
126 |
+
def drive_im(source_image, driving_image, adapt_scale):
|
127 |
+
source_image = resize(source_image, (256, 256))[..., :3]
|
128 |
+
driving_image = [resize(driving_image, (256, 256))[..., :3]]
|
129 |
+
|
130 |
+
prediction = make_animation(source_image, driving_image, generator, kp_detector, relative=False,
|
131 |
+
adapt_movement_scale=adapt_scale, cpu=cpu)
|
132 |
+
return img_as_ubyte(prediction[0])
|
133 |
+
|
134 |
+
|
135 |
+
def drive_vi(source_image, driving_video, mode, find_best_frame, relative, adapt_scale):
|
136 |
+
reader = imageio.get_reader(driving_video)
|
137 |
+
fps = reader.get_meta_data()['fps']
|
138 |
+
driving_video = []
|
139 |
+
try:
|
140 |
+
for im in reader:
|
141 |
+
driving_video.append(im)
|
142 |
+
except RuntimeError:
|
143 |
+
pass
|
144 |
+
reader.close()
|
145 |
+
|
146 |
+
|
147 |
+
if mode == 'reconstruction':
|
148 |
+
source_image = driving_video[0]
|
149 |
+
|
150 |
+
source_image = resize(source_image, (256, 256))[..., :3]
|
151 |
+
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
|
152 |
+
|
153 |
+
if find_best_frame:
|
154 |
+
i = find_best_frame_func(source_image, driving_video, cpu=cpu)
|
155 |
+
print("Best frame: " + str(i))
|
156 |
+
driving_forward = driving_video[i:]
|
157 |
+
driving_backward = driving_video[:(i + 1)][::-1]
|
158 |
+
predictions_forward = make_animation(source_image, driving_forward, generator, kp_detector,
|
159 |
+
relative=relative, adapt_movement_scale=adapt_scale, cpu=cpu)
|
160 |
+
predictions_backward = make_animation(source_image, driving_backward, generator, kp_detector,
|
161 |
+
relative=relative, adapt_movement_scale=adapt_scale, cpu=cpu)
|
162 |
+
predictions = predictions_backward[::-1] + predictions_forward[1:]
|
163 |
+
else:
|
164 |
+
predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=relative,
|
165 |
+
adapt_movement_scale=adapt_scale, cpu=cpu)
|
166 |
+
result_video_path = "result_video.mp4"
|
167 |
+
imageio.mimsave(result_video_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)
|
168 |
+
return result_video_path
|
169 |
+
|
170 |
+
|
171 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = "3"
|
172 |
+
config = "config/vox-256.yaml"
|
173 |
+
checkpoint = "00000099-checkpoint.pth.tar"
|
174 |
+
mode = "reenactment"
|
175 |
+
relative = True
|
176 |
+
adapt_scale = True
|
177 |
+
find_best_frame = True
|
178 |
+
cpu = False # decided by the deploying environment
|
179 |
+
|
180 |
+
description = "We propose a Face Neural Volume Rendering (FNeVR) network for more realistic face animation, by taking the merits of 2D motion warping on facial expression transformation and 3D volume rendering on high-quality image synthesis in a unified framework.<br>[Paper](https://arxiv.org/abs/2209.10340) and [Code](https://github.com/zengbohan0217/FNeVR)"
|
181 |
+
im_description = "We can animate a face portrait by a single image in this tab.<br>Please input the origin face and the driving face which provides pose and expression information, then we can obtain the virtual generated face.<br>We can select \"adaptive scale\" parameter for better optic flow estimation using adaptive movement scale based on convex hull of keypoints."
|
182 |
+
vi_description = "We can animate a face portrait by a video in this tab.<br>Please input the origin face and the driving video which provides pose and expression information, then we can obtain the virtual generated video.<br>Please select inference mode (reenactment for different identities and reconstruction for the same identities).<br>We can select \"find best frame\" parameter to generate video from the frame that is the most alligned with source image, select \"relative motion\" paramter to use relative keypoint coordinates for preserving global object geometry, and select \"adaptive scale\" parameter for better optic flow estimation using adaptive movement scale based on convex hull of keypoints."
|
183 |
+
acknowledgements = "This work was supported by “the Fundamental Research Funds for the Central Universities”, and the National Natural Science Foundation of China under Grant 62076016, Beijing Natural Science Foundation-Xiaomi Innovation Joint Fund L223024. Besides, we gratefully acknowledge the support of [MindSpore](https://www.mindspore.cn), CANN (Compute Architecture for Neural Networks) and Ascend AI processor used for this research.<br>Our FNeVR implementation is inspired by [FOMM](https://github.com/AliaksandrSiarohin/first-order-model) and [DECA](https://github.com/YadiraF/DECA). We appreciate the authors of these papers for making their codes available to the public."
|
184 |
+
|
185 |
+
generator, kp_detector = load_checkpoints(config_path=config, checkpoint_path=checkpoint, cpu=cpu)
|
186 |
+
|
187 |
+
# iface = gr.Interface(fn=drive_im,
|
188 |
+
# inputs=[gr.Image(label="Origin face"),
|
189 |
+
# gr.Image(label="Driving face"),
|
190 |
+
# gr.CheckboxGroup(label="adapt scale")],
|
191 |
+
# outputs=gr.Image(label="Generated face"), examples=[["sup-mat/source.png"], ["sup-mat/driving.png"]],
|
192 |
+
# title="Demostration of FNeVR", description=description)
|
193 |
+
|
194 |
+
with gr.Blocks(title="Demostration of FNeVR") as demo:
|
195 |
+
gr.Markdown("# <center> Demostration of FNeVR")
|
196 |
+
gr.Markdown(description)
|
197 |
+
|
198 |
+
with gr.Tab("Driving by image"):
|
199 |
+
gr.Markdown(im_description)
|
200 |
+
|
201 |
+
with gr.Row():
|
202 |
+
with gr.Column():
|
203 |
+
gr.Markdown("#### Inputs")
|
204 |
+
inp2 = gr.Image(label="Driving face")
|
205 |
+
inp1 = gr.Image(label="Origin face")
|
206 |
+
|
207 |
+
gr.Markdown("#### Parameter")
|
208 |
+
inp3 = gr.Checkbox(value=True, label="adaptive scale")
|
209 |
+
|
210 |
+
btn = gr.Button(value="Animate")
|
211 |
+
with gr.Column():
|
212 |
+
gr.Markdown("#### Output")
|
213 |
+
outp = gr.Image(label="Generated face")
|
214 |
+
|
215 |
+
gr.Examples([["sup-mat/driving.png", "sup-mat/source.png"]], [inp2, inp1])
|
216 |
+
|
217 |
+
btn.click(fn=drive_im, inputs=[inp1, inp2, inp3], outputs=outp)
|
218 |
+
with gr.Tab("Driving by video"):
|
219 |
+
gr.Markdown(vi_description)
|
220 |
+
|
221 |
+
with gr.Row():
|
222 |
+
with gr.Column():
|
223 |
+
gr.Markdown("#### Inputs")
|
224 |
+
inp2 = gr.Video(label="Driving video")
|
225 |
+
inp1 = gr.Image(label="Origin face")
|
226 |
+
|
227 |
+
gr.Markdown("#### Parameters")
|
228 |
+
inp3 = gr.Radio(choices=['reenactment', 'reconstruction'], value="reenactment", label="mode (if \"reconstruction\" selected, origin face is the first frame of driving video)")
|
229 |
+
inp4 = gr.Checkbox(value=True, label="find best frame (more time consumed)")
|
230 |
+
inp5 = gr.Checkbox(value=True, label="relative motion")
|
231 |
+
inp6 = gr.Checkbox(value=True, label="adaptive scale")
|
232 |
+
|
233 |
+
btn = gr.Button(value="Animate")
|
234 |
+
with gr.Column():
|
235 |
+
gr.Markdown("#### Output")
|
236 |
+
outp = gr.Video(label="Generated video")
|
237 |
+
|
238 |
+
gr.Examples([["sup-mat/driving.mp4", "sup-mat/source_for_video.png"]], [inp2, inp1])
|
239 |
+
|
240 |
+
btn.click(fn=drive_vi, inputs=[inp1, inp2, inp3, inp4, inp5, inp6], outputs=outp)
|
241 |
+
|
242 |
+
with gr.Tab("Real time animation"):
|
243 |
+
gr.Markdown("#### Real time animation")
|
244 |
+
|
245 |
+
gr.Markdown("## Acknowledgements")
|
246 |
+
gr.Markdown(acknowledgements)
|
247 |
+
|
248 |
+
demo.launch()
|
augmentation.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code from https://github.com/hassony2/torch_videovision
|
3 |
+
"""
|
4 |
+
|
5 |
+
import numbers
|
6 |
+
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
import PIL
|
10 |
+
|
11 |
+
from skimage.transform import resize, rotate
|
12 |
+
from skimage.util import pad
|
13 |
+
import torchvision
|
14 |
+
|
15 |
+
import warnings
|
16 |
+
|
17 |
+
from skimage import img_as_ubyte, img_as_float
|
18 |
+
|
19 |
+
|
20 |
+
def crop_clip(clip, min_h, min_w, h, w):
|
21 |
+
if isinstance(clip[0], np.ndarray):
|
22 |
+
cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
|
23 |
+
|
24 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
25 |
+
cropped = [
|
26 |
+
img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
|
27 |
+
]
|
28 |
+
else:
|
29 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
30 |
+
'but got list of {0}'.format(type(clip[0])))
|
31 |
+
return cropped
|
32 |
+
|
33 |
+
|
34 |
+
def pad_clip(clip, h, w):
|
35 |
+
im_h, im_w = clip[0].shape[:2]
|
36 |
+
pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
|
37 |
+
pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)
|
38 |
+
|
39 |
+
return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')
|
40 |
+
|
41 |
+
|
42 |
+
def resize_clip(clip, size, interpolation='bilinear'):
|
43 |
+
if isinstance(clip[0], np.ndarray):
|
44 |
+
if isinstance(size, numbers.Number):
|
45 |
+
im_h, im_w, im_c = clip[0].shape
|
46 |
+
# Min spatial dim already matches minimal size
|
47 |
+
if (im_w <= im_h and im_w == size) or (im_h <= im_w
|
48 |
+
and im_h == size):
|
49 |
+
return clip
|
50 |
+
new_h, new_w = get_resize_sizes(im_h, im_w, size)
|
51 |
+
size = (new_w, new_h)
|
52 |
+
else:
|
53 |
+
size = size[1], size[0]
|
54 |
+
|
55 |
+
scaled = [
|
56 |
+
resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
|
57 |
+
mode='constant', anti_aliasing=True) for img in clip
|
58 |
+
]
|
59 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
60 |
+
if isinstance(size, numbers.Number):
|
61 |
+
im_w, im_h = clip[0].size
|
62 |
+
# Min spatial dim already matches minimal size
|
63 |
+
if (im_w <= im_h and im_w == size) or (im_h <= im_w
|
64 |
+
and im_h == size):
|
65 |
+
return clip
|
66 |
+
new_h, new_w = get_resize_sizes(im_h, im_w, size)
|
67 |
+
size = (new_w, new_h)
|
68 |
+
else:
|
69 |
+
size = size[1], size[0]
|
70 |
+
if interpolation == 'bilinear':
|
71 |
+
pil_inter = PIL.Image.NEAREST
|
72 |
+
else:
|
73 |
+
pil_inter = PIL.Image.BILINEAR
|
74 |
+
scaled = [img.resize(size, pil_inter) for img in clip]
|
75 |
+
else:
|
76 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
77 |
+
'but got list of {0}'.format(type(clip[0])))
|
78 |
+
return scaled
|
79 |
+
|
80 |
+
|
81 |
+
def get_resize_sizes(im_h, im_w, size):
|
82 |
+
if im_w < im_h:
|
83 |
+
ow = size
|
84 |
+
oh = int(size * im_h / im_w)
|
85 |
+
else:
|
86 |
+
oh = size
|
87 |
+
ow = int(size * im_w / im_h)
|
88 |
+
return oh, ow
|
89 |
+
|
90 |
+
|
91 |
+
class RandomFlip(object):
|
92 |
+
def __init__(self, time_flip=False, horizontal_flip=False):
|
93 |
+
self.time_flip = time_flip
|
94 |
+
self.horizontal_flip = horizontal_flip
|
95 |
+
|
96 |
+
def __call__(self, clip):
|
97 |
+
if random.random() < 0.5 and self.time_flip:
|
98 |
+
return clip[::-1]
|
99 |
+
if random.random() < 0.5 and self.horizontal_flip:
|
100 |
+
return [np.fliplr(img) for img in clip]
|
101 |
+
|
102 |
+
return clip
|
103 |
+
|
104 |
+
|
105 |
+
class RandomResize(object):
|
106 |
+
"""Resizes a list of (H x W x C) numpy.ndarray to the final size
|
107 |
+
The larger the original image is, the more times it takes to
|
108 |
+
interpolate
|
109 |
+
Args:
|
110 |
+
interpolation (str): Can be one of 'nearest', 'bilinear'
|
111 |
+
defaults to nearest
|
112 |
+
size (tuple): (widht, height)
|
113 |
+
"""
|
114 |
+
|
115 |
+
def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
|
116 |
+
self.ratio = ratio
|
117 |
+
self.interpolation = interpolation
|
118 |
+
|
119 |
+
def __call__(self, clip):
|
120 |
+
scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
|
121 |
+
|
122 |
+
if isinstance(clip[0], np.ndarray):
|
123 |
+
im_h, im_w, im_c = clip[0].shape
|
124 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
125 |
+
im_w, im_h = clip[0].size
|
126 |
+
|
127 |
+
new_w = int(im_w * scaling_factor)
|
128 |
+
new_h = int(im_h * scaling_factor)
|
129 |
+
new_size = (new_w, new_h)
|
130 |
+
resized = resize_clip(
|
131 |
+
clip, new_size, interpolation=self.interpolation)
|
132 |
+
|
133 |
+
return resized
|
134 |
+
|
135 |
+
|
136 |
+
class RandomCrop(object):
|
137 |
+
"""Extract random crop at the same location for a list of videos
|
138 |
+
Args:
|
139 |
+
size (sequence or int): Desired output size for the
|
140 |
+
crop in format (h, w)
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(self, size):
|
144 |
+
if isinstance(size, numbers.Number):
|
145 |
+
size = (size, size)
|
146 |
+
|
147 |
+
self.size = size
|
148 |
+
|
149 |
+
def __call__(self, clip):
|
150 |
+
"""
|
151 |
+
Args:
|
152 |
+
img (PIL.Image or numpy.ndarray): List of videos to be cropped
|
153 |
+
in format (h, w, c) in numpy.ndarray
|
154 |
+
Returns:
|
155 |
+
PIL.Image or numpy.ndarray: Cropped list of videos
|
156 |
+
"""
|
157 |
+
h, w = self.size
|
158 |
+
if isinstance(clip[0], np.ndarray):
|
159 |
+
im_h, im_w, im_c = clip[0].shape
|
160 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
161 |
+
im_w, im_h = clip[0].size
|
162 |
+
else:
|
163 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
164 |
+
'but got list of {0}'.format(type(clip[0])))
|
165 |
+
|
166 |
+
clip = pad_clip(clip, h, w)
|
167 |
+
im_h, im_w = clip.shape[1:3]
|
168 |
+
x1 = 0 if h == im_h else random.randint(0, im_w - w)
|
169 |
+
y1 = 0 if w == im_w else random.randint(0, im_h - h)
|
170 |
+
cropped = crop_clip(clip, y1, x1, h, w)
|
171 |
+
|
172 |
+
return cropped
|
173 |
+
|
174 |
+
|
175 |
+
class RandomRotation(object):
|
176 |
+
"""Rotate entire clip randomly by a random angle within
|
177 |
+
given bounds
|
178 |
+
Args:
|
179 |
+
degrees (sequence or int): Range of degrees to select from
|
180 |
+
If degrees is a number instead of sequence like (min, max),
|
181 |
+
the range of degrees, will be (-degrees, +degrees).
|
182 |
+
"""
|
183 |
+
|
184 |
+
def __init__(self, degrees):
|
185 |
+
if isinstance(degrees, numbers.Number):
|
186 |
+
if degrees < 0:
|
187 |
+
raise ValueError('If degrees is a single number,'
|
188 |
+
'must be positive')
|
189 |
+
degrees = (-degrees, degrees)
|
190 |
+
else:
|
191 |
+
if len(degrees) != 2:
|
192 |
+
raise ValueError('If degrees is a sequence,'
|
193 |
+
'it must be of len 2.')
|
194 |
+
|
195 |
+
self.degrees = degrees
|
196 |
+
|
197 |
+
def __call__(self, clip):
|
198 |
+
"""
|
199 |
+
Args:
|
200 |
+
img (PIL.Image or numpy.ndarray): List of videos to be cropped
|
201 |
+
in format (h, w, c) in numpy.ndarray
|
202 |
+
Returns:
|
203 |
+
PIL.Image or numpy.ndarray: Cropped list of videos
|
204 |
+
"""
|
205 |
+
angle = random.uniform(self.degrees[0], self.degrees[1])
|
206 |
+
if isinstance(clip[0], np.ndarray):
|
207 |
+
rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
|
208 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
209 |
+
rotated = [img.rotate(angle) for img in clip]
|
210 |
+
else:
|
211 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
212 |
+
'but got list of {0}'.format(type(clip[0])))
|
213 |
+
|
214 |
+
return rotated
|
215 |
+
|
216 |
+
|
217 |
+
class ColorJitter(object):
|
218 |
+
"""Randomly change the brightness, contrast and saturation and hue of the clip
|
219 |
+
Args:
|
220 |
+
brightness (float): How much to jitter brightness. brightness_factor
|
221 |
+
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
|
222 |
+
contrast (float): How much to jitter contrast. contrast_factor
|
223 |
+
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
|
224 |
+
saturation (float): How much to jitter saturation. saturation_factor
|
225 |
+
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
|
226 |
+
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
|
227 |
+
[-hue, hue]. Should be >=0 and <= 0.5.
|
228 |
+
"""
|
229 |
+
|
230 |
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
231 |
+
self.brightness = brightness
|
232 |
+
self.contrast = contrast
|
233 |
+
self.saturation = saturation
|
234 |
+
self.hue = hue
|
235 |
+
|
236 |
+
def get_params(self, brightness, contrast, saturation, hue):
|
237 |
+
if brightness > 0:
|
238 |
+
brightness_factor = random.uniform(
|
239 |
+
max(0, 1 - brightness), 1 + brightness)
|
240 |
+
else:
|
241 |
+
brightness_factor = None
|
242 |
+
|
243 |
+
if contrast > 0:
|
244 |
+
contrast_factor = random.uniform(
|
245 |
+
max(0, 1 - contrast), 1 + contrast)
|
246 |
+
else:
|
247 |
+
contrast_factor = None
|
248 |
+
|
249 |
+
if saturation > 0:
|
250 |
+
saturation_factor = random.uniform(
|
251 |
+
max(0, 1 - saturation), 1 + saturation)
|
252 |
+
else:
|
253 |
+
saturation_factor = None
|
254 |
+
|
255 |
+
if hue > 0:
|
256 |
+
hue_factor = random.uniform(-hue, hue)
|
257 |
+
else:
|
258 |
+
hue_factor = None
|
259 |
+
return brightness_factor, contrast_factor, saturation_factor, hue_factor
|
260 |
+
|
261 |
+
def __call__(self, clip):
|
262 |
+
"""
|
263 |
+
Args:
|
264 |
+
clip (list): list of PIL.Image
|
265 |
+
Returns:
|
266 |
+
list PIL.Image : list of transformed PIL.Image
|
267 |
+
"""
|
268 |
+
if isinstance(clip[0], np.ndarray):
|
269 |
+
brightness, contrast, saturation, hue = self.get_params(
|
270 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
271 |
+
|
272 |
+
# Create img transform function sequence
|
273 |
+
img_transforms = []
|
274 |
+
if brightness is not None:
|
275 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
|
276 |
+
if saturation is not None:
|
277 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
|
278 |
+
if hue is not None:
|
279 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
|
280 |
+
if contrast is not None:
|
281 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
|
282 |
+
random.shuffle(img_transforms)
|
283 |
+
img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,
|
284 |
+
img_as_float]
|
285 |
+
|
286 |
+
with warnings.catch_warnings():
|
287 |
+
warnings.simplefilter("ignore")
|
288 |
+
jittered_clip = []
|
289 |
+
for img in clip:
|
290 |
+
jittered_img = img
|
291 |
+
for func in img_transforms:
|
292 |
+
jittered_img = func(jittered_img)
|
293 |
+
jittered_clip.append(jittered_img.astype('float32'))
|
294 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
295 |
+
brightness, contrast, saturation, hue = self.get_params(
|
296 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
297 |
+
|
298 |
+
# Create img transform function sequence
|
299 |
+
img_transforms = []
|
300 |
+
if brightness is not None:
|
301 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
|
302 |
+
if saturation is not None:
|
303 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
|
304 |
+
if hue is not None:
|
305 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
|
306 |
+
if contrast is not None:
|
307 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
|
308 |
+
random.shuffle(img_transforms)
|
309 |
+
|
310 |
+
# Apply to all videos
|
311 |
+
jittered_clip = []
|
312 |
+
for img in clip:
|
313 |
+
for func in img_transforms:
|
314 |
+
jittered_img = func(img)
|
315 |
+
jittered_clip.append(jittered_img)
|
316 |
+
|
317 |
+
else:
|
318 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
319 |
+
'but got list of {0}'.format(type(clip[0])))
|
320 |
+
return jittered_clip
|
321 |
+
|
322 |
+
|
323 |
+
class AllAugmentationTransform:
|
324 |
+
def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None):
|
325 |
+
self.transforms = []
|
326 |
+
|
327 |
+
if flip_param is not None:
|
328 |
+
self.transforms.append(RandomFlip(**flip_param))
|
329 |
+
|
330 |
+
if rotation_param is not None:
|
331 |
+
self.transforms.append(RandomRotation(**rotation_param))
|
332 |
+
|
333 |
+
if resize_param is not None:
|
334 |
+
self.transforms.append(RandomResize(**resize_param))
|
335 |
+
|
336 |
+
if crop_param is not None:
|
337 |
+
self.transforms.append(RandomCrop(**crop_param))
|
338 |
+
|
339 |
+
if jitter_param is not None:
|
340 |
+
self.transforms.append(ColorJitter(**jitter_param))
|
341 |
+
|
342 |
+
def __call__(self, clip):
|
343 |
+
for t in self.transforms:
|
344 |
+
clip = t(clip)
|
345 |
+
return clip
|
config/vox-256.yaml
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_params:
|
2 |
+
root_dir: /data/lh/repo/datasets/face-video-preprocessing/vox-png
|
3 |
+
frame_shape: [256, 256, 3]
|
4 |
+
id_sampling: True
|
5 |
+
pairs_list: data/vox256.csv
|
6 |
+
augmentation_params:
|
7 |
+
flip_param:
|
8 |
+
horizontal_flip: True
|
9 |
+
time_flip: True
|
10 |
+
jitter_param:
|
11 |
+
brightness: 0.1
|
12 |
+
contrast: 0.1
|
13 |
+
saturation: 0.1
|
14 |
+
hue: 0.1
|
15 |
+
|
16 |
+
|
17 |
+
model_params:
|
18 |
+
common_params:
|
19 |
+
num_kp: 10
|
20 |
+
num_channels: 3
|
21 |
+
estimate_jacobian: True
|
22 |
+
kp_detector_params:
|
23 |
+
temperature: 0.1
|
24 |
+
block_expansion: 32
|
25 |
+
max_features: 1024
|
26 |
+
scale_factor: 0.25
|
27 |
+
num_blocks: 5
|
28 |
+
generator_params:
|
29 |
+
block_expansion: 64
|
30 |
+
max_features: 512
|
31 |
+
num_down_blocks: 2
|
32 |
+
num_bottleneck_blocks: 6
|
33 |
+
estimate_occlusion_map: True
|
34 |
+
dense_motion_params:
|
35 |
+
block_expansion: 64
|
36 |
+
max_features: 1024
|
37 |
+
num_blocks: 5
|
38 |
+
scale_factor: 0.25
|
39 |
+
render_params:
|
40 |
+
simpled_channel_rgb: 128
|
41 |
+
simpled_channel_sigma: 128
|
42 |
+
floor_num: 8
|
43 |
+
hidden_size: 128
|
44 |
+
discriminator_params:
|
45 |
+
scales: [1]
|
46 |
+
block_expansion: 32
|
47 |
+
max_features: 512
|
48 |
+
num_blocks: 4
|
49 |
+
sn: True
|
50 |
+
|
51 |
+
train_params:
|
52 |
+
num_epochs: 100
|
53 |
+
num_repeats: 75
|
54 |
+
epoch_milestones: [60, 90]
|
55 |
+
lr_generator: 2.0e-4
|
56 |
+
lr_discriminator: 2.0e-4
|
57 |
+
lr_kp_detector: 2.0e-4
|
58 |
+
lr_face_editor: 2.0e-4
|
59 |
+
# batch_size: 40
|
60 |
+
batch_size: 20
|
61 |
+
scales: [1, 0.5, 0.25, 0.125]
|
62 |
+
# checkpoint_freq: 75
|
63 |
+
checkpoint_freq: 10
|
64 |
+
transform_params:
|
65 |
+
sigma_affine: 0.05
|
66 |
+
sigma_tps: 0.005
|
67 |
+
points_tps: 5
|
68 |
+
loss_weights:
|
69 |
+
generator_gan: 1 # 0
|
70 |
+
discriminator_gan: 1
|
71 |
+
feature_matching: [10, 10, 10, 10]
|
72 |
+
perceptual: [10, 10, 10, 10, 10]
|
73 |
+
equivariance_value: 10
|
74 |
+
equivariance_jacobian: 10
|
75 |
+
perceptual_l1: 5
|
76 |
+
pose_edit: 1
|
77 |
+
|
78 |
+
reconstruction_params:
|
79 |
+
num_videos: 1000
|
80 |
+
format: '.mp4'
|
81 |
+
|
82 |
+
animate_params:
|
83 |
+
num_pairs: 50
|
84 |
+
format: '.mp4'
|
85 |
+
normalization_params:
|
86 |
+
adapt_movement_scale: True
|
87 |
+
use_relative_movement: True
|
88 |
+
use_relative_jacobian: True
|
89 |
+
|
90 |
+
visualizer_params:
|
91 |
+
kp_size: 5
|
92 |
+
draw_border: True
|
93 |
+
colormap: 'gist_rainbow'
|
modules/__pycache__/dense_motion.cpython-38.pyc
ADDED
Binary file (3.83 kB). View file
|
|
modules/__pycache__/generator.cpython-38.pyc
ADDED
Binary file (6.9 kB). View file
|
|
modules/__pycache__/keypoint_detector.cpython-38.pyc
ADDED
Binary file (2.48 kB). View file
|
|
modules/__pycache__/nerf_verts_util.cpython-38.pyc
ADDED
Binary file (7.04 kB). View file
|
|
modules/__pycache__/util.cpython-38.pyc
ADDED
Binary file (7.82 kB). View file
|
|
modules/dense_motion.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch
|
4 |
+
from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
|
5 |
+
|
6 |
+
|
7 |
+
class DenseMotionNetwork(nn.Module):
|
8 |
+
"""
|
9 |
+
Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False,
|
13 |
+
scale_factor=1, kp_variance=0.01):
|
14 |
+
super(DenseMotionNetwork, self).__init__()
|
15 |
+
self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1),
|
16 |
+
max_features=max_features, num_blocks=num_blocks)
|
17 |
+
|
18 |
+
self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3))
|
19 |
+
|
20 |
+
if estimate_occlusion_map:
|
21 |
+
self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3))
|
22 |
+
else:
|
23 |
+
self.occlusion = None
|
24 |
+
|
25 |
+
self.num_kp = num_kp
|
26 |
+
self.scale_factor = scale_factor
|
27 |
+
self.kp_variance = kp_variance
|
28 |
+
|
29 |
+
if self.scale_factor != 1:
|
30 |
+
self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
|
31 |
+
|
32 |
+
def create_heatmap_representations(self, source_image, kp_driving, kp_source):
|
33 |
+
"""
|
34 |
+
Eq 6. in the paper H_k(z)
|
35 |
+
"""
|
36 |
+
spatial_size = source_image.shape[2:]
|
37 |
+
gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance)
|
38 |
+
gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance)
|
39 |
+
heatmap = gaussian_driving - gaussian_source
|
40 |
+
|
41 |
+
#adding background feature
|
42 |
+
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type())
|
43 |
+
heatmap = torch.cat([zeros, heatmap], dim=1)
|
44 |
+
heatmap = heatmap.unsqueeze(2)
|
45 |
+
return heatmap
|
46 |
+
|
47 |
+
def create_sparse_motions(self, source_image, kp_driving, kp_source):
|
48 |
+
"""
|
49 |
+
Eq 4. in the paper T_{s<-d}(z)
|
50 |
+
"""
|
51 |
+
bs, _, h, w = source_image.shape
|
52 |
+
identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type())
|
53 |
+
identity_grid = identity_grid.view(1, 1, h, w, 2)
|
54 |
+
coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2)
|
55 |
+
if 'jacobian' in kp_driving:
|
56 |
+
jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
|
57 |
+
|
58 |
+
jacobian = jacobian.unsqueeze(-3).unsqueeze(-3)
|
59 |
+
jacobian = jacobian.repeat(1, 1, h, w, 1, 1)
|
60 |
+
coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
|
61 |
+
coordinate_grid = coordinate_grid.squeeze(-1)
|
62 |
+
|
63 |
+
driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2)
|
64 |
+
|
65 |
+
#adding background feature
|
66 |
+
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
|
67 |
+
sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1)
|
68 |
+
return sparse_motions
|
69 |
+
|
70 |
+
def create_deformed_source_image(self, source_image, sparse_motions):
|
71 |
+
"""
|
72 |
+
Eq 7. in the paper \hat{T}_{s<-d}(z)
|
73 |
+
"""
|
74 |
+
bs, _, h, w = source_image.shape
|
75 |
+
source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1)
|
76 |
+
source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w)
|
77 |
+
sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1))
|
78 |
+
sparse_deformed = F.grid_sample(source_repeat, sparse_motions)
|
79 |
+
sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w))
|
80 |
+
return sparse_deformed
|
81 |
+
|
82 |
+
def forward(self, source_image, kp_driving, kp_source):
|
83 |
+
if self.scale_factor != 1:
|
84 |
+
source_image = self.down(source_image)
|
85 |
+
|
86 |
+
bs, _, h, w = source_image.shape
|
87 |
+
|
88 |
+
out_dict = dict()
|
89 |
+
heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)
|
90 |
+
sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source)
|
91 |
+
deformed_source = self.create_deformed_source_image(source_image, sparse_motion)
|
92 |
+
out_dict['sparse_deformed'] = deformed_source
|
93 |
+
|
94 |
+
input = torch.cat([heatmap_representation, deformed_source], dim=2)
|
95 |
+
input = input.view(bs, -1, h, w)
|
96 |
+
|
97 |
+
prediction = self.hourglass(input)
|
98 |
+
|
99 |
+
mask = self.mask(prediction)
|
100 |
+
mask = F.softmax(mask, dim=1)
|
101 |
+
out_dict['mask'] = mask
|
102 |
+
mask = mask.unsqueeze(2)
|
103 |
+
sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3)
|
104 |
+
deformation = (sparse_motion * mask).sum(dim=1)
|
105 |
+
deformation = deformation.permute(0, 2, 3, 1)
|
106 |
+
|
107 |
+
out_dict['deformation'] = deformation
|
108 |
+
|
109 |
+
# Sec. 3.2 in the paper
|
110 |
+
if self.occlusion:
|
111 |
+
occlusion_map = torch.sigmoid(self.occlusion(prediction))
|
112 |
+
out_dict['occlusion_map'] = occlusion_map
|
113 |
+
|
114 |
+
return out_dict
|
modules/generator.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
|
5 |
+
from modules.dense_motion import DenseMotionNetwork
|
6 |
+
from modules.nerf_verts_util import RenderModel
|
7 |
+
|
8 |
+
|
9 |
+
class SPADE_layer(nn.Module):
|
10 |
+
def __init__(self, norm_channel, label_channel):
|
11 |
+
super(SPADE_layer, self).__init__()
|
12 |
+
|
13 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_channel, affine=False)
|
14 |
+
hidden_channel = 128
|
15 |
+
|
16 |
+
self.mlp_shared = nn.Sequential(
|
17 |
+
nn.Conv2d(label_channel, hidden_channel, kernel_size=3, padding=1),
|
18 |
+
nn.ReLU()
|
19 |
+
)
|
20 |
+
self.mlp_gamma = nn.Conv2d(hidden_channel, norm_channel, kernel_size=3, padding=1)
|
21 |
+
self.mlp_beta = nn.Conv2d(hidden_channel, norm_channel, kernel_size=3, padding=1)
|
22 |
+
|
23 |
+
def forward(self, x, modulation_in):
|
24 |
+
normalized = self.param_free_norm(x)
|
25 |
+
modulation_in = F.interpolate(modulation_in, size=x.size()[2:], mode='nearest')
|
26 |
+
actv = self.mlp_shared(modulation_in)
|
27 |
+
gamma = self.mlp_gamma(actv)
|
28 |
+
beta = self.mlp_beta(actv)
|
29 |
+
out = normalized * (1 + gamma) + beta
|
30 |
+
return out
|
31 |
+
|
32 |
+
|
33 |
+
class SPADE_block(nn.Module):
|
34 |
+
def __init__(self, norm_channel, label_channel, out_channel):
|
35 |
+
super(SPADE_block, self).__init__()
|
36 |
+
self.SPADE_0 = SPADE_layer(norm_channel, label_channel)
|
37 |
+
self.relu_0 = nn.ReLU()
|
38 |
+
self.conv_0 = nn.Conv2d(norm_channel, norm_channel, kernel_size=3, padding=1)
|
39 |
+
self.SPADE_1 = SPADE_layer(norm_channel, label_channel)
|
40 |
+
self.relu_1 = nn.ReLU()
|
41 |
+
self.conv_1 = nn.Conv2d(norm_channel, out_channel, kernel_size=3, padding=1)
|
42 |
+
|
43 |
+
def forward(self, x, modulation_in):
|
44 |
+
out = self.SPADE_0(x, modulation_in)
|
45 |
+
out = self.relu_0(out)
|
46 |
+
out = self.conv_0(out)
|
47 |
+
out = self.SPADE_1(out, modulation_in)
|
48 |
+
out = self.relu_1(out)
|
49 |
+
out = self.conv_1(out)
|
50 |
+
return out
|
51 |
+
|
52 |
+
|
53 |
+
class SPADE_decoder(nn.Module):
|
54 |
+
def __init__(self, in_channel, mid_channel):
|
55 |
+
super(SPADE_decoder, self).__init__()
|
56 |
+
self.in_channel = in_channel
|
57 |
+
self.mid_channel = mid_channel
|
58 |
+
self.seg_conv = nn.Sequential(
|
59 |
+
nn.Conv2d(in_channel, mid_channel, kernel_size=3, padding=1),
|
60 |
+
nn.ReLU()
|
61 |
+
)
|
62 |
+
self.SPADE_0 = SPADE_block(in_channel, mid_channel, in_channel // 4)
|
63 |
+
self.up_0 = nn.UpsamplingBilinear2d(scale_factor=2)
|
64 |
+
in_channel = in_channel // 4
|
65 |
+
self.SPADE_1 = SPADE_block(in_channel, mid_channel, in_channel // 4)
|
66 |
+
self.up_1 = nn.UpsamplingBilinear2d(scale_factor=2)
|
67 |
+
in_channel = in_channel // 4
|
68 |
+
self.SPADE_2 = SPADE_block(in_channel, mid_channel, in_channel)
|
69 |
+
self.SPADE_3 = SPADE_block(in_channel, mid_channel, in_channel)
|
70 |
+
self.final = nn.Sequential(
|
71 |
+
nn.Conv2d(in_channel, 3, kernel_size=7, padding=3),
|
72 |
+
nn.Sigmoid()
|
73 |
+
)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
seg = self.seg_conv(x)
|
77 |
+
x = self.SPADE_0(x, seg)
|
78 |
+
x = self.up_0(x)
|
79 |
+
x = self.SPADE_1(x, seg)
|
80 |
+
x = self.up_1(x)
|
81 |
+
x = self.SPADE_2(x, seg)
|
82 |
+
x = self.SPADE_3(x, seg)
|
83 |
+
x = self.final(x)
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
def calc_mean_std(feat, eps=1e-5):
|
88 |
+
# eps is a small value added to the variance to avoid divide-by-zero.
|
89 |
+
size = feat.size()
|
90 |
+
assert (len(size) == 4)
|
91 |
+
N, C = size[:2]
|
92 |
+
feat_var = feat.view(N, C, -1).var(dim=2) + eps
|
93 |
+
feat_std = feat_var.sqrt().view(N, C, 1, 1)
|
94 |
+
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
95 |
+
return feat_mean, feat_std
|
96 |
+
|
97 |
+
|
98 |
+
def adaptive_instance_normalization(x, modulation_in):
|
99 |
+
assert (x.size()[:2] == modulation_in.size()[:2])
|
100 |
+
size = x.size()
|
101 |
+
style_mean, style_std = calc_mean_std(modulation_in)
|
102 |
+
content_mean, content_std = calc_mean_std(x)
|
103 |
+
|
104 |
+
normalized_feat = (x - content_mean.expand(
|
105 |
+
size)) / content_std.expand(size)
|
106 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
107 |
+
|
108 |
+
|
109 |
+
class AdaIN_layer(nn.Module):
|
110 |
+
def __init__(self, norm_channel, label_channel):
|
111 |
+
super(AdaIN_layer, self).__init__()
|
112 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_channel, affine=False)
|
113 |
+
|
114 |
+
self.mlp_shared = nn.Sequential(
|
115 |
+
nn.Conv2d(label_channel, norm_channel, kernel_size=3, padding=1),
|
116 |
+
nn.ReLU()
|
117 |
+
)
|
118 |
+
|
119 |
+
def forward(self, x, modulation_in):
|
120 |
+
normalized = self.param_free_norm(x)
|
121 |
+
modulation_in = self.mlp_shared(modulation_in)
|
122 |
+
out = adaptive_instance_normalization(normalized, modulation_in)
|
123 |
+
return out
|
124 |
+
|
125 |
+
|
126 |
+
class OcclusionAwareGenerator_SPADE(nn.Module):
|
127 |
+
"""
|
128 |
+
Generator that given source image and and keypoints try to transform image according to movement trajectories
|
129 |
+
induced by keypoints. Generator follows Johnson architecture.
|
130 |
+
"""
|
131 |
+
|
132 |
+
def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks,
|
133 |
+
num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, render_params=None,
|
134 |
+
estimate_jacobian=False):
|
135 |
+
super(OcclusionAwareGenerator_SPADE, self).__init__()
|
136 |
+
|
137 |
+
if dense_motion_params is not None:
|
138 |
+
self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels,
|
139 |
+
estimate_occlusion_map=estimate_occlusion_map,
|
140 |
+
**dense_motion_params)
|
141 |
+
else:
|
142 |
+
self.dense_motion_network = None
|
143 |
+
|
144 |
+
self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
|
145 |
+
|
146 |
+
down_blocks = []
|
147 |
+
for i in range(num_down_blocks):
|
148 |
+
in_features = min(max_features, block_expansion * (2 ** i))
|
149 |
+
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
|
150 |
+
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
151 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
152 |
+
|
153 |
+
in_features = min(max_features, block_expansion * (2 ** num_down_blocks))
|
154 |
+
|
155 |
+
self.Render_model = RenderModel(in_channels=in_features, **render_params)
|
156 |
+
self.decoder = SPADE_decoder(in_channel=in_features * 2, mid_channel=128)
|
157 |
+
|
158 |
+
self.estimate_occlusion_map = estimate_occlusion_map
|
159 |
+
self.num_channels = num_channels
|
160 |
+
|
161 |
+
def deform_input(self, inp, deformation):
|
162 |
+
_, h_old, w_old, _ = deformation.shape
|
163 |
+
_, _, h, w = inp.shape
|
164 |
+
if h_old != h or w_old != w:
|
165 |
+
deformation = deformation.permute(0, 3, 1, 2)
|
166 |
+
deformation = F.interpolate(deformation, size=(h, w), mode='bilinear')
|
167 |
+
deformation = deformation.permute(0, 2, 3, 1)
|
168 |
+
return F.grid_sample(inp, deformation)
|
169 |
+
|
170 |
+
def forward(self, source_image, kp_driving, kp_source):
|
171 |
+
# Encoding (downsampling) part
|
172 |
+
out = self.first(source_image)
|
173 |
+
for i in range(len(self.down_blocks)):
|
174 |
+
out = self.down_blocks[i](out)
|
175 |
+
|
176 |
+
# Transforming feature representation according to deformation and occlusion
|
177 |
+
output_dict = {}
|
178 |
+
if self.dense_motion_network is not None:
|
179 |
+
dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving,
|
180 |
+
kp_source=kp_source)
|
181 |
+
output_dict['mask'] = dense_motion['mask']
|
182 |
+
output_dict['sparse_deformed'] = dense_motion['sparse_deformed']
|
183 |
+
|
184 |
+
if 'occlusion_map' in dense_motion:
|
185 |
+
occlusion_map = dense_motion['occlusion_map']
|
186 |
+
output_dict['occlusion_map'] = occlusion_map
|
187 |
+
else:
|
188 |
+
occlusion_map = None
|
189 |
+
deformation = dense_motion['deformation']
|
190 |
+
out = self.deform_input(out, deformation)
|
191 |
+
|
192 |
+
if occlusion_map is not None:
|
193 |
+
if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
|
194 |
+
occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
|
195 |
+
out = out * occlusion_map
|
196 |
+
|
197 |
+
output_dict["deformed"] = self.deform_input(source_image, deformation)
|
198 |
+
|
199 |
+
# render part
|
200 |
+
render_result = self.Render_model(feature=out)
|
201 |
+
output_dict['render'] = render_result['mini_pred']
|
202 |
+
output_dict['point_pred'] = render_result['point_pred']
|
203 |
+
out = torch.cat((out, render_result['render']), dim=1)
|
204 |
+
# out = self.merge_conv(out)
|
205 |
+
|
206 |
+
# Decoding part
|
207 |
+
out = self.decoder(out)
|
208 |
+
|
209 |
+
output_dict["prediction"] = out
|
210 |
+
|
211 |
+
return output_dict
|
modules/keypoint_detector.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d
|
5 |
+
|
6 |
+
|
7 |
+
class KPDetector(nn.Module):
|
8 |
+
"""
|
9 |
+
Detecting a keypoints. Return keypoint position and jacobian near each keypoint.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, block_expansion, num_kp, num_channels, max_features,
|
13 |
+
num_blocks, temperature, estimate_jacobian=False, estimate_hessian=False,
|
14 |
+
scale_factor=1, single_jacobian_map=False, pad=0):
|
15 |
+
super(KPDetector, self).__init__()
|
16 |
+
|
17 |
+
self.predictor = Hourglass(block_expansion, in_features=num_channels,
|
18 |
+
max_features=max_features, num_blocks=num_blocks)
|
19 |
+
|
20 |
+
self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7),
|
21 |
+
padding=pad)
|
22 |
+
|
23 |
+
if estimate_jacobian:
|
24 |
+
self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
|
25 |
+
self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters,
|
26 |
+
out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad)
|
27 |
+
self.jacobian.weight.data.zero_()
|
28 |
+
self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
|
29 |
+
else:
|
30 |
+
self.jacobian = None
|
31 |
+
|
32 |
+
self.temperature = temperature
|
33 |
+
self.scale_factor = scale_factor
|
34 |
+
if self.scale_factor != 1:
|
35 |
+
self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
|
36 |
+
|
37 |
+
def gaussian2kp(self, heatmap):
|
38 |
+
"""
|
39 |
+
Extract the mean and from a heatmap
|
40 |
+
"""
|
41 |
+
shape = heatmap.shape
|
42 |
+
heatmap = heatmap.unsqueeze(-1)
|
43 |
+
grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)
|
44 |
+
value = (heatmap * grid).sum(dim=(2, 3))
|
45 |
+
kp = {'value': value}
|
46 |
+
|
47 |
+
return kp
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
if self.scale_factor != 1:
|
51 |
+
x = self.down(x)
|
52 |
+
|
53 |
+
feature_map = self.predictor(x)
|
54 |
+
prediction = self.kp(feature_map)
|
55 |
+
|
56 |
+
final_shape = prediction.shape
|
57 |
+
heatmap = prediction.view(final_shape[0], final_shape[1], -1)
|
58 |
+
heatmap = F.softmax(heatmap / self.temperature, dim=2)
|
59 |
+
heatmap = heatmap.view(*final_shape)
|
60 |
+
|
61 |
+
out = self.gaussian2kp(heatmap)
|
62 |
+
|
63 |
+
if self.jacobian is not None:
|
64 |
+
jacobian_map = self.jacobian(feature_map)
|
65 |
+
|
66 |
+
jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
|
67 |
+
final_shape[3])
|
68 |
+
heatmap = heatmap.unsqueeze(2)
|
69 |
+
|
70 |
+
jacobian = heatmap * jacobian_map
|
71 |
+
jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
|
72 |
+
jacobian = jacobian.sum(dim=-1)
|
73 |
+
jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2)
|
74 |
+
out['jacobian'] = jacobian
|
75 |
+
|
76 |
+
return out
|
modules/nerf_verts_util.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
|
5 |
+
from sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d
|
6 |
+
import einops
|
7 |
+
from modules.util import UpBlock2d, DownBlock2d
|
8 |
+
|
9 |
+
|
10 |
+
def make_coordinate_grid(spatial_size, type):
|
11 |
+
d, h, w = spatial_size
|
12 |
+
x = torch.arange(w).type(type)
|
13 |
+
y = torch.arange(h).type(type)
|
14 |
+
z = torch.arange(d).type(type)
|
15 |
+
|
16 |
+
x = (2 * (x / (w - 1)) - 1)
|
17 |
+
y = (2 * (y / (h - 1)) - 1)
|
18 |
+
z = (2 * (z / (d - 1)) - 1)
|
19 |
+
|
20 |
+
yy = y.view(1, -1, 1).repeat(d, 1, w)
|
21 |
+
xx = x.view(1, 1, -1).repeat(d, h, 1)
|
22 |
+
zz = z.view(-1, 1, 1).repeat(1, h, w)
|
23 |
+
|
24 |
+
meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)
|
25 |
+
|
26 |
+
return meshed
|
27 |
+
|
28 |
+
|
29 |
+
def kp2gaussian_3d(kp, spatial_size, kp_variance):
|
30 |
+
"""
|
31 |
+
Transform a keypoint into gaussian like representation
|
32 |
+
"""
|
33 |
+
# mean = kp['value']
|
34 |
+
mean = kp
|
35 |
+
|
36 |
+
coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
|
37 |
+
number_of_leading_dimensions = len(mean.shape) - 1
|
38 |
+
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
|
39 |
+
coordinate_grid = coordinate_grid.view(*shape)
|
40 |
+
repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
|
41 |
+
coordinate_grid = coordinate_grid.repeat(*repeats)
|
42 |
+
|
43 |
+
# Preprocess kp shape
|
44 |
+
shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
|
45 |
+
mean = mean.view(*shape)
|
46 |
+
|
47 |
+
mean_sub = (coordinate_grid - mean)
|
48 |
+
|
49 |
+
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
|
50 |
+
|
51 |
+
return out
|
52 |
+
|
53 |
+
|
54 |
+
class ResBlock3d(nn.Module):
|
55 |
+
"""
|
56 |
+
Res block, preserve spatial resolution.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, in_features, kernel_size, padding):
|
60 |
+
super(ResBlock3d, self).__init__()
|
61 |
+
self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
62 |
+
padding=padding)
|
63 |
+
self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
64 |
+
padding=padding)
|
65 |
+
self.norm1 = BatchNorm3d(in_features, affine=True)
|
66 |
+
self.norm2 = BatchNorm3d(in_features, affine=True)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
out = self.norm1(x)
|
70 |
+
out = F.relu(out)
|
71 |
+
out = self.conv1(out)
|
72 |
+
out = self.norm2(out)
|
73 |
+
out = F.relu(out)
|
74 |
+
out = self.conv2(out)
|
75 |
+
out += x
|
76 |
+
return out
|
77 |
+
|
78 |
+
|
79 |
+
class rgb_predictor(nn.Module):
|
80 |
+
def __init__(self, in_channels, simpled_channel=128, floor_num=8):
|
81 |
+
super(rgb_predictor, self).__init__()
|
82 |
+
self.floor_num = floor_num
|
83 |
+
self.down_conv = nn.Conv2d(in_channels=in_channels, out_channels=simpled_channel, kernel_size=3, padding=1)
|
84 |
+
|
85 |
+
def forward(self, feature):
|
86 |
+
"""
|
87 |
+
Args:
|
88 |
+
feature: warp feature: bs * c * h * w
|
89 |
+
Returns:
|
90 |
+
rgb: bs * h * w * floor_num * e
|
91 |
+
"""
|
92 |
+
feature = self.down_conv(feature)
|
93 |
+
feature = einops.rearrange(feature, 'b (c f) h w -> b c f h w', f=self.floor_num)
|
94 |
+
feature = einops.rearrange(feature, 'b c f h w -> b h w f c')
|
95 |
+
return feature
|
96 |
+
|
97 |
+
|
98 |
+
class sigma_predictor(nn.Module):
|
99 |
+
def __init__(self, in_channels, simpled_channel=128, floor_num=8):
|
100 |
+
super(sigma_predictor, self).__init__()
|
101 |
+
self.floor_num = floor_num
|
102 |
+
self.down_conv = nn.Conv2d(in_channels=in_channels, out_channels=simpled_channel, kernel_size=3, padding=1)
|
103 |
+
|
104 |
+
self.res_conv3d = nn.Sequential(
|
105 |
+
ResBlock3d(16, 3, 1),
|
106 |
+
nn.BatchNorm3d(16),
|
107 |
+
ResBlock3d(16, 3, 1),
|
108 |
+
nn.BatchNorm3d(16),
|
109 |
+
ResBlock3d(16, 3, 1),
|
110 |
+
nn.BatchNorm3d(16)
|
111 |
+
)
|
112 |
+
|
113 |
+
def forward(self, feature):
|
114 |
+
"""
|
115 |
+
Args:
|
116 |
+
feature: bs * h * w * floor * c, the output of rgb predictor
|
117 |
+
Returns:
|
118 |
+
sigma: bs * h * w * floor * encode
|
119 |
+
point: bs * 5023 * 3
|
120 |
+
"""
|
121 |
+
heatmap = self.down_conv(feature)
|
122 |
+
heatmap = einops.rearrange(heatmap, "b (c f) h w -> b c f h w", f=self.floor_num)
|
123 |
+
heatmap = self.res_conv3d(heatmap)
|
124 |
+
sigma = einops.rearrange(heatmap, "b c f h w -> b h w f c")
|
125 |
+
|
126 |
+
point_dict = {'sigma_map': heatmap}
|
127 |
+
# point_pred = einops.rearrange(point_pred, 'b p n -> b n p')
|
128 |
+
return sigma, point_dict
|
129 |
+
|
130 |
+
|
131 |
+
class MultiHeadNeRFModel(torch.nn.Module):
|
132 |
+
|
133 |
+
def __init__(self, hidden_size=128, num_encoding_rgb=16, num_encoding_sigma=16):
|
134 |
+
super(MultiHeadNeRFModel, self).__init__()
|
135 |
+
# self.xyz_encoding_dims = 1 + 1 * 2 * num_encoding_functions + num_encoding_rgb
|
136 |
+
self.xyz_encoding_dims = num_encoding_sigma
|
137 |
+
self.viewdir_encoding_dims = num_encoding_rgb
|
138 |
+
|
139 |
+
# Input layer (default: 16 -> 128)
|
140 |
+
self.layer1 = torch.nn.Linear(self.xyz_encoding_dims, hidden_size)
|
141 |
+
# Layer 2 (default: 128 -> 128)
|
142 |
+
self.layer2 = torch.nn.Linear(hidden_size, hidden_size)
|
143 |
+
# Layer 3_1 (default: 128 -> 1): Predicts radiance ("sigma")
|
144 |
+
self.layer3_1 = torch.nn.Linear(hidden_size, 1)
|
145 |
+
# Layer 3_2 (default: 128 -> 32): Predicts a feature vector (used for color)
|
146 |
+
self.layer3_2 = torch.nn.Linear(hidden_size, hidden_size // 4)
|
147 |
+
self.layer3_3 = torch.nn.Linear(self.viewdir_encoding_dims, hidden_size)
|
148 |
+
|
149 |
+
# Layer 4 (default: 32 + 128 -> 128)
|
150 |
+
self.layer4 = torch.nn.Linear(
|
151 |
+
hidden_size // 4 + hidden_size, hidden_size
|
152 |
+
)
|
153 |
+
# Layer 5 (default: 128 -> 128)
|
154 |
+
self.layer5 = torch.nn.Linear(hidden_size, hidden_size)
|
155 |
+
# Layer 6 (default: 128 -> 256): Predicts RGB color
|
156 |
+
self.layer6 = torch.nn.Linear(hidden_size, 256)
|
157 |
+
|
158 |
+
# Short hand for torch.nn.functional.relu
|
159 |
+
self.relu = torch.nn.functional.relu
|
160 |
+
|
161 |
+
def forward(self, rgb_in, sigma_in):
|
162 |
+
"""
|
163 |
+
Args:
|
164 |
+
x: rgb pred result of Perdict3D
|
165 |
+
view: result of LightPredict
|
166 |
+
Returns:
|
167 |
+
"""
|
168 |
+
bs, h, w, floor_num, _ = rgb_in.size()
|
169 |
+
# x = torch.cat((x, point3D), dim=-1)
|
170 |
+
out = self.relu(self.layer1(sigma_in))
|
171 |
+
out = self.relu(self.layer2(out))
|
172 |
+
sigma = self.layer3_1(out)
|
173 |
+
feat_sigma = self.relu(self.layer3_2(out))
|
174 |
+
feat_rgb = self.relu(self.layer3_3(rgb_in))
|
175 |
+
x = torch.cat((feat_sigma, feat_rgb), dim=-1)
|
176 |
+
x = self.relu(self.layer4(x))
|
177 |
+
x = self.relu(self.layer5(x))
|
178 |
+
x = self.layer6(x)
|
179 |
+
return x, sigma
|
180 |
+
|
181 |
+
|
182 |
+
def volume_render(rgb_pred, sigma_pred):
|
183 |
+
"""
|
184 |
+
Args:
|
185 |
+
rgb_pred: result of Nerf, [bs, h, w, floor, rgb_channel]
|
186 |
+
sigma_pred: result of Nerf, [bs, h, w, floor, sigma_channel]
|
187 |
+
Returns:
|
188 |
+
|
189 |
+
"""
|
190 |
+
_, _, _, floor, _ = sigma_pred.size()
|
191 |
+
c = 0
|
192 |
+
T = 0
|
193 |
+
for i in range(floor):
|
194 |
+
sigma_mid = torch.nn.functional.relu(sigma_pred[:, :, :, i, :])
|
195 |
+
T = T + (-sigma_mid)
|
196 |
+
c = c + torch.exp(T) * (1 - torch.exp(-sigma_mid)) * rgb_pred[:, :, :, i, :]
|
197 |
+
c = einops.rearrange(c, 'b h w c -> b c h w')
|
198 |
+
return c
|
199 |
+
|
200 |
+
|
201 |
+
class RenderModel(nn.Module):
|
202 |
+
def __init__(self, in_channels, simpled_channel_rgb, simpled_channel_sigma, floor_num, hidden_size):
|
203 |
+
super(RenderModel, self).__init__()
|
204 |
+
self.rgb_predict = rgb_predictor(in_channels=in_channels, simpled_channel=simpled_channel_rgb,
|
205 |
+
floor_num=floor_num)
|
206 |
+
self.sigma_predict = sigma_predictor(in_channels=in_channels, simpled_channel=simpled_channel_sigma,
|
207 |
+
floor_num=floor_num)
|
208 |
+
num_encoding_rgb, num_encoding_sigma = simpled_channel_rgb // floor_num, simpled_channel_sigma // floor_num
|
209 |
+
self.nerf_module = MultiHeadNeRFModel(hidden_size=hidden_size, num_encoding_rgb=num_encoding_rgb,
|
210 |
+
num_encoding_sigma=num_encoding_sigma)
|
211 |
+
self.mini_decoder = nn.Sequential(
|
212 |
+
UpBlock2d(256, 64, kernel_size=3, padding=1),
|
213 |
+
nn.ReLU(),
|
214 |
+
UpBlock2d(64, 3, kernel_size=3, padding=1),
|
215 |
+
nn.Sigmoid()
|
216 |
+
)
|
217 |
+
|
218 |
+
def forward(self, feature):
|
219 |
+
rgb_in = self.rgb_predict(feature)
|
220 |
+
# sigma_in, point_dict = self.sigma_predict(feature.detach())
|
221 |
+
sigma_in, point_dict = self.sigma_predict(feature)
|
222 |
+
rgb_out, sigma_out = self.nerf_module(rgb_in, sigma_in)
|
223 |
+
render_result = volume_render(rgb_out, sigma_out)
|
224 |
+
render_result = torch.sigmoid(render_result)
|
225 |
+
mini_pred = self.mini_decoder(render_result)
|
226 |
+
out_dict = {'render': render_result, 'mini_pred': mini_pred, 'point_pred': point_dict}
|
227 |
+
return out_dict
|
modules/util.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
|
7 |
+
|
8 |
+
|
9 |
+
def kp2gaussian(kp, spatial_size, kp_variance):
|
10 |
+
"""
|
11 |
+
Transform a keypoint into gaussian like representation
|
12 |
+
"""
|
13 |
+
mean = kp['value']
|
14 |
+
|
15 |
+
coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
|
16 |
+
number_of_leading_dimensions = len(mean.shape) - 1
|
17 |
+
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
|
18 |
+
coordinate_grid = coordinate_grid.view(*shape)
|
19 |
+
repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1)
|
20 |
+
coordinate_grid = coordinate_grid.repeat(*repeats)
|
21 |
+
|
22 |
+
# Preprocess kp shape
|
23 |
+
shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2)
|
24 |
+
mean = mean.view(*shape)
|
25 |
+
|
26 |
+
mean_sub = (coordinate_grid - mean)
|
27 |
+
|
28 |
+
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
|
29 |
+
|
30 |
+
return out
|
31 |
+
|
32 |
+
|
33 |
+
def make_coordinate_grid(spatial_size, type):
|
34 |
+
"""
|
35 |
+
Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
|
36 |
+
"""
|
37 |
+
h, w = spatial_size
|
38 |
+
x = torch.arange(w).type(type)
|
39 |
+
y = torch.arange(h).type(type)
|
40 |
+
|
41 |
+
x = (2 * (x / (w - 1)) - 1)
|
42 |
+
y = (2 * (y / (h - 1)) - 1)
|
43 |
+
|
44 |
+
yy = y.view(-1, 1).repeat(1, w)
|
45 |
+
xx = x.view(1, -1).repeat(h, 1)
|
46 |
+
|
47 |
+
meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
|
48 |
+
|
49 |
+
return meshed
|
50 |
+
|
51 |
+
|
52 |
+
class ResBlock2d(nn.Module):
|
53 |
+
"""
|
54 |
+
Res block, preserve spatial resolution.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self, in_features, kernel_size, padding):
|
58 |
+
super(ResBlock2d, self).__init__()
|
59 |
+
self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
60 |
+
padding=padding)
|
61 |
+
self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
62 |
+
padding=padding)
|
63 |
+
self.norm1 = BatchNorm2d(in_features, affine=True)
|
64 |
+
self.norm2 = BatchNorm2d(in_features, affine=True)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
out = self.norm1(x)
|
68 |
+
out = F.relu(out)
|
69 |
+
out = self.conv1(out)
|
70 |
+
out = self.norm2(out)
|
71 |
+
out = F.relu(out)
|
72 |
+
out = self.conv2(out)
|
73 |
+
out += x
|
74 |
+
return out
|
75 |
+
|
76 |
+
|
77 |
+
class UpBlock2d(nn.Module):
|
78 |
+
"""
|
79 |
+
Upsampling block for use in decoder.
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
83 |
+
super(UpBlock2d, self).__init__()
|
84 |
+
|
85 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
86 |
+
padding=padding, groups=groups)
|
87 |
+
self.norm = BatchNorm2d(out_features, affine=True)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
out = F.interpolate(x, scale_factor=2)
|
91 |
+
out = self.conv(out)
|
92 |
+
out = self.norm(out)
|
93 |
+
out = F.relu(out)
|
94 |
+
return out
|
95 |
+
|
96 |
+
|
97 |
+
class DownBlock2d(nn.Module):
|
98 |
+
"""
|
99 |
+
Downsampling block for use in encoder.
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
103 |
+
super(DownBlock2d, self).__init__()
|
104 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
105 |
+
padding=padding, groups=groups)
|
106 |
+
self.norm = BatchNorm2d(out_features, affine=True)
|
107 |
+
self.pool = nn.AvgPool2d(kernel_size=(2, 2))
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
out = self.conv(x)
|
111 |
+
out = self.norm(out)
|
112 |
+
out = F.relu(out)
|
113 |
+
out = self.pool(out)
|
114 |
+
return out
|
115 |
+
|
116 |
+
|
117 |
+
class SameBlock2d(nn.Module):
|
118 |
+
"""
|
119 |
+
Simple block, preserve spatial resolution.
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):
|
123 |
+
super(SameBlock2d, self).__init__()
|
124 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
|
125 |
+
kernel_size=kernel_size, padding=padding, groups=groups)
|
126 |
+
self.norm = BatchNorm2d(out_features, affine=True)
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
out = self.conv(x)
|
130 |
+
out = self.norm(out)
|
131 |
+
out = F.relu(out)
|
132 |
+
return out
|
133 |
+
|
134 |
+
|
135 |
+
class Encoder(nn.Module):
|
136 |
+
"""
|
137 |
+
Hourglass Encoder
|
138 |
+
"""
|
139 |
+
|
140 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
141 |
+
super(Encoder, self).__init__()
|
142 |
+
|
143 |
+
down_blocks = []
|
144 |
+
for i in range(num_blocks):
|
145 |
+
down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
|
146 |
+
min(max_features, block_expansion * (2 ** (i + 1))),
|
147 |
+
kernel_size=3, padding=1))
|
148 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
outs = [x]
|
152 |
+
for down_block in self.down_blocks:
|
153 |
+
outs.append(down_block(outs[-1]))
|
154 |
+
return outs
|
155 |
+
|
156 |
+
|
157 |
+
class Decoder(nn.Module):
|
158 |
+
"""
|
159 |
+
Hourglass Decoder
|
160 |
+
"""
|
161 |
+
|
162 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
163 |
+
super(Decoder, self).__init__()
|
164 |
+
|
165 |
+
up_blocks = []
|
166 |
+
|
167 |
+
for i in range(num_blocks)[::-1]:
|
168 |
+
in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
|
169 |
+
out_filters = min(max_features, block_expansion * (2 ** i))
|
170 |
+
up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
|
171 |
+
|
172 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
173 |
+
self.out_filters = block_expansion + in_features
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
out = x.pop()
|
177 |
+
for up_block in self.up_blocks:
|
178 |
+
out = up_block(out)
|
179 |
+
skip = x.pop()
|
180 |
+
out = torch.cat([out, skip], dim=1)
|
181 |
+
return out
|
182 |
+
|
183 |
+
|
184 |
+
class Hourglass(nn.Module):
|
185 |
+
"""
|
186 |
+
Hourglass architecture.
|
187 |
+
"""
|
188 |
+
|
189 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
190 |
+
super(Hourglass, self).__init__()
|
191 |
+
self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
|
192 |
+
self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
|
193 |
+
self.out_filters = self.decoder.out_filters
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
return self.decoder(self.encoder(x))
|
197 |
+
|
198 |
+
|
199 |
+
class AntiAliasInterpolation2d(nn.Module):
|
200 |
+
"""
|
201 |
+
Band-limited downsampling, for better preservation of the input signal.
|
202 |
+
"""
|
203 |
+
def __init__(self, channels, scale):
|
204 |
+
super(AntiAliasInterpolation2d, self).__init__()
|
205 |
+
sigma = (1 / scale - 1) / 2
|
206 |
+
kernel_size = 2 * round(sigma * 4) + 1
|
207 |
+
self.ka = kernel_size // 2
|
208 |
+
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
|
209 |
+
|
210 |
+
kernel_size = [kernel_size, kernel_size]
|
211 |
+
sigma = [sigma, sigma]
|
212 |
+
# The gaussian kernel is the product of the
|
213 |
+
# gaussian function of each dimension.
|
214 |
+
kernel = 1
|
215 |
+
meshgrids = torch.meshgrid(
|
216 |
+
[
|
217 |
+
torch.arange(size, dtype=torch.float32)
|
218 |
+
for size in kernel_size
|
219 |
+
]
|
220 |
+
)
|
221 |
+
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
222 |
+
mean = (size - 1) / 2
|
223 |
+
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
|
224 |
+
|
225 |
+
# Make sure sum of values in gaussian kernel equals 1.
|
226 |
+
kernel = kernel / torch.sum(kernel)
|
227 |
+
# Reshape to depthwise convolutional weight
|
228 |
+
kernel = kernel.view(1, 1, *kernel.size())
|
229 |
+
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
230 |
+
|
231 |
+
self.register_buffer('weight', kernel)
|
232 |
+
self.groups = channels
|
233 |
+
self.scale = scale
|
234 |
+
inv_scale = 1 / scale
|
235 |
+
self.int_inv_scale = int(inv_scale)
|
236 |
+
|
237 |
+
def forward(self, input):
|
238 |
+
if self.scale == 1.0:
|
239 |
+
return input
|
240 |
+
|
241 |
+
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
|
242 |
+
out = F.conv2d(out, weight=self.weight, groups=self.groups)
|
243 |
+
out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
|
244 |
+
|
245 |
+
return out
|
requirements.txt
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.0.0
|
2 |
+
aiohttp==3.7.4.post0
|
3 |
+
astropy==5.0.4
|
4 |
+
astunparse==1.6.3
|
5 |
+
async-timeout==3.0.1
|
6 |
+
attrs==21.4.0
|
7 |
+
blessings==1.7
|
8 |
+
blinker==1.4
|
9 |
+
brotlipy==0.7.0
|
10 |
+
cachetools==5.0.0
|
11 |
+
certifi==2021.10.8
|
12 |
+
cffi==1.15.0
|
13 |
+
chardet==4.0.0
|
14 |
+
charset-normalizer==2.0.12
|
15 |
+
chumpy==0.70
|
16 |
+
colorama==0.4.4
|
17 |
+
cryptography==35.0.0
|
18 |
+
cycler==0.11.0
|
19 |
+
Cython==0.29.28
|
20 |
+
dill==0.3.5.1
|
21 |
+
dominate==2.6.0
|
22 |
+
einops==0.4.1
|
23 |
+
face-alignment==1.3.5
|
24 |
+
flatbuffers==2.0
|
25 |
+
fonttools==4.33.3
|
26 |
+
fvcore==0.1.5.post20210915
|
27 |
+
gast==0.5.3
|
28 |
+
google-auth==2.6.6
|
29 |
+
google-auth-oauthlib==0.4.6
|
30 |
+
google-pasta==0.2.0
|
31 |
+
grpcio==1.46.0
|
32 |
+
h5py==3.6.0
|
33 |
+
idna==3.3
|
34 |
+
imageio==2.19.1
|
35 |
+
imageio-ffmpeg==0.4.7
|
36 |
+
importlib-metadata==4.11.3
|
37 |
+
iopath==0.1.9
|
38 |
+
joblib==1.1.0
|
39 |
+
keras==2.8.0
|
40 |
+
Keras-Preprocessing==1.1.2
|
41 |
+
kiwisolver==1.4.2
|
42 |
+
libclang==14.0.1
|
43 |
+
llvmlite==0.38.0
|
44 |
+
matplotlib==3.5.2
|
45 |
+
mkl-fft==1.3.0
|
46 |
+
mkl-random==1.2.2
|
47 |
+
mkl-service==2.4.0
|
48 |
+
multidict==5.2.0
|
49 |
+
multiprocess==0.70.12.2
|
50 |
+
networkx==2.8
|
51 |
+
numba
|
52 |
+
nvidia-ml-py3==7.352.0
|
53 |
+
oauthlib==3.2.0
|
54 |
+
onnx==1.11.0
|
55 |
+
opencv-python==4.5.5.64
|
56 |
+
opt-einsum==3.3.0
|
57 |
+
packaging==21.3
|
58 |
+
pandas==1.4.2
|
59 |
+
Pillow==9.0.1
|
60 |
+
portalocker==2.4.0
|
61 |
+
protobuf==3.20.1
|
62 |
+
psutil==5.9.0
|
63 |
+
pyasn1==0.4.8
|
64 |
+
pyasn1-modules==0.2.8
|
65 |
+
pycparser==2.21
|
66 |
+
pyerfa==2.0.0.1
|
67 |
+
PyJWT==2.3.0
|
68 |
+
pyOpenSSL==22.0.0
|
69 |
+
pyparsing==3.0.8
|
70 |
+
PySocks==1.7.1
|
71 |
+
python-dateutil==2.8.2
|
72 |
+
pytz==2022.1
|
73 |
+
pyu2f==0.1.5
|
74 |
+
PyWavelets==1.3.0
|
75 |
+
PyYAML==5.4.1
|
76 |
+
requests==2.27.1
|
77 |
+
requests-oauthlib==1.3.1
|
78 |
+
requests-toolbelt==0.9.1
|
79 |
+
rsa==4.8
|
80 |
+
scikit-image==0.17.2
|
81 |
+
scikit-learn==1.0.2
|
82 |
+
scipy==1.8.0
|
83 |
+
six==1.16.0
|
84 |
+
tabulate==0.8.9
|
85 |
+
threadpoolctl==3.1.0
|
86 |
+
tifffile==2022.5.4
|
87 |
+
tqdm==4.64.0
|
88 |
+
typing_extensions==4.1.1
|
89 |
+
urllib3==1.26.9
|
90 |
+
Werkzeug==2.1.2
|
91 |
+
wrapt==1.14.1
|
92 |
+
yacs==0.1.8
|
93 |
+
yarl==1.6.3
|
sup-mat/driving.mp4
ADDED
Binary file (101 kB). View file
|
|
sup-mat/driving.png
ADDED
![]() |
sup-mat/source.png
ADDED
![]() |
sup-mat/source_for_video.png
ADDED
![]() |
sync_batchnorm/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : __init__.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
12 |
+
from .replicate import DataParallelWithCallback, patch_replication_callback
|
sync_batchnorm/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (371 Bytes). View file
|
|
sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc
ADDED
Binary file (12.9 kB). View file
|
|
sync_batchnorm/__pycache__/comm.cpython-38.pyc
ADDED
Binary file (4.81 kB). View file
|
|
sync_batchnorm/__pycache__/replicate.cpython-38.pyc
ADDED
Binary file (3.46 kB). View file
|
|
sync_batchnorm/batchnorm.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import collections
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
17 |
+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
18 |
+
|
19 |
+
from .comm import SyncMaster
|
20 |
+
|
21 |
+
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
22 |
+
|
23 |
+
|
24 |
+
def _sum_ft(tensor):
|
25 |
+
"""sum over the first and last dimention"""
|
26 |
+
return tensor.sum(dim=0).sum(dim=-1)
|
27 |
+
|
28 |
+
|
29 |
+
def _unsqueeze_ft(tensor):
|
30 |
+
"""add new dementions at the front and the tail"""
|
31 |
+
return tensor.unsqueeze(0).unsqueeze(-1)
|
32 |
+
|
33 |
+
|
34 |
+
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
35 |
+
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
36 |
+
|
37 |
+
|
38 |
+
class _SynchronizedBatchNorm(_BatchNorm):
|
39 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
|
40 |
+
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
41 |
+
|
42 |
+
self._sync_master = SyncMaster(self._data_parallel_master)
|
43 |
+
|
44 |
+
self._is_parallel = False
|
45 |
+
self._parallel_id = None
|
46 |
+
self._slave_pipe = None
|
47 |
+
|
48 |
+
def forward(self, input):
|
49 |
+
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
50 |
+
if not (self._is_parallel and self.training):
|
51 |
+
return F.batch_norm(
|
52 |
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
53 |
+
self.training, self.momentum, self.eps)
|
54 |
+
|
55 |
+
# Resize the input to (B, C, -1).
|
56 |
+
input_shape = input.size()
|
57 |
+
input = input.view(input.size(0), self.num_features, -1)
|
58 |
+
|
59 |
+
# Compute the sum and square-sum.
|
60 |
+
sum_size = input.size(0) * input.size(2)
|
61 |
+
input_sum = _sum_ft(input)
|
62 |
+
input_ssum = _sum_ft(input ** 2)
|
63 |
+
|
64 |
+
# Reduce-and-broadcast the statistics.
|
65 |
+
if self._parallel_id == 0:
|
66 |
+
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
67 |
+
else:
|
68 |
+
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
69 |
+
|
70 |
+
# Compute the output.
|
71 |
+
if self.affine:
|
72 |
+
# MJY:: Fuse the multiplication for speed.
|
73 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
74 |
+
else:
|
75 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
76 |
+
|
77 |
+
# Reshape it.
|
78 |
+
return output.view(input_shape)
|
79 |
+
|
80 |
+
def __data_parallel_replicate__(self, ctx, copy_id):
|
81 |
+
self._is_parallel = True
|
82 |
+
self._parallel_id = copy_id
|
83 |
+
|
84 |
+
# parallel_id == 0 means master device.
|
85 |
+
if self._parallel_id == 0:
|
86 |
+
ctx.sync_master = self._sync_master
|
87 |
+
else:
|
88 |
+
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
89 |
+
|
90 |
+
def _data_parallel_master(self, intermediates):
|
91 |
+
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
92 |
+
|
93 |
+
# Always using same "device order" makes the ReduceAdd operation faster.
|
94 |
+
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
95 |
+
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
96 |
+
|
97 |
+
to_reduce = [i[1][:2] for i in intermediates]
|
98 |
+
to_reduce = [j for i in to_reduce for j in i] # flatten
|
99 |
+
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
100 |
+
|
101 |
+
sum_size = sum([i[1].sum_size for i in intermediates])
|
102 |
+
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
103 |
+
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
104 |
+
|
105 |
+
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
106 |
+
|
107 |
+
outputs = []
|
108 |
+
for i, rec in enumerate(intermediates):
|
109 |
+
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
110 |
+
|
111 |
+
return outputs
|
112 |
+
|
113 |
+
def _compute_mean_std(self, sum_, ssum, size):
|
114 |
+
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
115 |
+
also maintains the moving average on the master device."""
|
116 |
+
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
117 |
+
mean = sum_ / size
|
118 |
+
sumvar = ssum - sum_ * mean
|
119 |
+
unbias_var = sumvar / (size - 1)
|
120 |
+
bias_var = sumvar / size
|
121 |
+
|
122 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
123 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
124 |
+
|
125 |
+
return mean, bias_var.clamp(self.eps) ** -0.5
|
126 |
+
|
127 |
+
|
128 |
+
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
129 |
+
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
130 |
+
mini-batch.
|
131 |
+
|
132 |
+
.. math::
|
133 |
+
|
134 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
135 |
+
|
136 |
+
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
137 |
+
standard-deviation are reduced across all devices during training.
|
138 |
+
|
139 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
140 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
141 |
+
the statistics only on that device, which accelerated the computation and
|
142 |
+
is also easy to implement, but the statistics might be inaccurate.
|
143 |
+
Instead, in this synchronized version, the statistics will be computed
|
144 |
+
over all training samples distributed on multiple devices.
|
145 |
+
|
146 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
147 |
+
as the built-in PyTorch implementation.
|
148 |
+
|
149 |
+
The mean and standard-deviation are calculated per-dimension over
|
150 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
151 |
+
of size C (where C is the input size).
|
152 |
+
|
153 |
+
During training, this layer keeps a running estimate of its computed mean
|
154 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
155 |
+
|
156 |
+
During evaluation, this running mean/variance is used for normalization.
|
157 |
+
|
158 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
159 |
+
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
160 |
+
|
161 |
+
Args:
|
162 |
+
num_features: num_features from an expected input of size
|
163 |
+
`batch_size x num_features [x width]`
|
164 |
+
eps: a value added to the denominator for numerical stability.
|
165 |
+
Default: 1e-5
|
166 |
+
momentum: the value used for the running_mean and running_var
|
167 |
+
computation. Default: 0.1
|
168 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
169 |
+
affine parameters. Default: ``True``
|
170 |
+
|
171 |
+
Shape:
|
172 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
173 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
174 |
+
|
175 |
+
Examples:
|
176 |
+
>>> # With Learnable Parameters
|
177 |
+
>>> m = SynchronizedBatchNorm1d(100)
|
178 |
+
>>> # Without Learnable Parameters
|
179 |
+
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
180 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
181 |
+
>>> output = m(input)
|
182 |
+
"""
|
183 |
+
|
184 |
+
def _check_input_dim(self, input):
|
185 |
+
if input.dim() != 2 and input.dim() != 3:
|
186 |
+
raise ValueError('expected 2D or 3D input (got {}D input)'
|
187 |
+
.format(input.dim()))
|
188 |
+
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
189 |
+
|
190 |
+
|
191 |
+
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
192 |
+
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
193 |
+
of 3d inputs
|
194 |
+
|
195 |
+
.. math::
|
196 |
+
|
197 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
198 |
+
|
199 |
+
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
200 |
+
standard-deviation are reduced across all devices during training.
|
201 |
+
|
202 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
203 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
204 |
+
the statistics only on that device, which accelerated the computation and
|
205 |
+
is also easy to implement, but the statistics might be inaccurate.
|
206 |
+
Instead, in this synchronized version, the statistics will be computed
|
207 |
+
over all training samples distributed on multiple devices.
|
208 |
+
|
209 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
210 |
+
as the built-in PyTorch implementation.
|
211 |
+
|
212 |
+
The mean and standard-deviation are calculated per-dimension over
|
213 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
214 |
+
of size C (where C is the input size).
|
215 |
+
|
216 |
+
During training, this layer keeps a running estimate of its computed mean
|
217 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
218 |
+
|
219 |
+
During evaluation, this running mean/variance is used for normalization.
|
220 |
+
|
221 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
222 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
223 |
+
|
224 |
+
Args:
|
225 |
+
num_features: num_features from an expected input of
|
226 |
+
size batch_size x num_features x height x width
|
227 |
+
eps: a value added to the denominator for numerical stability.
|
228 |
+
Default: 1e-5
|
229 |
+
momentum: the value used for the running_mean and running_var
|
230 |
+
computation. Default: 0.1
|
231 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
232 |
+
affine parameters. Default: ``True``
|
233 |
+
|
234 |
+
Shape:
|
235 |
+
- Input: :math:`(N, C, H, W)`
|
236 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
237 |
+
|
238 |
+
Examples:
|
239 |
+
>>> # With Learnable Parameters
|
240 |
+
>>> m = SynchronizedBatchNorm2d(100)
|
241 |
+
>>> # Without Learnable Parameters
|
242 |
+
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
243 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
244 |
+
>>> output = m(input)
|
245 |
+
"""
|
246 |
+
|
247 |
+
def _check_input_dim(self, input):
|
248 |
+
if input.dim() != 4:
|
249 |
+
raise ValueError('expected 4D input (got {}D input)'
|
250 |
+
.format(input.dim()))
|
251 |
+
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
252 |
+
|
253 |
+
|
254 |
+
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
255 |
+
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
256 |
+
of 4d inputs
|
257 |
+
|
258 |
+
.. math::
|
259 |
+
|
260 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
261 |
+
|
262 |
+
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
263 |
+
standard-deviation are reduced across all devices during training.
|
264 |
+
|
265 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
266 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
267 |
+
the statistics only on that device, which accelerated the computation and
|
268 |
+
is also easy to implement, but the statistics might be inaccurate.
|
269 |
+
Instead, in this synchronized version, the statistics will be computed
|
270 |
+
over all training samples distributed on multiple devices.
|
271 |
+
|
272 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
273 |
+
as the built-in PyTorch implementation.
|
274 |
+
|
275 |
+
The mean and standard-deviation are calculated per-dimension over
|
276 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
277 |
+
of size C (where C is the input size).
|
278 |
+
|
279 |
+
During training, this layer keeps a running estimate of its computed mean
|
280 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
281 |
+
|
282 |
+
During evaluation, this running mean/variance is used for normalization.
|
283 |
+
|
284 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
285 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
286 |
+
or Spatio-temporal BatchNorm
|
287 |
+
|
288 |
+
Args:
|
289 |
+
num_features: num_features from an expected input of
|
290 |
+
size batch_size x num_features x depth x height x width
|
291 |
+
eps: a value added to the denominator for numerical stability.
|
292 |
+
Default: 1e-5
|
293 |
+
momentum: the value used for the running_mean and running_var
|
294 |
+
computation. Default: 0.1
|
295 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
296 |
+
affine parameters. Default: ``True``
|
297 |
+
|
298 |
+
Shape:
|
299 |
+
- Input: :math:`(N, C, D, H, W)`
|
300 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
301 |
+
|
302 |
+
Examples:
|
303 |
+
>>> # With Learnable Parameters
|
304 |
+
>>> m = SynchronizedBatchNorm3d(100)
|
305 |
+
>>> # Without Learnable Parameters
|
306 |
+
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
307 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
308 |
+
>>> output = m(input)
|
309 |
+
"""
|
310 |
+
|
311 |
+
def _check_input_dim(self, input):
|
312 |
+
if input.dim() != 5:
|
313 |
+
raise ValueError('expected 5D input (got {}D input)'
|
314 |
+
.format(input.dim()))
|
315 |
+
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
sync_batchnorm/comm.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : comm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import queue
|
12 |
+
import collections
|
13 |
+
import threading
|
14 |
+
|
15 |
+
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
16 |
+
|
17 |
+
|
18 |
+
class FutureResult(object):
|
19 |
+
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
self._result = None
|
23 |
+
self._lock = threading.Lock()
|
24 |
+
self._cond = threading.Condition(self._lock)
|
25 |
+
|
26 |
+
def put(self, result):
|
27 |
+
with self._lock:
|
28 |
+
assert self._result is None, 'Previous result has\'t been fetched.'
|
29 |
+
self._result = result
|
30 |
+
self._cond.notify()
|
31 |
+
|
32 |
+
def get(self):
|
33 |
+
with self._lock:
|
34 |
+
if self._result is None:
|
35 |
+
self._cond.wait()
|
36 |
+
|
37 |
+
res = self._result
|
38 |
+
self._result = None
|
39 |
+
return res
|
40 |
+
|
41 |
+
|
42 |
+
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
43 |
+
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
44 |
+
|
45 |
+
|
46 |
+
class SlavePipe(_SlavePipeBase):
|
47 |
+
"""Pipe for master-slave communication."""
|
48 |
+
|
49 |
+
def run_slave(self, msg):
|
50 |
+
self.queue.put((self.identifier, msg))
|
51 |
+
ret = self.result.get()
|
52 |
+
self.queue.put(True)
|
53 |
+
return ret
|
54 |
+
|
55 |
+
|
56 |
+
class SyncMaster(object):
|
57 |
+
"""An abstract `SyncMaster` object.
|
58 |
+
|
59 |
+
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
60 |
+
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
61 |
+
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
62 |
+
and passed to a registered callback.
|
63 |
+
- After receiving the messages, the master device should gather the information and determine to message passed
|
64 |
+
back to each slave devices.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, master_callback):
|
68 |
+
"""
|
69 |
+
|
70 |
+
Args:
|
71 |
+
master_callback: a callback to be invoked after having collected messages from slave devices.
|
72 |
+
"""
|
73 |
+
self._master_callback = master_callback
|
74 |
+
self._queue = queue.Queue()
|
75 |
+
self._registry = collections.OrderedDict()
|
76 |
+
self._activated = False
|
77 |
+
|
78 |
+
def __getstate__(self):
|
79 |
+
return {'master_callback': self._master_callback}
|
80 |
+
|
81 |
+
def __setstate__(self, state):
|
82 |
+
self.__init__(state['master_callback'])
|
83 |
+
|
84 |
+
def register_slave(self, identifier):
|
85 |
+
"""
|
86 |
+
Register an slave device.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
identifier: an identifier, usually is the device id.
|
90 |
+
|
91 |
+
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
92 |
+
|
93 |
+
"""
|
94 |
+
if self._activated:
|
95 |
+
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
96 |
+
self._activated = False
|
97 |
+
self._registry.clear()
|
98 |
+
future = FutureResult()
|
99 |
+
self._registry[identifier] = _MasterRegistry(future)
|
100 |
+
return SlavePipe(identifier, self._queue, future)
|
101 |
+
|
102 |
+
def run_master(self, master_msg):
|
103 |
+
"""
|
104 |
+
Main entry for the master device in each forward pass.
|
105 |
+
The messages were first collected from each devices (including the master device), and then
|
106 |
+
an callback will be invoked to compute the message to be sent back to each devices
|
107 |
+
(including the master device).
|
108 |
+
|
109 |
+
Args:
|
110 |
+
master_msg: the message that the master want to send to itself. This will be placed as the first
|
111 |
+
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
112 |
+
|
113 |
+
Returns: the message to be sent back to the master device.
|
114 |
+
|
115 |
+
"""
|
116 |
+
self._activated = True
|
117 |
+
|
118 |
+
intermediates = [(0, master_msg)]
|
119 |
+
for i in range(self.nr_slaves):
|
120 |
+
intermediates.append(self._queue.get())
|
121 |
+
|
122 |
+
results = self._master_callback(intermediates)
|
123 |
+
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
124 |
+
|
125 |
+
for i, res in results:
|
126 |
+
if i == 0:
|
127 |
+
continue
|
128 |
+
self._registry[i].result.put(res)
|
129 |
+
|
130 |
+
for i in range(self.nr_slaves):
|
131 |
+
assert self._queue.get() is True
|
132 |
+
|
133 |
+
return results[0][1]
|
134 |
+
|
135 |
+
@property
|
136 |
+
def nr_slaves(self):
|
137 |
+
return len(self._registry)
|
sync_batchnorm/replicate.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : replicate.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import functools
|
12 |
+
|
13 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
'CallbackContext',
|
17 |
+
'execute_replication_callbacks',
|
18 |
+
'DataParallelWithCallback',
|
19 |
+
'patch_replication_callback'
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
class CallbackContext(object):
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
def execute_replication_callbacks(modules):
|
28 |
+
"""
|
29 |
+
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
30 |
+
|
31 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
32 |
+
|
33 |
+
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
34 |
+
(shared among multiple copies of this module on different devices).
|
35 |
+
Through this context, different copies can share some information.
|
36 |
+
|
37 |
+
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
38 |
+
of any slave copies.
|
39 |
+
"""
|
40 |
+
master_copy = modules[0]
|
41 |
+
nr_modules = len(list(master_copy.modules()))
|
42 |
+
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
43 |
+
|
44 |
+
for i, module in enumerate(modules):
|
45 |
+
for j, m in enumerate(module.modules()):
|
46 |
+
if hasattr(m, '__data_parallel_replicate__'):
|
47 |
+
m.__data_parallel_replicate__(ctxs[j], i)
|
48 |
+
|
49 |
+
|
50 |
+
class DataParallelWithCallback(DataParallel):
|
51 |
+
"""
|
52 |
+
Data Parallel with a replication callback.
|
53 |
+
|
54 |
+
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
55 |
+
original `replicate` function.
|
56 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
57 |
+
|
58 |
+
Examples:
|
59 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
60 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
61 |
+
# sync_bn.__data_parallel_replicate__ will be invoked.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def replicate(self, module, device_ids):
|
65 |
+
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
66 |
+
execute_replication_callbacks(modules)
|
67 |
+
return modules
|
68 |
+
|
69 |
+
|
70 |
+
def patch_replication_callback(data_parallel):
|
71 |
+
"""
|
72 |
+
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
73 |
+
Useful when you have customized `DataParallel` implementation.
|
74 |
+
|
75 |
+
Examples:
|
76 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
77 |
+
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
78 |
+
> patch_replication_callback(sync_bn)
|
79 |
+
# this is equivalent to
|
80 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
81 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
82 |
+
"""
|
83 |
+
|
84 |
+
assert isinstance(data_parallel, DataParallel)
|
85 |
+
|
86 |
+
old_replicate = data_parallel.replicate
|
87 |
+
|
88 |
+
@functools.wraps(old_replicate)
|
89 |
+
def new_replicate(module, device_ids):
|
90 |
+
modules = old_replicate(module, device_ids)
|
91 |
+
execute_replication_callbacks(modules)
|
92 |
+
return modules
|
93 |
+
|
94 |
+
data_parallel.replicate = new_replicate
|
sync_batchnorm/unittest.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : unittest.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import unittest
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from torch.autograd import Variable
|
15 |
+
|
16 |
+
|
17 |
+
def as_numpy(v):
|
18 |
+
if isinstance(v, Variable):
|
19 |
+
v = v.data
|
20 |
+
return v.cpu().numpy()
|
21 |
+
|
22 |
+
|
23 |
+
class TorchTestCase(unittest.TestCase):
|
24 |
+
def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
|
25 |
+
npa, npb = as_numpy(a), as_numpy(b)
|
26 |
+
self.assertTrue(
|
27 |
+
np.allclose(npa, npb, atol=atol),
|
28 |
+
'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
|
29 |
+
)
|