Spaces:
Runtime error
Runtime error
AlekseyKorshuk
commited on
Commit
·
f844f44
1
Parent(s):
7ea1f4a
feat: updates
Browse files- LICENSE +21 -0
- app.py +100 -0
- assets/driving.mp4 +0 -0
- assets/driving_ted.mp4 +0 -0
- assets/source.png +0 -0
- assets/source_ted.png +0 -0
- augmentation.py +344 -0
- checkpoints/README.md +1 -0
- config/mgif-256.yaml +75 -0
- config/taichi-256.yaml +134 -0
- config/ted-384.yaml +73 -0
- config/vox-256.yaml +74 -0
- demo.ipynb +0 -0
- demo.py +176 -0
- frames_dataset.py +173 -0
- generated.mp4 +0 -0
- logger.py +212 -0
- modules/avd_network.py +65 -0
- modules/bg_motion_predictor.py +24 -0
- modules/dense_motion.py +164 -0
- modules/inpainting_network.py +130 -0
- modules/keypoint_detector.py +27 -0
- modules/model.py +182 -0
- modules/util.py +349 -0
- reconstruction.py +69 -0
- requirements.txt +26 -0
- run.py +89 -0
- train.py +94 -0
- train_avd.py +91 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 yoyo-nb
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
app.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import imageio
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import matplotlib.animation as animation
|
6 |
+
from skimage.transform import resize
|
7 |
+
import warnings
|
8 |
+
import os
|
9 |
+
from demo import make_animation
|
10 |
+
from skimage import img_as_ubyte
|
11 |
+
from demo import load_checkpoints
|
12 |
+
import gradio
|
13 |
+
|
14 |
+
|
15 |
+
def inference(source_image_path='./assets/source.png', driving_video_path='./assets/driving.mp4', dataset_name="vox"):
|
16 |
+
# edit the config
|
17 |
+
device = torch.device('cpu')
|
18 |
+
# dataset_name = 'vox' # ['vox', 'taichi', 'ted', 'mgif']
|
19 |
+
# source_image_path = './assets/source.png'
|
20 |
+
# driving_video_path = './assets/driving.mp4'
|
21 |
+
output_video_path = './generated.mp4'
|
22 |
+
|
23 |
+
pixel = 256 # for vox, taichi and mgif, the resolution is 256*256
|
24 |
+
if (dataset_name == 'ted'): # for ted, the resolution is 384*384
|
25 |
+
pixel = 384
|
26 |
+
config_path = f'config/{dataset_name}-{pixel}.yaml'
|
27 |
+
checkpoint_path = f'checkpoints/{dataset_name}.pth.tar'
|
28 |
+
predict_mode = 'relative' # ['standard', 'relative', 'avd']
|
29 |
+
|
30 |
+
warnings.filterwarnings("ignore")
|
31 |
+
|
32 |
+
source_image = imageio.imread(source_image_path)
|
33 |
+
reader = imageio.get_reader(driving_video_path)
|
34 |
+
|
35 |
+
source_image = resize(source_image, (pixel, pixel))[..., :3]
|
36 |
+
|
37 |
+
fps = reader.get_meta_data()['fps']
|
38 |
+
driving_video = []
|
39 |
+
try:
|
40 |
+
for im in reader:
|
41 |
+
driving_video.append(im)
|
42 |
+
except RuntimeError:
|
43 |
+
pass
|
44 |
+
reader.close()
|
45 |
+
|
46 |
+
driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video]
|
47 |
+
|
48 |
+
# driving_video = driving_video[:10]
|
49 |
+
|
50 |
+
def display(source, driving, generated=None) -> animation.ArtistAnimation:
|
51 |
+
fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))
|
52 |
+
|
53 |
+
ims = []
|
54 |
+
for i in range(len(driving)):
|
55 |
+
cols = [source]
|
56 |
+
cols.append(driving[i])
|
57 |
+
if generated is not None:
|
58 |
+
cols.append(generated[i])
|
59 |
+
im = plt.imshow(np.concatenate(cols, axis=1), animated=True)
|
60 |
+
plt.axis('off')
|
61 |
+
ims.append([im])
|
62 |
+
|
63 |
+
ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)
|
64 |
+
# plt.show()
|
65 |
+
plt.close()
|
66 |
+
return ani
|
67 |
+
|
68 |
+
inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path=config_path,
|
69 |
+
checkpoint_path=checkpoint_path,
|
70 |
+
device=device)
|
71 |
+
|
72 |
+
predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network,
|
73 |
+
avd_network, device=device, mode=predict_mode)
|
74 |
+
|
75 |
+
# save resulting video
|
76 |
+
imageio.mimsave(output_video_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)
|
77 |
+
|
78 |
+
ani = display(source_image, driving_video, predictions)
|
79 |
+
ani.save('animation.mp4', writer='imagemagick', fps=60)
|
80 |
+
return 'animation.mp4'
|
81 |
+
|
82 |
+
|
83 |
+
demo = gradio.Interface(
|
84 |
+
fn=inference,
|
85 |
+
inputs=[
|
86 |
+
gradio.inputs.Image(type="filepath", label="Input image"),
|
87 |
+
gradio.inputs.Video(label="Input video"),
|
88 |
+
gradio.inputs.Dropdown(['vox', 'taichi', 'ted', 'mgif'], type="value", default="vox", label="Model",
|
89 |
+
optional=False),
|
90 |
+
|
91 |
+
],
|
92 |
+
outputs=["video"],
|
93 |
+
examples=[
|
94 |
+
['./assets/source.png', './assets/driving.mp4', "vox"],
|
95 |
+
['./assets/source_ted.png', './assets/driving_ted.mp4', "ted"],
|
96 |
+
],
|
97 |
+
)
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
demo.launch()
|
assets/driving.mp4
ADDED
Binary file (556 kB). View file
|
|
assets/driving_ted.mp4
ADDED
Binary file (206 kB). View file
|
|
assets/source.png
ADDED
assets/source_ted.png
ADDED
augmentation.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code from https://github.com/hassony2/torch_videovision
|
3 |
+
"""
|
4 |
+
|
5 |
+
import numbers
|
6 |
+
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
import PIL
|
10 |
+
|
11 |
+
from skimage.transform import resize, rotate
|
12 |
+
import torchvision
|
13 |
+
|
14 |
+
import warnings
|
15 |
+
|
16 |
+
from skimage import img_as_ubyte, img_as_float
|
17 |
+
|
18 |
+
|
19 |
+
def crop_clip(clip, min_h, min_w, h, w):
|
20 |
+
if isinstance(clip[0], np.ndarray):
|
21 |
+
cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
|
22 |
+
|
23 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
24 |
+
cropped = [
|
25 |
+
img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
|
26 |
+
]
|
27 |
+
else:
|
28 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
29 |
+
'but got list of {0}'.format(type(clip[0])))
|
30 |
+
return cropped
|
31 |
+
|
32 |
+
|
33 |
+
def pad_clip(clip, h, w):
|
34 |
+
im_h, im_w = clip[0].shape[:2]
|
35 |
+
pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
|
36 |
+
pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)
|
37 |
+
|
38 |
+
return np.pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')
|
39 |
+
|
40 |
+
|
41 |
+
def resize_clip(clip, size, interpolation='bilinear'):
|
42 |
+
if isinstance(clip[0], np.ndarray):
|
43 |
+
if isinstance(size, numbers.Number):
|
44 |
+
im_h, im_w, im_c = clip[0].shape
|
45 |
+
# Min spatial dim already matches minimal size
|
46 |
+
if (im_w <= im_h and im_w == size) or (im_h <= im_w
|
47 |
+
and im_h == size):
|
48 |
+
return clip
|
49 |
+
new_h, new_w = get_resize_sizes(im_h, im_w, size)
|
50 |
+
size = (new_w, new_h)
|
51 |
+
else:
|
52 |
+
size = size[1], size[0]
|
53 |
+
|
54 |
+
scaled = [
|
55 |
+
resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
|
56 |
+
mode='constant', anti_aliasing=True) for img in clip
|
57 |
+
]
|
58 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
59 |
+
if isinstance(size, numbers.Number):
|
60 |
+
im_w, im_h = clip[0].size
|
61 |
+
# Min spatial dim already matches minimal size
|
62 |
+
if (im_w <= im_h and im_w == size) or (im_h <= im_w
|
63 |
+
and im_h == size):
|
64 |
+
return clip
|
65 |
+
new_h, new_w = get_resize_sizes(im_h, im_w, size)
|
66 |
+
size = (new_w, new_h)
|
67 |
+
else:
|
68 |
+
size = size[1], size[0]
|
69 |
+
if interpolation == 'bilinear':
|
70 |
+
pil_inter = PIL.Image.NEAREST
|
71 |
+
else:
|
72 |
+
pil_inter = PIL.Image.BILINEAR
|
73 |
+
scaled = [img.resize(size, pil_inter) for img in clip]
|
74 |
+
else:
|
75 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
76 |
+
'but got list of {0}'.format(type(clip[0])))
|
77 |
+
return scaled
|
78 |
+
|
79 |
+
|
80 |
+
def get_resize_sizes(im_h, im_w, size):
|
81 |
+
if im_w < im_h:
|
82 |
+
ow = size
|
83 |
+
oh = int(size * im_h / im_w)
|
84 |
+
else:
|
85 |
+
oh = size
|
86 |
+
ow = int(size * im_w / im_h)
|
87 |
+
return oh, ow
|
88 |
+
|
89 |
+
|
90 |
+
class RandomFlip(object):
|
91 |
+
def __init__(self, time_flip=False, horizontal_flip=False):
|
92 |
+
self.time_flip = time_flip
|
93 |
+
self.horizontal_flip = horizontal_flip
|
94 |
+
|
95 |
+
def __call__(self, clip):
|
96 |
+
if random.random() < 0.5 and self.time_flip:
|
97 |
+
return clip[::-1]
|
98 |
+
if random.random() < 0.5 and self.horizontal_flip:
|
99 |
+
return [np.fliplr(img) for img in clip]
|
100 |
+
|
101 |
+
return clip
|
102 |
+
|
103 |
+
|
104 |
+
class RandomResize(object):
|
105 |
+
"""Resizes a list of (H x W x C) numpy.ndarray to the final size
|
106 |
+
The larger the original image is, the more times it takes to
|
107 |
+
interpolate
|
108 |
+
Args:
|
109 |
+
interpolation (str): Can be one of 'nearest', 'bilinear'
|
110 |
+
defaults to nearest
|
111 |
+
size (tuple): (widht, height)
|
112 |
+
"""
|
113 |
+
|
114 |
+
def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
|
115 |
+
self.ratio = ratio
|
116 |
+
self.interpolation = interpolation
|
117 |
+
|
118 |
+
def __call__(self, clip):
|
119 |
+
scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
|
120 |
+
|
121 |
+
if isinstance(clip[0], np.ndarray):
|
122 |
+
im_h, im_w, im_c = clip[0].shape
|
123 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
124 |
+
im_w, im_h = clip[0].size
|
125 |
+
|
126 |
+
new_w = int(im_w * scaling_factor)
|
127 |
+
new_h = int(im_h * scaling_factor)
|
128 |
+
new_size = (new_w, new_h)
|
129 |
+
resized = resize_clip(
|
130 |
+
clip, new_size, interpolation=self.interpolation)
|
131 |
+
|
132 |
+
return resized
|
133 |
+
|
134 |
+
|
135 |
+
class RandomCrop(object):
|
136 |
+
"""Extract random crop at the same location for a list of videos
|
137 |
+
Args:
|
138 |
+
size (sequence or int): Desired output size for the
|
139 |
+
crop in format (h, w)
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(self, size):
|
143 |
+
if isinstance(size, numbers.Number):
|
144 |
+
size = (size, size)
|
145 |
+
|
146 |
+
self.size = size
|
147 |
+
|
148 |
+
def __call__(self, clip):
|
149 |
+
"""
|
150 |
+
Args:
|
151 |
+
img (PIL.Image or numpy.ndarray): List of videos to be cropped
|
152 |
+
in format (h, w, c) in numpy.ndarray
|
153 |
+
Returns:
|
154 |
+
PIL.Image or numpy.ndarray: Cropped list of videos
|
155 |
+
"""
|
156 |
+
h, w = self.size
|
157 |
+
if isinstance(clip[0], np.ndarray):
|
158 |
+
im_h, im_w, im_c = clip[0].shape
|
159 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
160 |
+
im_w, im_h = clip[0].size
|
161 |
+
else:
|
162 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
163 |
+
'but got list of {0}'.format(type(clip[0])))
|
164 |
+
|
165 |
+
clip = pad_clip(clip, h, w)
|
166 |
+
im_h, im_w = clip.shape[1:3]
|
167 |
+
x1 = 0 if h == im_h else random.randint(0, im_w - w)
|
168 |
+
y1 = 0 if w == im_w else random.randint(0, im_h - h)
|
169 |
+
cropped = crop_clip(clip, y1, x1, h, w)
|
170 |
+
|
171 |
+
return cropped
|
172 |
+
|
173 |
+
|
174 |
+
class RandomRotation(object):
|
175 |
+
"""Rotate entire clip randomly by a random angle within
|
176 |
+
given bounds
|
177 |
+
Args:
|
178 |
+
degrees (sequence or int): Range of degrees to select from
|
179 |
+
If degrees is a number instead of sequence like (min, max),
|
180 |
+
the range of degrees, will be (-degrees, +degrees).
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, degrees):
|
184 |
+
if isinstance(degrees, numbers.Number):
|
185 |
+
if degrees < 0:
|
186 |
+
raise ValueError('If degrees is a single number,'
|
187 |
+
'must be positive')
|
188 |
+
degrees = (-degrees, degrees)
|
189 |
+
else:
|
190 |
+
if len(degrees) != 2:
|
191 |
+
raise ValueError('If degrees is a sequence,'
|
192 |
+
'it must be of len 2.')
|
193 |
+
|
194 |
+
self.degrees = degrees
|
195 |
+
|
196 |
+
def __call__(self, clip):
|
197 |
+
"""
|
198 |
+
Args:
|
199 |
+
img (PIL.Image or numpy.ndarray): List of videos to be cropped
|
200 |
+
in format (h, w, c) in numpy.ndarray
|
201 |
+
Returns:
|
202 |
+
PIL.Image or numpy.ndarray: Cropped list of videos
|
203 |
+
"""
|
204 |
+
angle = random.uniform(self.degrees[0], self.degrees[1])
|
205 |
+
if isinstance(clip[0], np.ndarray):
|
206 |
+
rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
|
207 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
208 |
+
rotated = [img.rotate(angle) for img in clip]
|
209 |
+
else:
|
210 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
211 |
+
'but got list of {0}'.format(type(clip[0])))
|
212 |
+
|
213 |
+
return rotated
|
214 |
+
|
215 |
+
|
216 |
+
class ColorJitter(object):
|
217 |
+
"""Randomly change the brightness, contrast and saturation and hue of the clip
|
218 |
+
Args:
|
219 |
+
brightness (float): How much to jitter brightness. brightness_factor
|
220 |
+
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
|
221 |
+
contrast (float): How much to jitter contrast. contrast_factor
|
222 |
+
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
|
223 |
+
saturation (float): How much to jitter saturation. saturation_factor
|
224 |
+
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
|
225 |
+
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
|
226 |
+
[-hue, hue]. Should be >=0 and <= 0.5.
|
227 |
+
"""
|
228 |
+
|
229 |
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
230 |
+
self.brightness = brightness
|
231 |
+
self.contrast = contrast
|
232 |
+
self.saturation = saturation
|
233 |
+
self.hue = hue
|
234 |
+
|
235 |
+
def get_params(self, brightness, contrast, saturation, hue):
|
236 |
+
if brightness > 0:
|
237 |
+
brightness_factor = random.uniform(
|
238 |
+
max(0, 1 - brightness), 1 + brightness)
|
239 |
+
else:
|
240 |
+
brightness_factor = None
|
241 |
+
|
242 |
+
if contrast > 0:
|
243 |
+
contrast_factor = random.uniform(
|
244 |
+
max(0, 1 - contrast), 1 + contrast)
|
245 |
+
else:
|
246 |
+
contrast_factor = None
|
247 |
+
|
248 |
+
if saturation > 0:
|
249 |
+
saturation_factor = random.uniform(
|
250 |
+
max(0, 1 - saturation), 1 + saturation)
|
251 |
+
else:
|
252 |
+
saturation_factor = None
|
253 |
+
|
254 |
+
if hue > 0:
|
255 |
+
hue_factor = random.uniform(-hue, hue)
|
256 |
+
else:
|
257 |
+
hue_factor = None
|
258 |
+
return brightness_factor, contrast_factor, saturation_factor, hue_factor
|
259 |
+
|
260 |
+
def __call__(self, clip):
|
261 |
+
"""
|
262 |
+
Args:
|
263 |
+
clip (list): list of PIL.Image
|
264 |
+
Returns:
|
265 |
+
list PIL.Image : list of transformed PIL.Image
|
266 |
+
"""
|
267 |
+
if isinstance(clip[0], np.ndarray):
|
268 |
+
brightness, contrast, saturation, hue = self.get_params(
|
269 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
270 |
+
|
271 |
+
# Create img transform function sequence
|
272 |
+
img_transforms = []
|
273 |
+
if brightness is not None:
|
274 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
|
275 |
+
if saturation is not None:
|
276 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
|
277 |
+
if hue is not None:
|
278 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
|
279 |
+
if contrast is not None:
|
280 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
|
281 |
+
random.shuffle(img_transforms)
|
282 |
+
img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,
|
283 |
+
img_as_float]
|
284 |
+
|
285 |
+
with warnings.catch_warnings():
|
286 |
+
warnings.simplefilter("ignore")
|
287 |
+
jittered_clip = []
|
288 |
+
for img in clip:
|
289 |
+
jittered_img = img
|
290 |
+
for func in img_transforms:
|
291 |
+
jittered_img = func(jittered_img)
|
292 |
+
jittered_clip.append(jittered_img.astype('float32'))
|
293 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
294 |
+
brightness, contrast, saturation, hue = self.get_params(
|
295 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
296 |
+
|
297 |
+
# Create img transform function sequence
|
298 |
+
img_transforms = []
|
299 |
+
if brightness is not None:
|
300 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
|
301 |
+
if saturation is not None:
|
302 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
|
303 |
+
if hue is not None:
|
304 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
|
305 |
+
if contrast is not None:
|
306 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
|
307 |
+
random.shuffle(img_transforms)
|
308 |
+
|
309 |
+
# Apply to all videos
|
310 |
+
jittered_clip = []
|
311 |
+
for img in clip:
|
312 |
+
for func in img_transforms:
|
313 |
+
jittered_img = func(img)
|
314 |
+
jittered_clip.append(jittered_img)
|
315 |
+
|
316 |
+
else:
|
317 |
+
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
318 |
+
'but got list of {0}'.format(type(clip[0])))
|
319 |
+
return jittered_clip
|
320 |
+
|
321 |
+
|
322 |
+
class AllAugmentationTransform:
|
323 |
+
def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None):
|
324 |
+
self.transforms = []
|
325 |
+
|
326 |
+
if flip_param is not None:
|
327 |
+
self.transforms.append(RandomFlip(**flip_param))
|
328 |
+
|
329 |
+
if rotation_param is not None:
|
330 |
+
self.transforms.append(RandomRotation(**rotation_param))
|
331 |
+
|
332 |
+
if resize_param is not None:
|
333 |
+
self.transforms.append(RandomResize(**resize_param))
|
334 |
+
|
335 |
+
if crop_param is not None:
|
336 |
+
self.transforms.append(RandomCrop(**crop_param))
|
337 |
+
|
338 |
+
if jitter_param is not None:
|
339 |
+
self.transforms.append(ColorJitter(**jitter_param))
|
340 |
+
|
341 |
+
def __call__(self, clip):
|
342 |
+
for t in self.transforms:
|
343 |
+
clip = t(clip)
|
344 |
+
return clip
|
checkpoints/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Checkpoints
|
config/mgif-256.yaml
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_params:
|
2 |
+
root_dir: ../moving-gif
|
3 |
+
frame_shape: null
|
4 |
+
id_sampling: False
|
5 |
+
augmentation_params:
|
6 |
+
flip_param:
|
7 |
+
horizontal_flip: True
|
8 |
+
time_flip: True
|
9 |
+
crop_param:
|
10 |
+
size: [256, 256]
|
11 |
+
resize_param:
|
12 |
+
ratio: [0.9, 1.1]
|
13 |
+
jitter_param:
|
14 |
+
hue: 0.5
|
15 |
+
|
16 |
+
model_params:
|
17 |
+
common_params:
|
18 |
+
num_tps: 10
|
19 |
+
num_channels: 3
|
20 |
+
bg: False
|
21 |
+
multi_mask: True
|
22 |
+
generator_params:
|
23 |
+
block_expansion: 64
|
24 |
+
max_features: 512
|
25 |
+
num_down_blocks: 3
|
26 |
+
dense_motion_params:
|
27 |
+
block_expansion: 64
|
28 |
+
max_features: 1024
|
29 |
+
num_blocks: 5
|
30 |
+
scale_factor: 0.25
|
31 |
+
avd_network_params:
|
32 |
+
id_bottle_size: 128
|
33 |
+
pose_bottle_size: 128
|
34 |
+
|
35 |
+
|
36 |
+
train_params:
|
37 |
+
num_epochs: 100
|
38 |
+
num_repeats: 50
|
39 |
+
epoch_milestones: [70, 90]
|
40 |
+
lr_generator: 2.0e-4
|
41 |
+
batch_size: 28
|
42 |
+
scales: [1, 0.5, 0.25, 0.125]
|
43 |
+
dataloader_workers: 12
|
44 |
+
checkpoint_freq: 50
|
45 |
+
dropout_epoch: 35
|
46 |
+
dropout_maxp: 0.5
|
47 |
+
dropout_startp: 0.2
|
48 |
+
dropout_inc_epoch: 10
|
49 |
+
bg_start: 0
|
50 |
+
transform_params:
|
51 |
+
sigma_affine: 0.05
|
52 |
+
sigma_tps: 0.005
|
53 |
+
points_tps: 5
|
54 |
+
loss_weights:
|
55 |
+
perceptual: [10, 10, 10, 10, 10]
|
56 |
+
equivariance_value: 10
|
57 |
+
warp_loss: 10
|
58 |
+
bg: 10
|
59 |
+
|
60 |
+
train_avd_params:
|
61 |
+
num_epochs: 100
|
62 |
+
num_repeats: 50
|
63 |
+
batch_size: 256
|
64 |
+
dataloader_workers: 24
|
65 |
+
checkpoint_freq: 10
|
66 |
+
epoch_milestones: [70, 90]
|
67 |
+
lr: 1.0e-3
|
68 |
+
lambda_shift: 1
|
69 |
+
lambda_affine: 1
|
70 |
+
random_scale: 0.25
|
71 |
+
|
72 |
+
visualizer_params:
|
73 |
+
kp_size: 5
|
74 |
+
draw_border: True
|
75 |
+
colormap: 'gist_rainbow'
|
config/taichi-256.yaml
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dataset parameters
|
2 |
+
# Each dataset should contain 2 folders train and test
|
3 |
+
# Each video can be represented as:
|
4 |
+
# - an image of concatenated frames
|
5 |
+
# - '.mp4' or '.gif'
|
6 |
+
# - folder with all frames from a specific video
|
7 |
+
# In case of Taichi. Same (youtube) video can be splitted in many parts (chunks). Each part has a following
|
8 |
+
# format (id)#other#info.mp4. For example '12335#adsbf.mp4' has an id 12335. In case of TaiChi id stands for youtube
|
9 |
+
# video id.
|
10 |
+
dataset_params:
|
11 |
+
# Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.
|
12 |
+
root_dir: ../taichi
|
13 |
+
# Image shape, needed for staked .png format.
|
14 |
+
frame_shape: null
|
15 |
+
# In case of TaiChi single video can be splitted in many chunks, or the maybe several videos for single person.
|
16 |
+
# In this case epoch can be a pass over different videos (if id_sampling=True) or over different chunks (if id_sampling=False)
|
17 |
+
# If the name of the video '12335#adsbf.mp4' the id is assumed to be 12335
|
18 |
+
id_sampling: True
|
19 |
+
# Augmentation parameters see augmentation.py for all posible augmentations
|
20 |
+
augmentation_params:
|
21 |
+
flip_param:
|
22 |
+
horizontal_flip: True
|
23 |
+
time_flip: True
|
24 |
+
jitter_param:
|
25 |
+
brightness: 0.1
|
26 |
+
contrast: 0.1
|
27 |
+
saturation: 0.1
|
28 |
+
hue: 0.1
|
29 |
+
|
30 |
+
# Defines model architecture
|
31 |
+
model_params:
|
32 |
+
common_params:
|
33 |
+
# Number of TPS transformation
|
34 |
+
num_tps: 10
|
35 |
+
# Number of channels per image
|
36 |
+
num_channels: 3
|
37 |
+
# Whether to estimate affine background transformation
|
38 |
+
bg: True
|
39 |
+
# Whether to estimate the multi-resolution occlusion masks
|
40 |
+
multi_mask: True
|
41 |
+
generator_params:
|
42 |
+
# Number of features mutliplier
|
43 |
+
block_expansion: 64
|
44 |
+
# Maximum allowed number of features
|
45 |
+
max_features: 512
|
46 |
+
# Number of downsampling blocks and Upsampling blocks.
|
47 |
+
num_down_blocks: 3
|
48 |
+
dense_motion_params:
|
49 |
+
# Number of features mutliplier
|
50 |
+
block_expansion: 64
|
51 |
+
# Maximum allowed number of features
|
52 |
+
max_features: 1024
|
53 |
+
# Number of block in Unet.
|
54 |
+
num_blocks: 5
|
55 |
+
# Optical flow is predicted on smaller images for better performance,
|
56 |
+
# scale_factor=0.25 means that 256x256 image will be resized to 64x64
|
57 |
+
scale_factor: 0.25
|
58 |
+
avd_network_params:
|
59 |
+
# Bottleneck for identity branch
|
60 |
+
id_bottle_size: 128
|
61 |
+
# Bottleneck for pose branch
|
62 |
+
pose_bottle_size: 128
|
63 |
+
|
64 |
+
# Parameters of training
|
65 |
+
train_params:
|
66 |
+
# Number of training epochs
|
67 |
+
num_epochs: 100
|
68 |
+
# For better i/o performance when number of videos is small number of epochs can be multiplied by this number.
|
69 |
+
# Thus effectivlly with num_repeats=100 each epoch is 100 times larger.
|
70 |
+
num_repeats: 150
|
71 |
+
# Drop learning rate by 10 times after this epochs
|
72 |
+
epoch_milestones: [70, 90]
|
73 |
+
# Initial learing rate for all modules
|
74 |
+
lr_generator: 2.0e-4
|
75 |
+
batch_size: 28
|
76 |
+
# Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256,
|
77 |
+
# than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32.
|
78 |
+
scales: [1, 0.5, 0.25, 0.125]
|
79 |
+
# Dataset preprocessing cpu workers
|
80 |
+
dataloader_workers: 12
|
81 |
+
# Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs.
|
82 |
+
checkpoint_freq: 50
|
83 |
+
# Parameters of dropout
|
84 |
+
# The first dropout_epoch training uses dropout operation
|
85 |
+
dropout_epoch: 35
|
86 |
+
# The probability P will linearly increase from dropout_startp to dropout_maxp in dropout_inc_epoch epochs
|
87 |
+
dropout_maxp: 0.7
|
88 |
+
dropout_startp: 0.0
|
89 |
+
dropout_inc_epoch: 10
|
90 |
+
# Estimate affine background transformation from the bg_start epoch.
|
91 |
+
bg_start: 0
|
92 |
+
# Parameters of random TPS transformation for equivariance loss
|
93 |
+
transform_params:
|
94 |
+
# Sigma for affine part
|
95 |
+
sigma_affine: 0.05
|
96 |
+
# Sigma for deformation part
|
97 |
+
sigma_tps: 0.005
|
98 |
+
# Number of point in the deformation grid
|
99 |
+
points_tps: 5
|
100 |
+
loss_weights:
|
101 |
+
# Weights for perceptual loss.
|
102 |
+
perceptual: [10, 10, 10, 10, 10]
|
103 |
+
# Weights for value equivariance.
|
104 |
+
equivariance_value: 10
|
105 |
+
# Weights for warp loss.
|
106 |
+
warp_loss: 10
|
107 |
+
# Weights for bg loss.
|
108 |
+
bg: 10
|
109 |
+
|
110 |
+
# Parameters of training (animation-via-disentanglement)
|
111 |
+
train_avd_params:
|
112 |
+
# Number of training epochs, visualization is produced after each epoch.
|
113 |
+
num_epochs: 100
|
114 |
+
# For better i/o performance when number of videos is small number of epochs can be multiplied by this number.
|
115 |
+
# Thus effectively with num_repeats=100 each epoch is 100 times larger.
|
116 |
+
num_repeats: 150
|
117 |
+
# Batch size.
|
118 |
+
batch_size: 256
|
119 |
+
# Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs.
|
120 |
+
checkpoint_freq: 10
|
121 |
+
# Dataset preprocessing cpu workers
|
122 |
+
dataloader_workers: 24
|
123 |
+
# Drop learning rate 10 times after this epochs
|
124 |
+
epoch_milestones: [70, 90]
|
125 |
+
# Initial learning rate
|
126 |
+
lr: 1.0e-3
|
127 |
+
# Weights for equivariance loss.
|
128 |
+
lambda_shift: 1
|
129 |
+
random_scale: 0.25
|
130 |
+
|
131 |
+
visualizer_params:
|
132 |
+
kp_size: 5
|
133 |
+
draw_border: True
|
134 |
+
colormap: 'gist_rainbow'
|
config/ted-384.yaml
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_params:
|
2 |
+
root_dir: ../TED384-v2
|
3 |
+
frame_shape: null
|
4 |
+
id_sampling: True
|
5 |
+
augmentation_params:
|
6 |
+
flip_param:
|
7 |
+
horizontal_flip: True
|
8 |
+
time_flip: True
|
9 |
+
jitter_param:
|
10 |
+
brightness: 0.1
|
11 |
+
contrast: 0.1
|
12 |
+
saturation: 0.1
|
13 |
+
hue: 0.1
|
14 |
+
|
15 |
+
model_params:
|
16 |
+
common_params:
|
17 |
+
num_tps: 10
|
18 |
+
num_channels: 3
|
19 |
+
bg: True
|
20 |
+
multi_mask: True
|
21 |
+
generator_params:
|
22 |
+
block_expansion: 64
|
23 |
+
max_features: 512
|
24 |
+
num_down_blocks: 3
|
25 |
+
dense_motion_params:
|
26 |
+
block_expansion: 64
|
27 |
+
max_features: 1024
|
28 |
+
num_blocks: 5
|
29 |
+
scale_factor: 0.25
|
30 |
+
avd_network_params:
|
31 |
+
id_bottle_size: 128
|
32 |
+
pose_bottle_size: 128
|
33 |
+
|
34 |
+
|
35 |
+
train_params:
|
36 |
+
num_epochs: 100
|
37 |
+
num_repeats: 150
|
38 |
+
epoch_milestones: [70, 90]
|
39 |
+
lr_generator: 2.0e-4
|
40 |
+
batch_size: 12
|
41 |
+
scales: [1, 0.5, 0.25, 0.125]
|
42 |
+
dataloader_workers: 6
|
43 |
+
checkpoint_freq: 50
|
44 |
+
dropout_epoch: 35
|
45 |
+
dropout_maxp: 0.5
|
46 |
+
dropout_startp: 0.0
|
47 |
+
dropout_inc_epoch: 10
|
48 |
+
bg_start: 0
|
49 |
+
transform_params:
|
50 |
+
sigma_affine: 0.05
|
51 |
+
sigma_tps: 0.005
|
52 |
+
points_tps: 5
|
53 |
+
loss_weights:
|
54 |
+
perceptual: [10, 10, 10, 10, 10]
|
55 |
+
equivariance_value: 10
|
56 |
+
warp_loss: 10
|
57 |
+
bg: 10
|
58 |
+
|
59 |
+
train_avd_params:
|
60 |
+
num_epochs: 30
|
61 |
+
num_repeats: 500
|
62 |
+
batch_size: 256
|
63 |
+
dataloader_workers: 24
|
64 |
+
checkpoint_freq: 10
|
65 |
+
epoch_milestones: [20, 25]
|
66 |
+
lr: 1.0e-3
|
67 |
+
lambda_shift: 1
|
68 |
+
random_scale: 0.25
|
69 |
+
|
70 |
+
visualizer_params:
|
71 |
+
kp_size: 5
|
72 |
+
draw_border: True
|
73 |
+
colormap: 'gist_rainbow'
|
config/vox-256.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_params:
|
2 |
+
root_dir: ../vox
|
3 |
+
frame_shape: null
|
4 |
+
id_sampling: True
|
5 |
+
augmentation_params:
|
6 |
+
flip_param:
|
7 |
+
horizontal_flip: True
|
8 |
+
time_flip: True
|
9 |
+
jitter_param:
|
10 |
+
brightness: 0.1
|
11 |
+
contrast: 0.1
|
12 |
+
saturation: 0.1
|
13 |
+
hue: 0.1
|
14 |
+
|
15 |
+
|
16 |
+
model_params:
|
17 |
+
common_params:
|
18 |
+
num_tps: 10
|
19 |
+
num_channels: 3
|
20 |
+
bg: True
|
21 |
+
multi_mask: True
|
22 |
+
generator_params:
|
23 |
+
block_expansion: 64
|
24 |
+
max_features: 512
|
25 |
+
num_down_blocks: 3
|
26 |
+
dense_motion_params:
|
27 |
+
block_expansion: 64
|
28 |
+
max_features: 1024
|
29 |
+
num_blocks: 5
|
30 |
+
scale_factor: 0.25
|
31 |
+
avd_network_params:
|
32 |
+
id_bottle_size: 128
|
33 |
+
pose_bottle_size: 128
|
34 |
+
|
35 |
+
|
36 |
+
train_params:
|
37 |
+
num_epochs: 100
|
38 |
+
num_repeats: 75
|
39 |
+
epoch_milestones: [70, 90]
|
40 |
+
lr_generator: 2.0e-4
|
41 |
+
batch_size: 28
|
42 |
+
scales: [1, 0.5, 0.25, 0.125]
|
43 |
+
dataloader_workers: 12
|
44 |
+
checkpoint_freq: 50
|
45 |
+
dropout_epoch: 35
|
46 |
+
dropout_maxp: 0.3
|
47 |
+
dropout_startp: 0.1
|
48 |
+
dropout_inc_epoch: 10
|
49 |
+
bg_start: 10
|
50 |
+
transform_params:
|
51 |
+
sigma_affine: 0.05
|
52 |
+
sigma_tps: 0.005
|
53 |
+
points_tps: 5
|
54 |
+
loss_weights:
|
55 |
+
perceptual: [10, 10, 10, 10, 10]
|
56 |
+
equivariance_value: 10
|
57 |
+
warp_loss: 10
|
58 |
+
bg: 10
|
59 |
+
|
60 |
+
train_avd_params:
|
61 |
+
num_epochs: 200
|
62 |
+
num_repeats: 300
|
63 |
+
batch_size: 256
|
64 |
+
dataloader_workers: 24
|
65 |
+
checkpoint_freq: 50
|
66 |
+
epoch_milestones: [140, 180]
|
67 |
+
lr: 1.0e-3
|
68 |
+
lambda_shift: 1
|
69 |
+
random_scale: 0.25
|
70 |
+
|
71 |
+
visualizer_params:
|
72 |
+
kp_size: 5
|
73 |
+
draw_border: True
|
74 |
+
colormap: 'gist_rainbow'
|
demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
demo.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
matplotlib.use('Agg')
|
3 |
+
import sys
|
4 |
+
import yaml
|
5 |
+
from argparse import ArgumentParser
|
6 |
+
from tqdm import tqdm
|
7 |
+
from scipy.spatial import ConvexHull
|
8 |
+
import numpy as np
|
9 |
+
import imageio
|
10 |
+
from skimage.transform import resize
|
11 |
+
from skimage import img_as_ubyte
|
12 |
+
import torch
|
13 |
+
from modules.inpainting_network import InpaintingNetwork
|
14 |
+
from modules.keypoint_detector import KPDetector
|
15 |
+
from modules.dense_motion import DenseMotionNetwork
|
16 |
+
from modules.avd_network import AVDNetwork
|
17 |
+
|
18 |
+
if sys.version_info[0] < 3:
|
19 |
+
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.9")
|
20 |
+
|
21 |
+
def relative_kp(kp_source, kp_driving, kp_driving_initial):
|
22 |
+
|
23 |
+
source_area = ConvexHull(kp_source['fg_kp'][0].data.cpu().numpy()).volume
|
24 |
+
driving_area = ConvexHull(kp_driving_initial['fg_kp'][0].data.cpu().numpy()).volume
|
25 |
+
adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
|
26 |
+
|
27 |
+
kp_new = {k: v for k, v in kp_driving.items()}
|
28 |
+
|
29 |
+
kp_value_diff = (kp_driving['fg_kp'] - kp_driving_initial['fg_kp'])
|
30 |
+
kp_value_diff *= adapt_movement_scale
|
31 |
+
kp_new['fg_kp'] = kp_value_diff + kp_source['fg_kp']
|
32 |
+
|
33 |
+
return kp_new
|
34 |
+
|
35 |
+
def load_checkpoints(config_path, checkpoint_path, device):
|
36 |
+
with open(config_path) as f:
|
37 |
+
config = yaml.load(f)
|
38 |
+
|
39 |
+
inpainting = InpaintingNetwork(**config['model_params']['generator_params'],
|
40 |
+
**config['model_params']['common_params'])
|
41 |
+
kp_detector = KPDetector(**config['model_params']['common_params'])
|
42 |
+
dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'],
|
43 |
+
**config['model_params']['dense_motion_params'])
|
44 |
+
avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'],
|
45 |
+
**config['model_params']['avd_network_params'])
|
46 |
+
kp_detector.to(device)
|
47 |
+
dense_motion_network.to(device)
|
48 |
+
inpainting.to(device)
|
49 |
+
avd_network.to(device)
|
50 |
+
|
51 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
52 |
+
|
53 |
+
inpainting.load_state_dict(checkpoint['inpainting_network'])
|
54 |
+
kp_detector.load_state_dict(checkpoint['kp_detector'])
|
55 |
+
dense_motion_network.load_state_dict(checkpoint['dense_motion_network'])
|
56 |
+
if 'avd_network' in checkpoint:
|
57 |
+
avd_network.load_state_dict(checkpoint['avd_network'])
|
58 |
+
|
59 |
+
inpainting.eval()
|
60 |
+
kp_detector.eval()
|
61 |
+
dense_motion_network.eval()
|
62 |
+
avd_network.eval()
|
63 |
+
|
64 |
+
return inpainting, kp_detector, dense_motion_network, avd_network
|
65 |
+
|
66 |
+
|
67 |
+
def make_animation(source_image, driving_video, inpainting_network, kp_detector, dense_motion_network, avd_network, device, mode = 'relative'):
|
68 |
+
assert mode in ['standard', 'relative', 'avd']
|
69 |
+
with torch.no_grad():
|
70 |
+
predictions = []
|
71 |
+
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
|
72 |
+
source = source.to(device)
|
73 |
+
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3).to(device)
|
74 |
+
kp_source = kp_detector(source)
|
75 |
+
kp_driving_initial = kp_detector(driving[:, :, 0])
|
76 |
+
|
77 |
+
for frame_idx in tqdm(range(driving.shape[2])):
|
78 |
+
driving_frame = driving[:, :, frame_idx]
|
79 |
+
driving_frame = driving_frame.to(device)
|
80 |
+
kp_driving = kp_detector(driving_frame)
|
81 |
+
if mode == 'standard':
|
82 |
+
kp_norm = kp_driving
|
83 |
+
elif mode=='relative':
|
84 |
+
kp_norm = relative_kp(kp_source=kp_source, kp_driving=kp_driving,
|
85 |
+
kp_driving_initial=kp_driving_initial)
|
86 |
+
elif mode == 'avd':
|
87 |
+
kp_norm = avd_network(kp_source, kp_driving)
|
88 |
+
dense_motion = dense_motion_network(source_image=source, kp_driving=kp_norm,
|
89 |
+
kp_source=kp_source, bg_param = None,
|
90 |
+
dropout_flag = False)
|
91 |
+
out = inpainting_network(source, dense_motion)
|
92 |
+
|
93 |
+
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
94 |
+
return predictions
|
95 |
+
|
96 |
+
|
97 |
+
def find_best_frame(source, driving, cpu):
|
98 |
+
import face_alignment
|
99 |
+
|
100 |
+
def normalize_kp(kp):
|
101 |
+
kp = kp - kp.mean(axis=0, keepdims=True)
|
102 |
+
area = ConvexHull(kp[:, :2]).volume
|
103 |
+
area = np.sqrt(area)
|
104 |
+
kp[:, :2] = kp[:, :2] / area
|
105 |
+
return kp
|
106 |
+
|
107 |
+
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
|
108 |
+
device= 'cpu' if cpu else 'cuda')
|
109 |
+
kp_source = fa.get_landmarks(255 * source)[0]
|
110 |
+
kp_source = normalize_kp(kp_source)
|
111 |
+
norm = float('inf')
|
112 |
+
frame_num = 0
|
113 |
+
for i, image in tqdm(enumerate(driving)):
|
114 |
+
kp_driving = fa.get_landmarks(255 * image)[0]
|
115 |
+
kp_driving = normalize_kp(kp_driving)
|
116 |
+
new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
|
117 |
+
if new_norm < norm:
|
118 |
+
norm = new_norm
|
119 |
+
frame_num = i
|
120 |
+
return frame_num
|
121 |
+
|
122 |
+
|
123 |
+
if __name__ == "__main__":
|
124 |
+
parser = ArgumentParser()
|
125 |
+
parser.add_argument("--config", required=True, help="path to config")
|
126 |
+
parser.add_argument("--checkpoint", default='checkpoints/vox.pth.tar', help="path to checkpoint to restore")
|
127 |
+
|
128 |
+
parser.add_argument("--source_image", default='./assets/source.png', help="path to source image")
|
129 |
+
parser.add_argument("--driving_video", default='./assets/driving.mp4', help="path to driving video")
|
130 |
+
parser.add_argument("--result_video", default='./result.mp4', help="path to output")
|
131 |
+
|
132 |
+
parser.add_argument("--img_shape", default="256,256", type=lambda x: list(map(int, x.split(','))),
|
133 |
+
help='Shape of image, that the model was trained on.')
|
134 |
+
|
135 |
+
parser.add_argument("--mode", default='relative', choices=['standard', 'relative', 'avd'], help="Animate mode: ['standard', 'relative', 'avd'], when use the relative mode to animate a face, use '--find_best_frame' can get better quality result")
|
136 |
+
|
137 |
+
parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true",
|
138 |
+
help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)")
|
139 |
+
|
140 |
+
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
|
141 |
+
|
142 |
+
opt = parser.parse_args()
|
143 |
+
|
144 |
+
source_image = imageio.imread(opt.source_image)
|
145 |
+
reader = imageio.get_reader(opt.driving_video)
|
146 |
+
fps = reader.get_meta_data()['fps']
|
147 |
+
driving_video = []
|
148 |
+
try:
|
149 |
+
for im in reader:
|
150 |
+
driving_video.append(im)
|
151 |
+
except RuntimeError:
|
152 |
+
pass
|
153 |
+
reader.close()
|
154 |
+
|
155 |
+
if opt.cpu:
|
156 |
+
device = torch.device('cpu')
|
157 |
+
else:
|
158 |
+
device = torch.device('cuda')
|
159 |
+
|
160 |
+
source_image = resize(source_image, opt.img_shape)[..., :3]
|
161 |
+
driving_video = [resize(frame, opt.img_shape)[..., :3] for frame in driving_video]
|
162 |
+
inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = opt.config, checkpoint_path = opt.checkpoint, device = device)
|
163 |
+
|
164 |
+
if opt.find_best_frame:
|
165 |
+
i = find_best_frame(source_image, driving_video, opt.cpu)
|
166 |
+
print ("Best frame: " + str(i))
|
167 |
+
driving_forward = driving_video[i:]
|
168 |
+
driving_backward = driving_video[:(i+1)][::-1]
|
169 |
+
predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode)
|
170 |
+
predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode)
|
171 |
+
predictions = predictions_backward[::-1] + predictions_forward[1:]
|
172 |
+
else:
|
173 |
+
predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode)
|
174 |
+
|
175 |
+
imageio.mimsave(opt.result_video, [img_as_ubyte(frame) for frame in predictions], fps=fps)
|
176 |
+
|
frames_dataset.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from skimage import io, img_as_float32
|
3 |
+
from skimage.color import gray2rgb
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
from imageio import mimread
|
6 |
+
from skimage.transform import resize
|
7 |
+
import numpy as np
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from augmentation import AllAugmentationTransform
|
10 |
+
import glob
|
11 |
+
from functools import partial
|
12 |
+
|
13 |
+
|
14 |
+
def read_video(name, frame_shape):
|
15 |
+
"""
|
16 |
+
Read video which can be:
|
17 |
+
- an image of concatenated frames
|
18 |
+
- '.mp4' and'.gif'
|
19 |
+
- folder with videos
|
20 |
+
"""
|
21 |
+
|
22 |
+
if os.path.isdir(name):
|
23 |
+
frames = sorted(os.listdir(name))
|
24 |
+
num_frames = len(frames)
|
25 |
+
video_array = np.array(
|
26 |
+
[img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)])
|
27 |
+
elif name.lower().endswith('.png') or name.lower().endswith('.jpg'):
|
28 |
+
image = io.imread(name)
|
29 |
+
|
30 |
+
if len(image.shape) == 2 or image.shape[2] == 1:
|
31 |
+
image = gray2rgb(image)
|
32 |
+
|
33 |
+
if image.shape[2] == 4:
|
34 |
+
image = image[..., :3]
|
35 |
+
|
36 |
+
image = img_as_float32(image)
|
37 |
+
|
38 |
+
video_array = np.moveaxis(image, 1, 0)
|
39 |
+
|
40 |
+
video_array = video_array.reshape((-1,) + frame_shape)
|
41 |
+
video_array = np.moveaxis(video_array, 1, 2)
|
42 |
+
elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):
|
43 |
+
video = mimread(name)
|
44 |
+
if len(video[0].shape) == 2:
|
45 |
+
video = [gray2rgb(frame) for frame in video]
|
46 |
+
if frame_shape is not None:
|
47 |
+
video = np.array([resize(frame, frame_shape) for frame in video])
|
48 |
+
video = np.array(video)
|
49 |
+
if video.shape[-1] == 4:
|
50 |
+
video = video[..., :3]
|
51 |
+
video_array = img_as_float32(video)
|
52 |
+
else:
|
53 |
+
raise Exception("Unknown file extensions %s" % name)
|
54 |
+
|
55 |
+
return video_array
|
56 |
+
|
57 |
+
|
58 |
+
class FramesDataset(Dataset):
|
59 |
+
"""
|
60 |
+
Dataset of videos, each video can be represented as:
|
61 |
+
- an image of concatenated frames
|
62 |
+
- '.mp4' or '.gif'
|
63 |
+
- folder with all frames
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
|
67 |
+
random_seed=0, pairs_list=None, augmentation_params=None):
|
68 |
+
self.root_dir = root_dir
|
69 |
+
self.videos = os.listdir(root_dir)
|
70 |
+
self.frame_shape = frame_shape
|
71 |
+
print(self.frame_shape)
|
72 |
+
self.pairs_list = pairs_list
|
73 |
+
self.id_sampling = id_sampling
|
74 |
+
|
75 |
+
if os.path.exists(os.path.join(root_dir, 'train')):
|
76 |
+
assert os.path.exists(os.path.join(root_dir, 'test'))
|
77 |
+
print("Use predefined train-test split.")
|
78 |
+
if id_sampling:
|
79 |
+
train_videos = {os.path.basename(video).split('#')[0] for video in
|
80 |
+
os.listdir(os.path.join(root_dir, 'train'))}
|
81 |
+
train_videos = list(train_videos)
|
82 |
+
else:
|
83 |
+
train_videos = os.listdir(os.path.join(root_dir, 'train'))
|
84 |
+
test_videos = os.listdir(os.path.join(root_dir, 'test'))
|
85 |
+
self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
|
86 |
+
else:
|
87 |
+
print("Use random train-test split.")
|
88 |
+
train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)
|
89 |
+
|
90 |
+
if is_train:
|
91 |
+
self.videos = train_videos
|
92 |
+
else:
|
93 |
+
self.videos = test_videos
|
94 |
+
|
95 |
+
self.is_train = is_train
|
96 |
+
|
97 |
+
if self.is_train:
|
98 |
+
self.transform = AllAugmentationTransform(**augmentation_params)
|
99 |
+
else:
|
100 |
+
self.transform = None
|
101 |
+
|
102 |
+
def __len__(self):
|
103 |
+
return len(self.videos)
|
104 |
+
|
105 |
+
def __getitem__(self, idx):
|
106 |
+
|
107 |
+
if self.is_train and self.id_sampling:
|
108 |
+
name = self.videos[idx]
|
109 |
+
path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
|
110 |
+
else:
|
111 |
+
name = self.videos[idx]
|
112 |
+
path = os.path.join(self.root_dir, name)
|
113 |
+
|
114 |
+
video_name = os.path.basename(path)
|
115 |
+
if self.is_train and os.path.isdir(path):
|
116 |
+
|
117 |
+
frames = os.listdir(path)
|
118 |
+
num_frames = len(frames)
|
119 |
+
frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))
|
120 |
+
|
121 |
+
if self.frame_shape is not None:
|
122 |
+
resize_fn = partial(resize, output_shape=self.frame_shape)
|
123 |
+
else:
|
124 |
+
resize_fn = img_as_float32
|
125 |
+
|
126 |
+
if type(frames[0]) is bytes:
|
127 |
+
video_array = [resize_fn(io.imread(os.path.join(path, frames[idx].decode('utf-8')))) for idx in
|
128 |
+
frame_idx]
|
129 |
+
else:
|
130 |
+
video_array = [resize_fn(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx]
|
131 |
+
else:
|
132 |
+
|
133 |
+
video_array = read_video(path, frame_shape=self.frame_shape)
|
134 |
+
|
135 |
+
num_frames = len(video_array)
|
136 |
+
frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range(
|
137 |
+
num_frames)
|
138 |
+
video_array = video_array[frame_idx]
|
139 |
+
|
140 |
+
|
141 |
+
if self.transform is not None:
|
142 |
+
video_array = self.transform(video_array)
|
143 |
+
|
144 |
+
out = {}
|
145 |
+
if self.is_train:
|
146 |
+
source = np.array(video_array[0], dtype='float32')
|
147 |
+
driving = np.array(video_array[1], dtype='float32')
|
148 |
+
|
149 |
+
out['driving'] = driving.transpose((2, 0, 1))
|
150 |
+
out['source'] = source.transpose((2, 0, 1))
|
151 |
+
else:
|
152 |
+
video = np.array(video_array, dtype='float32')
|
153 |
+
out['video'] = video.transpose((3, 0, 1, 2))
|
154 |
+
|
155 |
+
out['name'] = video_name
|
156 |
+
return out
|
157 |
+
|
158 |
+
|
159 |
+
class DatasetRepeater(Dataset):
|
160 |
+
"""
|
161 |
+
Pass several times over the same dataset for better i/o performance
|
162 |
+
"""
|
163 |
+
|
164 |
+
def __init__(self, dataset, num_repeats=100):
|
165 |
+
self.dataset = dataset
|
166 |
+
self.num_repeats = num_repeats
|
167 |
+
|
168 |
+
def __len__(self):
|
169 |
+
return self.num_repeats * self.dataset.__len__()
|
170 |
+
|
171 |
+
def __getitem__(self, idx):
|
172 |
+
return self.dataset[idx % self.dataset.__len__()]
|
173 |
+
|
generated.mp4
ADDED
Binary file (11.3 kB). View file
|
|
logger.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import imageio
|
5 |
+
|
6 |
+
import os
|
7 |
+
from skimage.draw import circle
|
8 |
+
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import collections
|
11 |
+
|
12 |
+
|
13 |
+
class Logger:
|
14 |
+
def __init__(self, log_dir, checkpoint_freq=50, visualizer_params=None, zfill_num=8, log_file_name='log.txt'):
|
15 |
+
|
16 |
+
self.loss_list = []
|
17 |
+
self.cpk_dir = log_dir
|
18 |
+
self.visualizations_dir = os.path.join(log_dir, 'train-vis')
|
19 |
+
if not os.path.exists(self.visualizations_dir):
|
20 |
+
os.makedirs(self.visualizations_dir)
|
21 |
+
self.log_file = open(os.path.join(log_dir, log_file_name), 'a')
|
22 |
+
self.zfill_num = zfill_num
|
23 |
+
self.visualizer = Visualizer(**visualizer_params)
|
24 |
+
self.checkpoint_freq = checkpoint_freq
|
25 |
+
self.epoch = 0
|
26 |
+
self.best_loss = float('inf')
|
27 |
+
self.names = None
|
28 |
+
|
29 |
+
def log_scores(self, loss_names):
|
30 |
+
loss_mean = np.array(self.loss_list).mean(axis=0)
|
31 |
+
|
32 |
+
loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)])
|
33 |
+
loss_string = str(self.epoch).zfill(self.zfill_num) + ") " + loss_string
|
34 |
+
|
35 |
+
print(loss_string, file=self.log_file)
|
36 |
+
self.loss_list = []
|
37 |
+
self.log_file.flush()
|
38 |
+
|
39 |
+
def visualize_rec(self, inp, out):
|
40 |
+
image = self.visualizer.visualize(inp['driving'], inp['source'], out)
|
41 |
+
imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image)
|
42 |
+
|
43 |
+
def save_cpk(self, emergent=False):
|
44 |
+
cpk = {k: v.state_dict() for k, v in self.models.items()}
|
45 |
+
cpk['epoch'] = self.epoch
|
46 |
+
cpk_path = os.path.join(self.cpk_dir, '%s-checkpoint.pth.tar' % str(self.epoch).zfill(self.zfill_num))
|
47 |
+
if not (os.path.exists(cpk_path) and emergent):
|
48 |
+
torch.save(cpk, cpk_path)
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def load_cpk(checkpoint_path, inpainting_network=None, dense_motion_network =None, kp_detector=None,
|
52 |
+
bg_predictor=None, avd_network=None, optimizer=None, optimizer_bg_predictor=None,
|
53 |
+
optimizer_avd=None):
|
54 |
+
checkpoint = torch.load(checkpoint_path)
|
55 |
+
if inpainting_network is not None:
|
56 |
+
inpainting_network.load_state_dict(checkpoint['inpainting_network'])
|
57 |
+
if kp_detector is not None:
|
58 |
+
kp_detector.load_state_dict(checkpoint['kp_detector'])
|
59 |
+
if bg_predictor is not None and 'bg_predictor' in checkpoint:
|
60 |
+
bg_predictor.load_state_dict(checkpoint['bg_predictor'])
|
61 |
+
if dense_motion_network is not None:
|
62 |
+
dense_motion_network.load_state_dict(checkpoint['dense_motion_network'])
|
63 |
+
if avd_network is not None:
|
64 |
+
if 'avd_network' in checkpoint:
|
65 |
+
avd_network.load_state_dict(checkpoint['avd_network'])
|
66 |
+
if optimizer_bg_predictor is not None and 'optimizer_bg_predictor' in checkpoint:
|
67 |
+
optimizer_bg_predictor.load_state_dict(checkpoint['optimizer_bg_predictor'])
|
68 |
+
if optimizer is not None and 'optimizer' in checkpoint:
|
69 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
70 |
+
if optimizer_avd is not None:
|
71 |
+
if 'optimizer_avd' in checkpoint:
|
72 |
+
optimizer_avd.load_state_dict(checkpoint['optimizer_avd'])
|
73 |
+
epoch = -1
|
74 |
+
if 'epoch' in checkpoint:
|
75 |
+
epoch = checkpoint['epoch']
|
76 |
+
return epoch
|
77 |
+
|
78 |
+
def __enter__(self):
|
79 |
+
return self
|
80 |
+
|
81 |
+
def __exit__(self):
|
82 |
+
if 'models' in self.__dict__:
|
83 |
+
self.save_cpk()
|
84 |
+
self.log_file.close()
|
85 |
+
|
86 |
+
def log_iter(self, losses):
|
87 |
+
losses = collections.OrderedDict(losses.items())
|
88 |
+
self.names = list(losses.keys())
|
89 |
+
self.loss_list.append(list(losses.values()))
|
90 |
+
|
91 |
+
def log_epoch(self, epoch, models, inp, out):
|
92 |
+
self.epoch = epoch
|
93 |
+
self.models = models
|
94 |
+
if (self.epoch + 1) % self.checkpoint_freq == 0:
|
95 |
+
self.save_cpk()
|
96 |
+
self.log_scores(self.names)
|
97 |
+
self.visualize_rec(inp, out)
|
98 |
+
|
99 |
+
|
100 |
+
class Visualizer:
|
101 |
+
def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'):
|
102 |
+
self.kp_size = kp_size
|
103 |
+
self.draw_border = draw_border
|
104 |
+
self.colormap = plt.get_cmap(colormap)
|
105 |
+
|
106 |
+
def draw_image_with_kp(self, image, kp_array):
|
107 |
+
image = np.copy(image)
|
108 |
+
spatial_size = np.array(image.shape[:2][::-1])[np.newaxis]
|
109 |
+
kp_array = spatial_size * (kp_array + 1) / 2
|
110 |
+
num_kp = kp_array.shape[0]
|
111 |
+
for kp_ind, kp in enumerate(kp_array):
|
112 |
+
rr, cc = circle(kp[1], kp[0], self.kp_size, shape=image.shape[:2])
|
113 |
+
image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3]
|
114 |
+
return image
|
115 |
+
|
116 |
+
def create_image_column_with_kp(self, images, kp):
|
117 |
+
image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)])
|
118 |
+
return self.create_image_column(image_array)
|
119 |
+
|
120 |
+
def create_image_column(self, images):
|
121 |
+
if self.draw_border:
|
122 |
+
images = np.copy(images)
|
123 |
+
images[:, :, [0, -1]] = (1, 1, 1)
|
124 |
+
images[:, :, [0, -1]] = (1, 1, 1)
|
125 |
+
return np.concatenate(list(images), axis=0)
|
126 |
+
|
127 |
+
def create_image_grid(self, *args):
|
128 |
+
out = []
|
129 |
+
for arg in args:
|
130 |
+
if type(arg) == tuple:
|
131 |
+
out.append(self.create_image_column_with_kp(arg[0], arg[1]))
|
132 |
+
else:
|
133 |
+
out.append(self.create_image_column(arg))
|
134 |
+
return np.concatenate(out, axis=1)
|
135 |
+
|
136 |
+
def visualize(self, driving, source, out):
|
137 |
+
images = []
|
138 |
+
|
139 |
+
# Source image with keypoints
|
140 |
+
source = source.data.cpu()
|
141 |
+
kp_source = out['kp_source']['fg_kp'].data.cpu().numpy()
|
142 |
+
source = np.transpose(source, [0, 2, 3, 1])
|
143 |
+
images.append((source, kp_source))
|
144 |
+
|
145 |
+
# Equivariance visualization
|
146 |
+
if 'transformed_frame' in out:
|
147 |
+
transformed = out['transformed_frame'].data.cpu().numpy()
|
148 |
+
transformed = np.transpose(transformed, [0, 2, 3, 1])
|
149 |
+
transformed_kp = out['transformed_kp']['fg_kp'].data.cpu().numpy()
|
150 |
+
images.append((transformed, transformed_kp))
|
151 |
+
|
152 |
+
# Driving image with keypoints
|
153 |
+
kp_driving = out['kp_driving']['fg_kp'].data.cpu().numpy()
|
154 |
+
driving = driving.data.cpu().numpy()
|
155 |
+
driving = np.transpose(driving, [0, 2, 3, 1])
|
156 |
+
images.append((driving, kp_driving))
|
157 |
+
|
158 |
+
# Deformed image
|
159 |
+
if 'deformed' in out:
|
160 |
+
deformed = out['deformed'].data.cpu().numpy()
|
161 |
+
deformed = np.transpose(deformed, [0, 2, 3, 1])
|
162 |
+
images.append(deformed)
|
163 |
+
|
164 |
+
# Result with and without keypoints
|
165 |
+
prediction = out['prediction'].data.cpu().numpy()
|
166 |
+
prediction = np.transpose(prediction, [0, 2, 3, 1])
|
167 |
+
if 'kp_norm' in out:
|
168 |
+
kp_norm = out['kp_norm']['fg_kp'].data.cpu().numpy()
|
169 |
+
images.append((prediction, kp_norm))
|
170 |
+
images.append(prediction)
|
171 |
+
|
172 |
+
|
173 |
+
## Occlusion map
|
174 |
+
if 'occlusion_map' in out:
|
175 |
+
for i in range(len(out['occlusion_map'])):
|
176 |
+
occlusion_map = out['occlusion_map'][i].data.cpu().repeat(1, 3, 1, 1)
|
177 |
+
occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy()
|
178 |
+
occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1])
|
179 |
+
images.append(occlusion_map)
|
180 |
+
|
181 |
+
# Deformed images according to each individual transform
|
182 |
+
if 'deformed_source' in out:
|
183 |
+
full_mask = []
|
184 |
+
for i in range(out['deformed_source'].shape[1]):
|
185 |
+
image = out['deformed_source'][:, i].data.cpu()
|
186 |
+
# import ipdb;ipdb.set_trace()
|
187 |
+
image = F.interpolate(image, size=source.shape[1:3])
|
188 |
+
mask = out['contribution_maps'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1)
|
189 |
+
mask = F.interpolate(mask, size=source.shape[1:3])
|
190 |
+
image = np.transpose(image.numpy(), (0, 2, 3, 1))
|
191 |
+
mask = np.transpose(mask.numpy(), (0, 2, 3, 1))
|
192 |
+
|
193 |
+
if i != 0:
|
194 |
+
color = np.array(self.colormap((i - 1) / (out['deformed_source'].shape[1] - 1)))[:3]
|
195 |
+
else:
|
196 |
+
color = np.array((0, 0, 0))
|
197 |
+
|
198 |
+
color = color.reshape((1, 1, 1, 3))
|
199 |
+
|
200 |
+
images.append(image)
|
201 |
+
if i != 0:
|
202 |
+
images.append(mask * color)
|
203 |
+
else:
|
204 |
+
images.append(mask)
|
205 |
+
|
206 |
+
full_mask.append(mask * color)
|
207 |
+
|
208 |
+
images.append(sum(full_mask))
|
209 |
+
|
210 |
+
image = self.create_image_grid(*images)
|
211 |
+
image = (255 * image).astype(np.uint8)
|
212 |
+
return image
|
modules/avd_network.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class AVDNetwork(nn.Module):
|
7 |
+
"""
|
8 |
+
Animation via Disentanglement network
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, num_tps, id_bottle_size=64, pose_bottle_size=64):
|
12 |
+
super(AVDNetwork, self).__init__()
|
13 |
+
input_size = 5*2 * num_tps
|
14 |
+
self.num_tps = num_tps
|
15 |
+
|
16 |
+
self.id_encoder = nn.Sequential(
|
17 |
+
nn.Linear(input_size, 256),
|
18 |
+
nn.BatchNorm1d(256),
|
19 |
+
nn.ReLU(inplace=True),
|
20 |
+
nn.Linear(256, 512),
|
21 |
+
nn.BatchNorm1d(512),
|
22 |
+
nn.ReLU(inplace=True),
|
23 |
+
nn.Linear(512, 1024),
|
24 |
+
nn.BatchNorm1d(1024),
|
25 |
+
nn.ReLU(inplace=True),
|
26 |
+
nn.Linear(1024, id_bottle_size)
|
27 |
+
)
|
28 |
+
|
29 |
+
self.pose_encoder = nn.Sequential(
|
30 |
+
nn.Linear(input_size, 256),
|
31 |
+
nn.BatchNorm1d(256),
|
32 |
+
nn.ReLU(inplace=True),
|
33 |
+
nn.Linear(256, 512),
|
34 |
+
nn.BatchNorm1d(512),
|
35 |
+
nn.ReLU(inplace=True),
|
36 |
+
nn.Linear(512, 1024),
|
37 |
+
nn.BatchNorm1d(1024),
|
38 |
+
nn.ReLU(inplace=True),
|
39 |
+
nn.Linear(1024, pose_bottle_size)
|
40 |
+
)
|
41 |
+
|
42 |
+
self.decoder = nn.Sequential(
|
43 |
+
nn.Linear(pose_bottle_size + id_bottle_size, 1024),
|
44 |
+
nn.BatchNorm1d(1024),
|
45 |
+
nn.ReLU(),
|
46 |
+
nn.Linear(1024, 512),
|
47 |
+
nn.BatchNorm1d(512),
|
48 |
+
nn.ReLU(),
|
49 |
+
nn.Linear(512, 256),
|
50 |
+
nn.BatchNorm1d(256),
|
51 |
+
nn.ReLU(),
|
52 |
+
nn.Linear(256, input_size)
|
53 |
+
)
|
54 |
+
|
55 |
+
def forward(self, kp_source, kp_random):
|
56 |
+
|
57 |
+
bs = kp_source['fg_kp'].shape[0]
|
58 |
+
|
59 |
+
pose_emb = self.pose_encoder(kp_random['fg_kp'].view(bs, -1))
|
60 |
+
id_emb = self.id_encoder(kp_source['fg_kp'].view(bs, -1))
|
61 |
+
|
62 |
+
rec = self.decoder(torch.cat([pose_emb, id_emb], dim=1))
|
63 |
+
|
64 |
+
rec = {'fg_kp': rec.view(bs, self.num_tps*5, -1)}
|
65 |
+
return rec
|
modules/bg_motion_predictor.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
from torchvision import models
|
4 |
+
|
5 |
+
class BGMotionPredictor(nn.Module):
|
6 |
+
"""
|
7 |
+
Module for background estimation, return single transformation, parametrized as 3x3 matrix. The third row is [0 0 1]
|
8 |
+
"""
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
super(BGMotionPredictor, self).__init__()
|
12 |
+
self.bg_encoder = models.resnet18(pretrained=False)
|
13 |
+
self.bg_encoder.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
|
14 |
+
num_features = self.bg_encoder.fc.in_features
|
15 |
+
self.bg_encoder.fc = nn.Linear(num_features, 6)
|
16 |
+
self.bg_encoder.fc.weight.data.zero_()
|
17 |
+
self.bg_encoder.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
|
18 |
+
|
19 |
+
def forward(self, source_image, driving_image):
|
20 |
+
bs = source_image.shape[0]
|
21 |
+
out = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).type(source_image.type())
|
22 |
+
prediction = self.bg_encoder(torch.cat([source_image, driving_image], dim=1))
|
23 |
+
out[:, :2, :] = prediction.view(bs, 2, 3)
|
24 |
+
return out
|
modules/dense_motion.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch
|
4 |
+
from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
|
5 |
+
from modules.util import to_homogeneous, from_homogeneous, UpBlock2d, TPS
|
6 |
+
import math
|
7 |
+
|
8 |
+
class DenseMotionNetwork(nn.Module):
|
9 |
+
"""
|
10 |
+
Module that estimating an optical flow and multi-resolution occlusion masks
|
11 |
+
from K TPS transformations and an affine transformation.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, block_expansion, num_blocks, max_features, num_tps, num_channels,
|
15 |
+
scale_factor=0.25, bg = False, multi_mask = True, kp_variance=0.01):
|
16 |
+
super(DenseMotionNetwork, self).__init__()
|
17 |
+
|
18 |
+
if scale_factor != 1:
|
19 |
+
self.down = AntiAliasInterpolation2d(num_channels, scale_factor)
|
20 |
+
self.scale_factor = scale_factor
|
21 |
+
self.multi_mask = multi_mask
|
22 |
+
|
23 |
+
self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_channels * (num_tps+1) + num_tps*5+1),
|
24 |
+
max_features=max_features, num_blocks=num_blocks)
|
25 |
+
|
26 |
+
hourglass_output_size = self.hourglass.out_channels
|
27 |
+
self.maps = nn.Conv2d(hourglass_output_size[-1], num_tps + 1, kernel_size=(7, 7), padding=(3, 3))
|
28 |
+
|
29 |
+
if multi_mask:
|
30 |
+
up = []
|
31 |
+
self.up_nums = int(math.log(1/scale_factor, 2))
|
32 |
+
self.occlusion_num = 4
|
33 |
+
|
34 |
+
channel = [hourglass_output_size[-1]//(2**i) for i in range(self.up_nums)]
|
35 |
+
for i in range(self.up_nums):
|
36 |
+
up.append(UpBlock2d(channel[i], channel[i]//2, kernel_size=3, padding=1))
|
37 |
+
self.up = nn.ModuleList(up)
|
38 |
+
|
39 |
+
channel = [hourglass_output_size[-i-1] for i in range(self.occlusion_num-self.up_nums)[::-1]]
|
40 |
+
for i in range(self.up_nums):
|
41 |
+
channel.append(hourglass_output_size[-1]//(2**(i+1)))
|
42 |
+
occlusion = []
|
43 |
+
|
44 |
+
for i in range(self.occlusion_num):
|
45 |
+
occlusion.append(nn.Conv2d(channel[i], 1, kernel_size=(7, 7), padding=(3, 3)))
|
46 |
+
self.occlusion = nn.ModuleList(occlusion)
|
47 |
+
else:
|
48 |
+
occlusion = [nn.Conv2d(hourglass_output_size[-1], 1, kernel_size=(7, 7), padding=(3, 3))]
|
49 |
+
self.occlusion = nn.ModuleList(occlusion)
|
50 |
+
|
51 |
+
self.num_tps = num_tps
|
52 |
+
self.bg = bg
|
53 |
+
self.kp_variance = kp_variance
|
54 |
+
|
55 |
+
|
56 |
+
def create_heatmap_representations(self, source_image, kp_driving, kp_source):
|
57 |
+
|
58 |
+
spatial_size = source_image.shape[2:]
|
59 |
+
gaussian_driving = kp2gaussian(kp_driving['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance)
|
60 |
+
gaussian_source = kp2gaussian(kp_source['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance)
|
61 |
+
heatmap = gaussian_driving - gaussian_source
|
62 |
+
|
63 |
+
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type()).to(heatmap.device)
|
64 |
+
heatmap = torch.cat([zeros, heatmap], dim=1)
|
65 |
+
|
66 |
+
return heatmap
|
67 |
+
|
68 |
+
def create_transformations(self, source_image, kp_driving, kp_source, bg_param):
|
69 |
+
# K TPS transformaions
|
70 |
+
bs, _, h, w = source_image.shape
|
71 |
+
kp_1 = kp_driving['fg_kp']
|
72 |
+
kp_2 = kp_source['fg_kp']
|
73 |
+
kp_1 = kp_1.view(bs, -1, 5, 2)
|
74 |
+
kp_2 = kp_2.view(bs, -1, 5, 2)
|
75 |
+
trans = TPS(mode = 'kp', bs = bs, kp_1 = kp_1, kp_2 = kp_2)
|
76 |
+
driving_to_source = trans.transform_frame(source_image)
|
77 |
+
|
78 |
+
identity_grid = make_coordinate_grid((h, w), type=kp_1.type()).to(kp_1.device)
|
79 |
+
identity_grid = identity_grid.view(1, 1, h, w, 2)
|
80 |
+
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
|
81 |
+
|
82 |
+
# affine background transformation
|
83 |
+
if not (bg_param is None):
|
84 |
+
identity_grid = to_homogeneous(identity_grid)
|
85 |
+
identity_grid = torch.matmul(bg_param.view(bs, 1, 1, 1, 3, 3), identity_grid.unsqueeze(-1)).squeeze(-1)
|
86 |
+
identity_grid = from_homogeneous(identity_grid)
|
87 |
+
|
88 |
+
transformations = torch.cat([identity_grid, driving_to_source], dim=1)
|
89 |
+
return transformations
|
90 |
+
|
91 |
+
def create_deformed_source_image(self, source_image, transformations):
|
92 |
+
|
93 |
+
bs, _, h, w = source_image.shape
|
94 |
+
source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_tps + 1, 1, 1, 1, 1)
|
95 |
+
source_repeat = source_repeat.view(bs * (self.num_tps + 1), -1, h, w)
|
96 |
+
transformations = transformations.view((bs * (self.num_tps + 1), h, w, -1))
|
97 |
+
deformed = F.grid_sample(source_repeat, transformations, align_corners=True)
|
98 |
+
deformed = deformed.view((bs, self.num_tps+1, -1, h, w))
|
99 |
+
return deformed
|
100 |
+
|
101 |
+
def dropout_softmax(self, X, P):
|
102 |
+
'''
|
103 |
+
Dropout for TPS transformations. Eq(7) and Eq(8) in the paper.
|
104 |
+
'''
|
105 |
+
drop = (torch.rand(X.shape[0],X.shape[1]) < (1-P)).type(X.type()).to(X.device)
|
106 |
+
drop[..., 0] = 1
|
107 |
+
drop = drop.repeat(X.shape[2],X.shape[3],1,1).permute(2,3,0,1)
|
108 |
+
|
109 |
+
maxx = X.max(1).values.unsqueeze_(1)
|
110 |
+
X = X - maxx
|
111 |
+
X_exp = X.exp()
|
112 |
+
X[:,1:,...] /= (1-P)
|
113 |
+
mask_bool =(drop == 0)
|
114 |
+
X_exp = X_exp.masked_fill(mask_bool, 0)
|
115 |
+
partition = X_exp.sum(dim=1, keepdim=True) + 1e-6
|
116 |
+
return X_exp / partition
|
117 |
+
|
118 |
+
def forward(self, source_image, kp_driving, kp_source, bg_param = None, dropout_flag=False, dropout_p = 0):
|
119 |
+
if self.scale_factor != 1:
|
120 |
+
source_image = self.down(source_image)
|
121 |
+
|
122 |
+
bs, _, h, w = source_image.shape
|
123 |
+
|
124 |
+
out_dict = dict()
|
125 |
+
heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)
|
126 |
+
transformations = self.create_transformations(source_image, kp_driving, kp_source, bg_param)
|
127 |
+
deformed_source = self.create_deformed_source_image(source_image, transformations)
|
128 |
+
out_dict['deformed_source'] = deformed_source
|
129 |
+
# out_dict['transformations'] = transformations
|
130 |
+
deformed_source = deformed_source.view(bs,-1,h,w)
|
131 |
+
input = torch.cat([heatmap_representation, deformed_source], dim=1)
|
132 |
+
input = input.view(bs, -1, h, w)
|
133 |
+
|
134 |
+
prediction = self.hourglass(input, mode = 1)
|
135 |
+
|
136 |
+
contribution_maps = self.maps(prediction[-1])
|
137 |
+
if(dropout_flag):
|
138 |
+
contribution_maps = self.dropout_softmax(contribution_maps, dropout_p)
|
139 |
+
else:
|
140 |
+
contribution_maps = F.softmax(contribution_maps, dim=1)
|
141 |
+
out_dict['contribution_maps'] = contribution_maps
|
142 |
+
|
143 |
+
# Combine the K+1 transformations
|
144 |
+
# Eq(6) in the paper
|
145 |
+
contribution_maps = contribution_maps.unsqueeze(2)
|
146 |
+
transformations = transformations.permute(0, 1, 4, 2, 3)
|
147 |
+
deformation = (transformations * contribution_maps).sum(dim=1)
|
148 |
+
deformation = deformation.permute(0, 2, 3, 1)
|
149 |
+
|
150 |
+
out_dict['deformation'] = deformation # Optical Flow
|
151 |
+
|
152 |
+
occlusion_map = []
|
153 |
+
if self.multi_mask:
|
154 |
+
for i in range(self.occlusion_num-self.up_nums):
|
155 |
+
occlusion_map.append(torch.sigmoid(self.occlusion[i](prediction[self.up_nums-self.occlusion_num+i])))
|
156 |
+
prediction = prediction[-1]
|
157 |
+
for i in range(self.up_nums):
|
158 |
+
prediction = self.up[i](prediction)
|
159 |
+
occlusion_map.append(torch.sigmoid(self.occlusion[i+self.occlusion_num-self.up_nums](prediction)))
|
160 |
+
else:
|
161 |
+
occlusion_map.append(torch.sigmoid(self.occlusion[0](prediction[-1])))
|
162 |
+
|
163 |
+
out_dict['occlusion_map'] = occlusion_map # Multi-resolution Occlusion Masks
|
164 |
+
return out_dict
|
modules/inpainting_network.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
|
5 |
+
from modules.dense_motion import DenseMotionNetwork
|
6 |
+
|
7 |
+
|
8 |
+
class InpaintingNetwork(nn.Module):
|
9 |
+
"""
|
10 |
+
Inpaint the missing regions and reconstruct the Driving image.
|
11 |
+
"""
|
12 |
+
def __init__(self, num_channels, block_expansion, max_features, num_down_blocks, multi_mask = True, **kwargs):
|
13 |
+
super(InpaintingNetwork, self).__init__()
|
14 |
+
|
15 |
+
self.num_down_blocks = num_down_blocks
|
16 |
+
self.multi_mask = multi_mask
|
17 |
+
self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
|
18 |
+
|
19 |
+
down_blocks = []
|
20 |
+
for i in range(num_down_blocks):
|
21 |
+
in_features = min(max_features, block_expansion * (2 ** i))
|
22 |
+
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
|
23 |
+
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
24 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
25 |
+
|
26 |
+
up_blocks = []
|
27 |
+
in_features = [max_features, max_features, max_features//2]
|
28 |
+
out_features = [max_features//2, max_features//4, max_features//8]
|
29 |
+
for i in range(num_down_blocks):
|
30 |
+
up_blocks.append(UpBlock2d(in_features[i], out_features[i], kernel_size=(3, 3), padding=(1, 1)))
|
31 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
32 |
+
|
33 |
+
resblock = []
|
34 |
+
for i in range(num_down_blocks):
|
35 |
+
resblock.append(ResBlock2d(in_features[i], kernel_size=(3, 3), padding=(1, 1)))
|
36 |
+
resblock.append(ResBlock2d(in_features[i], kernel_size=(3, 3), padding=(1, 1)))
|
37 |
+
self.resblock = nn.ModuleList(resblock)
|
38 |
+
|
39 |
+
self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
|
40 |
+
self.num_channels = num_channels
|
41 |
+
|
42 |
+
def deform_input(self, inp, deformation):
|
43 |
+
_, h_old, w_old, _ = deformation.shape
|
44 |
+
_, _, h, w = inp.shape
|
45 |
+
if h_old != h or w_old != w:
|
46 |
+
deformation = deformation.permute(0, 3, 1, 2)
|
47 |
+
deformation = F.interpolate(deformation, size=(h, w), mode='bilinear', align_corners=True)
|
48 |
+
deformation = deformation.permute(0, 2, 3, 1)
|
49 |
+
return F.grid_sample(inp, deformation,align_corners=True)
|
50 |
+
|
51 |
+
def occlude_input(self, inp, occlusion_map):
|
52 |
+
if not self.multi_mask:
|
53 |
+
if inp.shape[2] != occlusion_map.shape[2] or inp.shape[3] != occlusion_map.shape[3]:
|
54 |
+
occlusion_map = F.interpolate(occlusion_map, size=inp.shape[2:], mode='bilinear',align_corners=True)
|
55 |
+
out = inp * occlusion_map
|
56 |
+
return out
|
57 |
+
|
58 |
+
def forward(self, source_image, dense_motion):
|
59 |
+
out = self.first(source_image)
|
60 |
+
encoder_map = [out]
|
61 |
+
for i in range(len(self.down_blocks)):
|
62 |
+
out = self.down_blocks[i](out)
|
63 |
+
encoder_map.append(out)
|
64 |
+
|
65 |
+
output_dict = {}
|
66 |
+
output_dict['contribution_maps'] = dense_motion['contribution_maps']
|
67 |
+
output_dict['deformed_source'] = dense_motion['deformed_source']
|
68 |
+
|
69 |
+
occlusion_map = dense_motion['occlusion_map']
|
70 |
+
output_dict['occlusion_map'] = occlusion_map
|
71 |
+
|
72 |
+
deformation = dense_motion['deformation']
|
73 |
+
out_ij = self.deform_input(out.detach(), deformation)
|
74 |
+
out = self.deform_input(out, deformation)
|
75 |
+
|
76 |
+
out_ij = self.occlude_input(out_ij, occlusion_map[0].detach())
|
77 |
+
out = self.occlude_input(out, occlusion_map[0])
|
78 |
+
|
79 |
+
warped_encoder_maps = []
|
80 |
+
warped_encoder_maps.append(out_ij)
|
81 |
+
|
82 |
+
for i in range(self.num_down_blocks):
|
83 |
+
|
84 |
+
out = self.resblock[2*i](out)
|
85 |
+
out = self.resblock[2*i+1](out)
|
86 |
+
out = self.up_blocks[i](out)
|
87 |
+
|
88 |
+
encode_i = encoder_map[-(i+2)]
|
89 |
+
encode_ij = self.deform_input(encode_i.detach(), deformation)
|
90 |
+
encode_i = self.deform_input(encode_i, deformation)
|
91 |
+
|
92 |
+
occlusion_ind = 0
|
93 |
+
if self.multi_mask:
|
94 |
+
occlusion_ind = i+1
|
95 |
+
encode_ij = self.occlude_input(encode_ij, occlusion_map[occlusion_ind].detach())
|
96 |
+
encode_i = self.occlude_input(encode_i, occlusion_map[occlusion_ind])
|
97 |
+
warped_encoder_maps.append(encode_ij)
|
98 |
+
|
99 |
+
if(i==self.num_down_blocks-1):
|
100 |
+
break
|
101 |
+
|
102 |
+
out = torch.cat([out, encode_i], 1)
|
103 |
+
|
104 |
+
deformed_source = self.deform_input(source_image, deformation)
|
105 |
+
output_dict["deformed"] = deformed_source
|
106 |
+
output_dict["warped_encoder_maps"] = warped_encoder_maps
|
107 |
+
|
108 |
+
occlusion_last = occlusion_map[-1]
|
109 |
+
if not self.multi_mask:
|
110 |
+
occlusion_last = F.interpolate(occlusion_last, size=out.shape[2:], mode='bilinear',align_corners=True)
|
111 |
+
|
112 |
+
out = out * (1 - occlusion_last) + encode_i
|
113 |
+
out = self.final(out)
|
114 |
+
out = torch.sigmoid(out)
|
115 |
+
out = out * (1 - occlusion_last) + deformed_source * occlusion_last
|
116 |
+
output_dict["prediction"] = out
|
117 |
+
|
118 |
+
return output_dict
|
119 |
+
|
120 |
+
def get_encode(self, driver_image, occlusion_map):
|
121 |
+
out = self.first(driver_image)
|
122 |
+
encoder_map = []
|
123 |
+
encoder_map.append(self.occlude_input(out.detach(), occlusion_map[-1].detach()))
|
124 |
+
for i in range(len(self.down_blocks)):
|
125 |
+
out = self.down_blocks[i](out.detach())
|
126 |
+
out_mask = self.occlude_input(out.detach(), occlusion_map[2-i].detach())
|
127 |
+
encoder_map.append(out_mask.detach())
|
128 |
+
|
129 |
+
return encoder_map
|
130 |
+
|
modules/keypoint_detector.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
from torchvision import models
|
4 |
+
|
5 |
+
class KPDetector(nn.Module):
|
6 |
+
"""
|
7 |
+
Predict K*5 keypoints.
|
8 |
+
"""
|
9 |
+
|
10 |
+
def __init__(self, num_tps, **kwargs):
|
11 |
+
super(KPDetector, self).__init__()
|
12 |
+
self.num_tps = num_tps
|
13 |
+
|
14 |
+
self.fg_encoder = models.resnet18(pretrained=False)
|
15 |
+
num_features = self.fg_encoder.fc.in_features
|
16 |
+
self.fg_encoder.fc = nn.Linear(num_features, num_tps*5*2)
|
17 |
+
|
18 |
+
|
19 |
+
def forward(self, image):
|
20 |
+
|
21 |
+
fg_kp = self.fg_encoder(image)
|
22 |
+
bs, _, = fg_kp.shape
|
23 |
+
fg_kp = torch.sigmoid(fg_kp)
|
24 |
+
fg_kp = fg_kp * 2 - 1
|
25 |
+
out = {'fg_kp': fg_kp.view(bs, self.num_tps*5, -1)}
|
26 |
+
|
27 |
+
return out
|
modules/model.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from modules.util import AntiAliasInterpolation2d, TPS
|
5 |
+
from torchvision import models
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
class Vgg19(torch.nn.Module):
|
10 |
+
"""
|
11 |
+
Vgg19 network for perceptual loss. See Sec 3.3.
|
12 |
+
"""
|
13 |
+
def __init__(self, requires_grad=False):
|
14 |
+
super(Vgg19, self).__init__()
|
15 |
+
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
16 |
+
self.slice1 = torch.nn.Sequential()
|
17 |
+
self.slice2 = torch.nn.Sequential()
|
18 |
+
self.slice3 = torch.nn.Sequential()
|
19 |
+
self.slice4 = torch.nn.Sequential()
|
20 |
+
self.slice5 = torch.nn.Sequential()
|
21 |
+
for x in range(2):
|
22 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
23 |
+
for x in range(2, 7):
|
24 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
25 |
+
for x in range(7, 12):
|
26 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
27 |
+
for x in range(12, 21):
|
28 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
29 |
+
for x in range(21, 30):
|
30 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
31 |
+
|
32 |
+
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
|
33 |
+
requires_grad=False)
|
34 |
+
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
|
35 |
+
requires_grad=False)
|
36 |
+
|
37 |
+
if not requires_grad:
|
38 |
+
for param in self.parameters():
|
39 |
+
param.requires_grad = False
|
40 |
+
|
41 |
+
def forward(self, X):
|
42 |
+
X = (X - self.mean) / self.std
|
43 |
+
h_relu1 = self.slice1(X)
|
44 |
+
h_relu2 = self.slice2(h_relu1)
|
45 |
+
h_relu3 = self.slice3(h_relu2)
|
46 |
+
h_relu4 = self.slice4(h_relu3)
|
47 |
+
h_relu5 = self.slice5(h_relu4)
|
48 |
+
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
49 |
+
return out
|
50 |
+
|
51 |
+
|
52 |
+
class ImagePyramide(torch.nn.Module):
|
53 |
+
"""
|
54 |
+
Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
|
55 |
+
"""
|
56 |
+
def __init__(self, scales, num_channels):
|
57 |
+
super(ImagePyramide, self).__init__()
|
58 |
+
downs = {}
|
59 |
+
for scale in scales:
|
60 |
+
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
|
61 |
+
self.downs = nn.ModuleDict(downs)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
out_dict = {}
|
65 |
+
for scale, down_module in self.downs.items():
|
66 |
+
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
|
67 |
+
return out_dict
|
68 |
+
|
69 |
+
|
70 |
+
def detach_kp(kp):
|
71 |
+
return {key: value.detach() for key, value in kp.items()}
|
72 |
+
|
73 |
+
|
74 |
+
class GeneratorFullModel(torch.nn.Module):
|
75 |
+
"""
|
76 |
+
Merge all generator related updates into single model for better multi-gpu usage
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, kp_extractor, bg_predictor, dense_motion_network, inpainting_network, train_params, *kwargs):
|
80 |
+
super(GeneratorFullModel, self).__init__()
|
81 |
+
self.kp_extractor = kp_extractor
|
82 |
+
self.inpainting_network = inpainting_network
|
83 |
+
self.dense_motion_network = dense_motion_network
|
84 |
+
|
85 |
+
self.bg_predictor = None
|
86 |
+
if bg_predictor:
|
87 |
+
self.bg_predictor = bg_predictor
|
88 |
+
self.bg_start = train_params['bg_start']
|
89 |
+
|
90 |
+
self.train_params = train_params
|
91 |
+
self.scales = train_params['scales']
|
92 |
+
|
93 |
+
self.pyramid = ImagePyramide(self.scales, inpainting_network.num_channels)
|
94 |
+
if torch.cuda.is_available():
|
95 |
+
self.pyramid = self.pyramid.cuda()
|
96 |
+
|
97 |
+
self.loss_weights = train_params['loss_weights']
|
98 |
+
self.dropout_epoch = train_params['dropout_epoch']
|
99 |
+
self.dropout_maxp = train_params['dropout_maxp']
|
100 |
+
self.dropout_inc_epoch = train_params['dropout_inc_epoch']
|
101 |
+
self.dropout_startp =train_params['dropout_startp']
|
102 |
+
|
103 |
+
if sum(self.loss_weights['perceptual']) != 0:
|
104 |
+
self.vgg = Vgg19()
|
105 |
+
if torch.cuda.is_available():
|
106 |
+
self.vgg = self.vgg.cuda()
|
107 |
+
|
108 |
+
|
109 |
+
def forward(self, x, epoch):
|
110 |
+
kp_source = self.kp_extractor(x['source'])
|
111 |
+
kp_driving = self.kp_extractor(x['driving'])
|
112 |
+
bg_param = None
|
113 |
+
if self.bg_predictor:
|
114 |
+
if(epoch>=self.bg_start):
|
115 |
+
bg_param = self.bg_predictor(x['source'], x['driving'])
|
116 |
+
|
117 |
+
if(epoch>=self.dropout_epoch):
|
118 |
+
dropout_flag = False
|
119 |
+
dropout_p = 0
|
120 |
+
else:
|
121 |
+
# dropout_p will linearly increase from dropout_startp to dropout_maxp
|
122 |
+
dropout_flag = True
|
123 |
+
dropout_p = min(epoch/self.dropout_inc_epoch * self.dropout_maxp + self.dropout_startp, self.dropout_maxp)
|
124 |
+
|
125 |
+
dense_motion = self.dense_motion_network(source_image=x['source'], kp_driving=kp_driving,
|
126 |
+
kp_source=kp_source, bg_param = bg_param,
|
127 |
+
dropout_flag = dropout_flag, dropout_p = dropout_p)
|
128 |
+
generated = self.inpainting_network(x['source'], dense_motion)
|
129 |
+
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
|
130 |
+
|
131 |
+
loss_values = {}
|
132 |
+
|
133 |
+
pyramide_real = self.pyramid(x['driving'])
|
134 |
+
pyramide_generated = self.pyramid(generated['prediction'])
|
135 |
+
|
136 |
+
# reconstruction loss
|
137 |
+
if sum(self.loss_weights['perceptual']) != 0:
|
138 |
+
value_total = 0
|
139 |
+
for scale in self.scales:
|
140 |
+
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
|
141 |
+
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
|
142 |
+
|
143 |
+
for i, weight in enumerate(self.loss_weights['perceptual']):
|
144 |
+
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
|
145 |
+
value_total += self.loss_weights['perceptual'][i] * value
|
146 |
+
loss_values['perceptual'] = value_total
|
147 |
+
|
148 |
+
# equivariance loss
|
149 |
+
if self.loss_weights['equivariance_value'] != 0:
|
150 |
+
transform_random = TPS(mode = 'random', bs = x['driving'].shape[0], **self.train_params['transform_params'])
|
151 |
+
transform_grid = transform_random.transform_frame(x['driving'])
|
152 |
+
transformed_frame = F.grid_sample(x['driving'], transform_grid, padding_mode="reflection",align_corners=True)
|
153 |
+
transformed_kp = self.kp_extractor(transformed_frame)
|
154 |
+
|
155 |
+
generated['transformed_frame'] = transformed_frame
|
156 |
+
generated['transformed_kp'] = transformed_kp
|
157 |
+
|
158 |
+
warped = transform_random.warp_coordinates(transformed_kp['fg_kp'])
|
159 |
+
kp_d = kp_driving['fg_kp']
|
160 |
+
value = torch.abs(kp_d - warped).mean()
|
161 |
+
loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value
|
162 |
+
|
163 |
+
# warp loss
|
164 |
+
if self.loss_weights['warp_loss'] != 0:
|
165 |
+
occlusion_map = generated['occlusion_map']
|
166 |
+
encode_map = self.inpainting_network.get_encode(x['driving'], occlusion_map)
|
167 |
+
decode_map = generated['warped_encoder_maps']
|
168 |
+
value = 0
|
169 |
+
for i in range(len(encode_map)):
|
170 |
+
value += torch.abs(encode_map[i]-decode_map[-i-1]).mean()
|
171 |
+
|
172 |
+
loss_values['warp_loss'] = self.loss_weights['warp_loss'] * value
|
173 |
+
|
174 |
+
# bg loss
|
175 |
+
if self.bg_predictor and epoch >= self.bg_start and self.loss_weights['bg'] != 0:
|
176 |
+
bg_param_reverse = self.bg_predictor(x['driving'], x['source'])
|
177 |
+
value = torch.matmul(bg_param, bg_param_reverse)
|
178 |
+
eye = torch.eye(3).view(1, 1, 3, 3).type(value.type())
|
179 |
+
value = torch.abs(eye - value).mean()
|
180 |
+
loss_values['bg'] = self.loss_weights['bg'] * value
|
181 |
+
|
182 |
+
return loss_values, generated
|
modules/util.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class TPS:
|
7 |
+
'''
|
8 |
+
TPS transformation, mode 'kp' for Eq(2) in the paper, mode 'random' for equivariance loss.
|
9 |
+
'''
|
10 |
+
def __init__(self, mode, bs, **kwargs):
|
11 |
+
self.bs = bs
|
12 |
+
self.mode = mode
|
13 |
+
if mode == 'random':
|
14 |
+
noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
|
15 |
+
self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
|
16 |
+
self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
|
17 |
+
self.control_points = self.control_points.unsqueeze(0)
|
18 |
+
self.control_params = torch.normal(mean=0,
|
19 |
+
std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
|
20 |
+
elif mode == 'kp':
|
21 |
+
kp_1 = kwargs["kp_1"]
|
22 |
+
kp_2 = kwargs["kp_2"]
|
23 |
+
device = kp_1.device
|
24 |
+
kp_type = kp_1.type()
|
25 |
+
self.gs = kp_1.shape[1]
|
26 |
+
n = kp_1.shape[2]
|
27 |
+
K = torch.norm(kp_1[:,:,:, None]-kp_1[:,:, None, :], dim=4, p=2)
|
28 |
+
K = K**2
|
29 |
+
K = K * torch.log(K+1e-9)
|
30 |
+
|
31 |
+
one1 = torch.ones(self.bs, kp_1.shape[1], kp_1.shape[2], 1).to(device).type(kp_type)
|
32 |
+
kp_1p = torch.cat([kp_1,one1], 3)
|
33 |
+
|
34 |
+
zero = torch.zeros(self.bs, kp_1.shape[1], 3, 3).to(device).type(kp_type)
|
35 |
+
P = torch.cat([kp_1p, zero],2)
|
36 |
+
L = torch.cat([K,kp_1p.permute(0,1,3,2)],2)
|
37 |
+
L = torch.cat([L,P],3)
|
38 |
+
|
39 |
+
zero = torch.zeros(self.bs, kp_1.shape[1], 3, 2).to(device).type(kp_type)
|
40 |
+
Y = torch.cat([kp_2, zero], 2)
|
41 |
+
one = torch.eye(L.shape[2]).expand(L.shape).to(device).type(kp_type)*0.01
|
42 |
+
L = L + one
|
43 |
+
|
44 |
+
param = torch.matmul(torch.inverse(L),Y)
|
45 |
+
self.theta = param[:,:,n:,:].permute(0,1,3,2)
|
46 |
+
|
47 |
+
self.control_points = kp_1
|
48 |
+
self.control_params = param[:,:,:n,:]
|
49 |
+
else:
|
50 |
+
raise Exception("Error TPS mode")
|
51 |
+
|
52 |
+
def transform_frame(self, frame):
|
53 |
+
grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0).to(frame.device)
|
54 |
+
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
|
55 |
+
shape = [self.bs, frame.shape[2], frame.shape[3], 2]
|
56 |
+
if self.mode == 'kp':
|
57 |
+
shape.insert(1, self.gs)
|
58 |
+
grid = self.warp_coordinates(grid).view(*shape)
|
59 |
+
return grid
|
60 |
+
|
61 |
+
def warp_coordinates(self, coordinates):
|
62 |
+
theta = self.theta.type(coordinates.type()).to(coordinates.device)
|
63 |
+
control_points = self.control_points.type(coordinates.type()).to(coordinates.device)
|
64 |
+
control_params = self.control_params.type(coordinates.type()).to(coordinates.device)
|
65 |
+
|
66 |
+
if self.mode == 'kp':
|
67 |
+
transformed = torch.matmul(theta[:, :, :, :2], coordinates.permute(0, 2, 1)) + theta[:, :, :, 2:]
|
68 |
+
|
69 |
+
distances = coordinates.view(coordinates.shape[0], 1, 1, -1, 2) - control_points.view(self.bs, control_points.shape[1], -1, 1, 2)
|
70 |
+
|
71 |
+
distances = distances ** 2
|
72 |
+
result = distances.sum(-1)
|
73 |
+
result = result * torch.log(result + 1e-9)
|
74 |
+
result = torch.matmul(result.permute(0, 1, 3, 2), control_params)
|
75 |
+
transformed = transformed.permute(0, 1, 3, 2) + result
|
76 |
+
|
77 |
+
elif self.mode == 'random':
|
78 |
+
theta = theta.unsqueeze(1)
|
79 |
+
transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
|
80 |
+
transformed = transformed.squeeze(-1)
|
81 |
+
ances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
|
82 |
+
distances = ances ** 2
|
83 |
+
|
84 |
+
result = distances.sum(-1)
|
85 |
+
result = result * torch.log(result + 1e-9)
|
86 |
+
result = result * control_params
|
87 |
+
result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
|
88 |
+
transformed = transformed + result
|
89 |
+
else:
|
90 |
+
raise Exception("Error TPS mode")
|
91 |
+
|
92 |
+
return transformed
|
93 |
+
|
94 |
+
|
95 |
+
def kp2gaussian(kp, spatial_size, kp_variance):
|
96 |
+
"""
|
97 |
+
Transform a keypoint into gaussian like representation
|
98 |
+
"""
|
99 |
+
|
100 |
+
coordinate_grid = make_coordinate_grid(spatial_size, kp.type()).to(kp.device)
|
101 |
+
number_of_leading_dimensions = len(kp.shape) - 1
|
102 |
+
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
|
103 |
+
coordinate_grid = coordinate_grid.view(*shape)
|
104 |
+
repeats = kp.shape[:number_of_leading_dimensions] + (1, 1, 1)
|
105 |
+
coordinate_grid = coordinate_grid.repeat(*repeats)
|
106 |
+
|
107 |
+
# Preprocess kp shape
|
108 |
+
shape = kp.shape[:number_of_leading_dimensions] + (1, 1, 2)
|
109 |
+
kp = kp.view(*shape)
|
110 |
+
|
111 |
+
mean_sub = (coordinate_grid - kp)
|
112 |
+
|
113 |
+
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
|
114 |
+
|
115 |
+
return out
|
116 |
+
|
117 |
+
|
118 |
+
def make_coordinate_grid(spatial_size, type):
|
119 |
+
"""
|
120 |
+
Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
|
121 |
+
"""
|
122 |
+
h, w = spatial_size
|
123 |
+
x = torch.arange(w).type(type)
|
124 |
+
y = torch.arange(h).type(type)
|
125 |
+
|
126 |
+
x = (2 * (x / (w - 1)) - 1)
|
127 |
+
y = (2 * (y / (h - 1)) - 1)
|
128 |
+
|
129 |
+
yy = y.view(-1, 1).repeat(1, w)
|
130 |
+
xx = x.view(1, -1).repeat(h, 1)
|
131 |
+
|
132 |
+
meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
|
133 |
+
|
134 |
+
return meshed
|
135 |
+
|
136 |
+
|
137 |
+
class ResBlock2d(nn.Module):
|
138 |
+
"""
|
139 |
+
Res block, preserve spatial resolution.
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(self, in_features, kernel_size, padding):
|
143 |
+
super(ResBlock2d, self).__init__()
|
144 |
+
self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
145 |
+
padding=padding)
|
146 |
+
self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
147 |
+
padding=padding)
|
148 |
+
self.norm1 = nn.InstanceNorm2d(in_features, affine=True)
|
149 |
+
self.norm2 = nn.InstanceNorm2d(in_features, affine=True)
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
out = self.norm1(x)
|
153 |
+
out = F.relu(out)
|
154 |
+
out = self.conv1(out)
|
155 |
+
out = self.norm2(out)
|
156 |
+
out = F.relu(out)
|
157 |
+
out = self.conv2(out)
|
158 |
+
out += x
|
159 |
+
return out
|
160 |
+
|
161 |
+
|
162 |
+
class UpBlock2d(nn.Module):
|
163 |
+
"""
|
164 |
+
Upsampling block for use in decoder.
|
165 |
+
"""
|
166 |
+
|
167 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
168 |
+
super(UpBlock2d, self).__init__()
|
169 |
+
|
170 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
171 |
+
padding=padding, groups=groups)
|
172 |
+
self.norm = nn.InstanceNorm2d(out_features, affine=True)
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
out = F.interpolate(x, scale_factor=2)
|
176 |
+
out = self.conv(out)
|
177 |
+
out = self.norm(out)
|
178 |
+
out = F.relu(out)
|
179 |
+
return out
|
180 |
+
|
181 |
+
|
182 |
+
class DownBlock2d(nn.Module):
|
183 |
+
"""
|
184 |
+
Downsampling block for use in encoder.
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
188 |
+
super(DownBlock2d, self).__init__()
|
189 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
190 |
+
padding=padding, groups=groups)
|
191 |
+
self.norm = nn.InstanceNorm2d(out_features, affine=True)
|
192 |
+
self.pool = nn.AvgPool2d(kernel_size=(2, 2))
|
193 |
+
|
194 |
+
def forward(self, x):
|
195 |
+
out = self.conv(x)
|
196 |
+
out = self.norm(out)
|
197 |
+
out = F.relu(out)
|
198 |
+
out = self.pool(out)
|
199 |
+
return out
|
200 |
+
|
201 |
+
|
202 |
+
class SameBlock2d(nn.Module):
|
203 |
+
"""
|
204 |
+
Simple block, preserve spatial resolution.
|
205 |
+
"""
|
206 |
+
|
207 |
+
def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):
|
208 |
+
super(SameBlock2d, self).__init__()
|
209 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
|
210 |
+
kernel_size=kernel_size, padding=padding, groups=groups)
|
211 |
+
self.norm = nn.InstanceNorm2d(out_features, affine=True)
|
212 |
+
|
213 |
+
def forward(self, x):
|
214 |
+
out = self.conv(x)
|
215 |
+
out = self.norm(out)
|
216 |
+
out = F.relu(out)
|
217 |
+
return out
|
218 |
+
|
219 |
+
|
220 |
+
class Encoder(nn.Module):
|
221 |
+
"""
|
222 |
+
Hourglass Encoder
|
223 |
+
"""
|
224 |
+
|
225 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
226 |
+
super(Encoder, self).__init__()
|
227 |
+
|
228 |
+
down_blocks = []
|
229 |
+
for i in range(num_blocks):
|
230 |
+
down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
|
231 |
+
min(max_features, block_expansion * (2 ** (i + 1))),
|
232 |
+
kernel_size=3, padding=1))
|
233 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
234 |
+
|
235 |
+
def forward(self, x):
|
236 |
+
outs = [x]
|
237 |
+
#print('encoder:' ,outs[-1].shape)
|
238 |
+
for down_block in self.down_blocks:
|
239 |
+
outs.append(down_block(outs[-1]))
|
240 |
+
#print('encoder:' ,outs[-1].shape)
|
241 |
+
return outs
|
242 |
+
|
243 |
+
|
244 |
+
class Decoder(nn.Module):
|
245 |
+
"""
|
246 |
+
Hourglass Decoder
|
247 |
+
"""
|
248 |
+
|
249 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
250 |
+
super(Decoder, self).__init__()
|
251 |
+
|
252 |
+
up_blocks = []
|
253 |
+
self.out_channels = []
|
254 |
+
for i in range(num_blocks)[::-1]:
|
255 |
+
in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
|
256 |
+
self.out_channels.append(in_filters)
|
257 |
+
out_filters = min(max_features, block_expansion * (2 ** i))
|
258 |
+
up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
|
259 |
+
|
260 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
261 |
+
self.out_channels.append(block_expansion + in_features)
|
262 |
+
# self.out_filters = block_expansion + in_features
|
263 |
+
|
264 |
+
def forward(self, x, mode = 0):
|
265 |
+
out = x.pop()
|
266 |
+
outs = []
|
267 |
+
for up_block in self.up_blocks:
|
268 |
+
out = up_block(out)
|
269 |
+
skip = x.pop()
|
270 |
+
out = torch.cat([out, skip], dim=1)
|
271 |
+
outs.append(out)
|
272 |
+
if(mode == 0):
|
273 |
+
return out
|
274 |
+
else:
|
275 |
+
return outs
|
276 |
+
|
277 |
+
|
278 |
+
class Hourglass(nn.Module):
|
279 |
+
"""
|
280 |
+
Hourglass architecture.
|
281 |
+
"""
|
282 |
+
|
283 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
284 |
+
super(Hourglass, self).__init__()
|
285 |
+
self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
|
286 |
+
self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
|
287 |
+
self.out_channels = self.decoder.out_channels
|
288 |
+
# self.out_filters = self.decoder.out_filters
|
289 |
+
|
290 |
+
def forward(self, x, mode = 0):
|
291 |
+
return self.decoder(self.encoder(x), mode)
|
292 |
+
|
293 |
+
|
294 |
+
class AntiAliasInterpolation2d(nn.Module):
|
295 |
+
"""
|
296 |
+
Band-limited downsampling, for better preservation of the input signal.
|
297 |
+
"""
|
298 |
+
def __init__(self, channels, scale):
|
299 |
+
super(AntiAliasInterpolation2d, self).__init__()
|
300 |
+
sigma = (1 / scale - 1) / 2
|
301 |
+
kernel_size = 2 * round(sigma * 4) + 1
|
302 |
+
self.ka = kernel_size // 2
|
303 |
+
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
|
304 |
+
|
305 |
+
kernel_size = [kernel_size, kernel_size]
|
306 |
+
sigma = [sigma, sigma]
|
307 |
+
# The gaussian kernel is the product of the
|
308 |
+
# gaussian function of each dimension.
|
309 |
+
kernel = 1
|
310 |
+
meshgrids = torch.meshgrid(
|
311 |
+
[
|
312 |
+
torch.arange(size, dtype=torch.float32)
|
313 |
+
for size in kernel_size
|
314 |
+
]
|
315 |
+
)
|
316 |
+
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
317 |
+
mean = (size - 1) / 2
|
318 |
+
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
|
319 |
+
|
320 |
+
# Make sure sum of values in gaussian kernel equals 1.
|
321 |
+
kernel = kernel / torch.sum(kernel)
|
322 |
+
# Reshape to depthwise convolutional weight
|
323 |
+
kernel = kernel.view(1, 1, *kernel.size())
|
324 |
+
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
325 |
+
|
326 |
+
self.register_buffer('weight', kernel)
|
327 |
+
self.groups = channels
|
328 |
+
self.scale = scale
|
329 |
+
|
330 |
+
def forward(self, input):
|
331 |
+
if self.scale == 1.0:
|
332 |
+
return input
|
333 |
+
|
334 |
+
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
|
335 |
+
out = F.conv2d(out, weight=self.weight, groups=self.groups)
|
336 |
+
out = F.interpolate(out, scale_factor=(self.scale, self.scale))
|
337 |
+
|
338 |
+
return out
|
339 |
+
|
340 |
+
|
341 |
+
def to_homogeneous(coordinates):
|
342 |
+
ones_shape = list(coordinates.shape)
|
343 |
+
ones_shape[-1] = 1
|
344 |
+
ones = torch.ones(ones_shape).type(coordinates.type())
|
345 |
+
|
346 |
+
return torch.cat([coordinates, ones], dim=-1)
|
347 |
+
|
348 |
+
def from_homogeneous(coordinates):
|
349 |
+
return coordinates[..., :2] / coordinates[..., 2:3]
|
reconstruction.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tqdm import tqdm
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from logger import Logger, Visualizer
|
6 |
+
import numpy as np
|
7 |
+
import imageio
|
8 |
+
|
9 |
+
|
10 |
+
def reconstruction(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset):
|
11 |
+
png_dir = os.path.join(log_dir, 'reconstruction/png')
|
12 |
+
log_dir = os.path.join(log_dir, 'reconstruction')
|
13 |
+
|
14 |
+
if checkpoint is not None:
|
15 |
+
Logger.load_cpk(checkpoint, inpainting_network=inpainting_network, kp_detector=kp_detector,
|
16 |
+
bg_predictor=bg_predictor, dense_motion_network=dense_motion_network)
|
17 |
+
else:
|
18 |
+
raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
|
19 |
+
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
20 |
+
|
21 |
+
if not os.path.exists(log_dir):
|
22 |
+
os.makedirs(log_dir)
|
23 |
+
|
24 |
+
if not os.path.exists(png_dir):
|
25 |
+
os.makedirs(png_dir)
|
26 |
+
|
27 |
+
loss_list = []
|
28 |
+
|
29 |
+
inpainting_network.eval()
|
30 |
+
kp_detector.eval()
|
31 |
+
dense_motion_network.eval()
|
32 |
+
if bg_predictor:
|
33 |
+
bg_predictor.eval()
|
34 |
+
|
35 |
+
for it, x in tqdm(enumerate(dataloader)):
|
36 |
+
with torch.no_grad():
|
37 |
+
predictions = []
|
38 |
+
visualizations = []
|
39 |
+
if torch.cuda.is_available():
|
40 |
+
x['video'] = x['video'].cuda()
|
41 |
+
kp_source = kp_detector(x['video'][:, :, 0])
|
42 |
+
for frame_idx in range(x['video'].shape[2]):
|
43 |
+
source = x['video'][:, :, 0]
|
44 |
+
driving = x['video'][:, :, frame_idx]
|
45 |
+
kp_driving = kp_detector(driving)
|
46 |
+
bg_params = None
|
47 |
+
if bg_predictor:
|
48 |
+
bg_params = bg_predictor(source, driving)
|
49 |
+
|
50 |
+
dense_motion = dense_motion_network(source_image=source, kp_driving=kp_driving,
|
51 |
+
kp_source=kp_source, bg_param = bg_params,
|
52 |
+
dropout_flag = False)
|
53 |
+
out = inpainting_network(source, dense_motion)
|
54 |
+
out['kp_source'] = kp_source
|
55 |
+
out['kp_driving'] = kp_driving
|
56 |
+
|
57 |
+
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
58 |
+
|
59 |
+
visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
|
60 |
+
driving=driving, out=out)
|
61 |
+
visualizations.append(visualization)
|
62 |
+
loss = torch.abs(out['prediction'] - driving).mean().cpu().numpy()
|
63 |
+
|
64 |
+
loss_list.append(loss)
|
65 |
+
# print(np.mean(loss_list))
|
66 |
+
predictions = np.concatenate(predictions, axis=1)
|
67 |
+
imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))
|
68 |
+
|
69 |
+
print("Reconstruction loss: %s" % np.mean(loss_list))
|
requirements.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cffi==1.14.6
|
2 |
+
cycler==0.10.0
|
3 |
+
decorator==5.1.0
|
4 |
+
face-alignment==1.3.5
|
5 |
+
imageio==2.9.0
|
6 |
+
imageio-ffmpeg==0.4.5
|
7 |
+
kiwisolver==1.3.2
|
8 |
+
matplotlib==3.4.3
|
9 |
+
networkx==2.6.3
|
10 |
+
numpy==1.20.3
|
11 |
+
pandas==1.3.3
|
12 |
+
Pillow==8.3.2
|
13 |
+
pycparser==2.20
|
14 |
+
pyparsing==2.4.7
|
15 |
+
python-dateutil==2.8.2
|
16 |
+
pytz==2021.1
|
17 |
+
PyWavelets==1.1.1
|
18 |
+
PyYAML==5.4.1
|
19 |
+
scikit-image==0.18.3
|
20 |
+
scikit-learn==1.0
|
21 |
+
scipy==1.7.1
|
22 |
+
six==1.16.0
|
23 |
+
torch==1.11.0
|
24 |
+
torchvision==0.12.0
|
25 |
+
tqdm==4.62.3
|
26 |
+
gradio
|
run.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
matplotlib.use('Agg')
|
3 |
+
|
4 |
+
import os, sys
|
5 |
+
import yaml
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
from time import gmtime, strftime
|
8 |
+
from shutil import copy
|
9 |
+
from frames_dataset import FramesDataset
|
10 |
+
|
11 |
+
from modules.inpainting_network import InpaintingNetwork
|
12 |
+
from modules.keypoint_detector import KPDetector
|
13 |
+
from modules.bg_motion_predictor import BGMotionPredictor
|
14 |
+
from modules.dense_motion import DenseMotionNetwork
|
15 |
+
from modules.avd_network import AVDNetwork
|
16 |
+
import torch
|
17 |
+
from train import train
|
18 |
+
from train_avd import train_avd
|
19 |
+
from reconstruction import reconstruction
|
20 |
+
import os
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
|
25 |
+
if sys.version_info[0] < 3:
|
26 |
+
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.9")
|
27 |
+
|
28 |
+
parser = ArgumentParser()
|
29 |
+
parser.add_argument("--config", default="config/vox-256.yaml", help="path to config")
|
30 |
+
parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "train_avd"])
|
31 |
+
parser.add_argument("--log_dir", default='log', help="path to log into")
|
32 |
+
parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore")
|
33 |
+
parser.add_argument("--device_ids", default="0,1", type=lambda x: list(map(int, x.split(','))),
|
34 |
+
help="Names of the devices comma separated.")
|
35 |
+
|
36 |
+
opt = parser.parse_args()
|
37 |
+
with open(opt.config) as f:
|
38 |
+
config = yaml.load(f)
|
39 |
+
|
40 |
+
if opt.checkpoint is not None:
|
41 |
+
log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
|
42 |
+
else:
|
43 |
+
log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
|
44 |
+
log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())
|
45 |
+
|
46 |
+
inpainting = InpaintingNetwork(**config['model_params']['generator_params'],
|
47 |
+
**config['model_params']['common_params'])
|
48 |
+
|
49 |
+
if torch.cuda.is_available():
|
50 |
+
cuda_device = torch.device('cuda:'+str(opt.device_ids[0]))
|
51 |
+
inpainting.to(cuda_device)
|
52 |
+
|
53 |
+
kp_detector = KPDetector(**config['model_params']['common_params'])
|
54 |
+
dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'],
|
55 |
+
**config['model_params']['dense_motion_params'])
|
56 |
+
|
57 |
+
if torch.cuda.is_available():
|
58 |
+
kp_detector.to(opt.device_ids[0])
|
59 |
+
dense_motion_network.to(opt.device_ids[0])
|
60 |
+
|
61 |
+
bg_predictor = None
|
62 |
+
if (config['model_params']['common_params']['bg']):
|
63 |
+
bg_predictor = BGMotionPredictor()
|
64 |
+
if torch.cuda.is_available():
|
65 |
+
bg_predictor.to(opt.device_ids[0])
|
66 |
+
|
67 |
+
avd_network = None
|
68 |
+
if opt.mode == "train_avd":
|
69 |
+
avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'],
|
70 |
+
**config['model_params']['avd_network_params'])
|
71 |
+
if torch.cuda.is_available():
|
72 |
+
avd_network.to(opt.device_ids[0])
|
73 |
+
|
74 |
+
dataset = FramesDataset(is_train=(opt.mode.startswith('train')), **config['dataset_params'])
|
75 |
+
|
76 |
+
if not os.path.exists(log_dir):
|
77 |
+
os.makedirs(log_dir)
|
78 |
+
if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
|
79 |
+
copy(opt.config, log_dir)
|
80 |
+
|
81 |
+
if opt.mode == 'train':
|
82 |
+
print("Training...")
|
83 |
+
train(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset)
|
84 |
+
elif opt.mode == 'train_avd':
|
85 |
+
print("Training Animation via Disentaglement...")
|
86 |
+
train_avd(config, inpainting, kp_detector, bg_predictor, dense_motion_network, avd_network, opt.checkpoint, log_dir, dataset)
|
87 |
+
elif opt.mode == 'reconstruction':
|
88 |
+
print("Reconstruction...")
|
89 |
+
reconstruction(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset)
|
train.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import trange
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from logger import Logger
|
5 |
+
from modules.model import GeneratorFullModel
|
6 |
+
from torch.optim.lr_scheduler import MultiStepLR
|
7 |
+
from torch.nn.utils import clip_grad_norm_
|
8 |
+
from frames_dataset import DatasetRepeater
|
9 |
+
import math
|
10 |
+
|
11 |
+
def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset):
|
12 |
+
train_params = config['train_params']
|
13 |
+
optimizer = torch.optim.Adam(
|
14 |
+
[{'params': list(inpainting_network.parameters()) +
|
15 |
+
list(dense_motion_network.parameters()) +
|
16 |
+
list(kp_detector.parameters()), 'initial_lr': train_params['lr_generator']}],lr=train_params['lr_generator'], betas=(0.5, 0.999), weight_decay = 1e-4)
|
17 |
+
|
18 |
+
optimizer_bg_predictor = None
|
19 |
+
if bg_predictor:
|
20 |
+
optimizer_bg_predictor = torch.optim.Adam(
|
21 |
+
[{'params':bg_predictor.parameters(),'initial_lr': train_params['lr_generator']}],
|
22 |
+
lr=train_params['lr_generator'], betas=(0.5, 0.999), weight_decay = 1e-4)
|
23 |
+
|
24 |
+
if checkpoint is not None:
|
25 |
+
start_epoch = Logger.load_cpk(
|
26 |
+
checkpoint, inpainting_network = inpainting_network, dense_motion_network = dense_motion_network,
|
27 |
+
kp_detector = kp_detector, bg_predictor = bg_predictor,
|
28 |
+
optimizer = optimizer, optimizer_bg_predictor = optimizer_bg_predictor)
|
29 |
+
print('load success:', start_epoch)
|
30 |
+
start_epoch += 1
|
31 |
+
else:
|
32 |
+
start_epoch = 0
|
33 |
+
|
34 |
+
scheduler_optimizer = MultiStepLR(optimizer, train_params['epoch_milestones'], gamma=0.1,
|
35 |
+
last_epoch=start_epoch - 1)
|
36 |
+
if bg_predictor:
|
37 |
+
scheduler_bg_predictor = MultiStepLR(optimizer_bg_predictor, train_params['epoch_milestones'],
|
38 |
+
gamma=0.1, last_epoch=start_epoch - 1)
|
39 |
+
|
40 |
+
if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
|
41 |
+
dataset = DatasetRepeater(dataset, train_params['num_repeats'])
|
42 |
+
dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True,
|
43 |
+
num_workers=train_params['dataloader_workers'], drop_last=True)
|
44 |
+
|
45 |
+
generator_full = GeneratorFullModel(kp_detector, bg_predictor, dense_motion_network, inpainting_network, train_params)
|
46 |
+
|
47 |
+
if torch.cuda.is_available():
|
48 |
+
generator_full = torch.nn.DataParallel(generator_full).cuda()
|
49 |
+
|
50 |
+
bg_start = train_params['bg_start']
|
51 |
+
|
52 |
+
with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'],
|
53 |
+
checkpoint_freq=train_params['checkpoint_freq']) as logger:
|
54 |
+
for epoch in trange(start_epoch, train_params['num_epochs']):
|
55 |
+
for x in dataloader:
|
56 |
+
if(torch.cuda.is_available()):
|
57 |
+
x['driving'] = x['driving'].cuda()
|
58 |
+
x['source'] = x['source'].cuda()
|
59 |
+
|
60 |
+
losses_generator, generated = generator_full(x, epoch)
|
61 |
+
loss_values = [val.mean() for val in losses_generator.values()]
|
62 |
+
loss = sum(loss_values)
|
63 |
+
loss.backward()
|
64 |
+
|
65 |
+
clip_grad_norm_(kp_detector.parameters(), max_norm=10, norm_type = math.inf)
|
66 |
+
clip_grad_norm_(dense_motion_network.parameters(), max_norm=10, norm_type = math.inf)
|
67 |
+
if bg_predictor and epoch>=bg_start:
|
68 |
+
clip_grad_norm_(bg_predictor.parameters(), max_norm=10, norm_type = math.inf)
|
69 |
+
|
70 |
+
optimizer.step()
|
71 |
+
optimizer.zero_grad()
|
72 |
+
if bg_predictor and epoch>=bg_start:
|
73 |
+
optimizer_bg_predictor.step()
|
74 |
+
optimizer_bg_predictor.zero_grad()
|
75 |
+
|
76 |
+
losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
|
77 |
+
logger.log_iter(losses=losses)
|
78 |
+
|
79 |
+
scheduler_optimizer.step()
|
80 |
+
if bg_predictor:
|
81 |
+
scheduler_bg_predictor.step()
|
82 |
+
|
83 |
+
model_save = {
|
84 |
+
'inpainting_network': inpainting_network,
|
85 |
+
'dense_motion_network': dense_motion_network,
|
86 |
+
'kp_detector': kp_detector,
|
87 |
+
'optimizer': optimizer,
|
88 |
+
}
|
89 |
+
if bg_predictor and epoch>=bg_start:
|
90 |
+
model_save['bg_predictor'] = bg_predictor
|
91 |
+
model_save['optimizer_bg_predictor'] = optimizer_bg_predictor
|
92 |
+
|
93 |
+
logger.log_epoch(epoch, model_save, inp=x, out=generated)
|
94 |
+
|
train_avd.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import trange
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from logger import Logger
|
5 |
+
from torch.optim.lr_scheduler import MultiStepLR
|
6 |
+
from frames_dataset import DatasetRepeater
|
7 |
+
|
8 |
+
|
9 |
+
def random_scale(kp_params, scale):
|
10 |
+
theta = torch.rand(kp_params['fg_kp'].shape[0], 2) * (2 * scale) + (1 - scale)
|
11 |
+
theta = torch.diag_embed(theta).unsqueeze(1).type(kp_params['fg_kp'].type())
|
12 |
+
new_kp_params = {'fg_kp': torch.matmul(theta, kp_params['fg_kp'].unsqueeze(-1)).squeeze(-1)}
|
13 |
+
return new_kp_params
|
14 |
+
|
15 |
+
|
16 |
+
def train_avd(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network,
|
17 |
+
avd_network, checkpoint, log_dir, dataset):
|
18 |
+
train_params = config['train_avd_params']
|
19 |
+
|
20 |
+
optimizer = torch.optim.Adam(avd_network.parameters(), lr=train_params['lr'], betas=(0.5, 0.999))
|
21 |
+
|
22 |
+
if checkpoint is not None:
|
23 |
+
Logger.load_cpk(checkpoint, inpainting_network=inpainting_network, kp_detector=kp_detector,
|
24 |
+
bg_predictor=bg_predictor, avd_network=avd_network,
|
25 |
+
dense_motion_network= dense_motion_network,optimizer_avd=optimizer)
|
26 |
+
start_epoch = 0
|
27 |
+
else:
|
28 |
+
raise AttributeError("Checkpoint should be specified for mode='train_avd'.")
|
29 |
+
|
30 |
+
scheduler = MultiStepLR(optimizer, train_params['epoch_milestones'], gamma=0.1)
|
31 |
+
|
32 |
+
if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
|
33 |
+
dataset = DatasetRepeater(dataset, train_params['num_repeats'])
|
34 |
+
|
35 |
+
dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True,
|
36 |
+
num_workers=train_params['dataloader_workers'], drop_last=True)
|
37 |
+
|
38 |
+
with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'],
|
39 |
+
checkpoint_freq=train_params['checkpoint_freq']) as logger:
|
40 |
+
for epoch in trange(start_epoch, train_params['num_epochs']):
|
41 |
+
avd_network.train()
|
42 |
+
for x in dataloader:
|
43 |
+
with torch.no_grad():
|
44 |
+
kp_source = kp_detector(x['source'].cuda())
|
45 |
+
kp_driving_gt = kp_detector(x['driving'].cuda())
|
46 |
+
kp_driving_random = random_scale(kp_driving_gt, scale=train_params['random_scale'])
|
47 |
+
rec = avd_network(kp_source, kp_driving_random)
|
48 |
+
|
49 |
+
reconstruction_kp = train_params['lambda_shift'] * \
|
50 |
+
torch.abs(kp_driving_gt['fg_kp'] - rec['fg_kp']).mean()
|
51 |
+
|
52 |
+
loss_dict = {'rec_kp': reconstruction_kp}
|
53 |
+
loss = reconstruction_kp
|
54 |
+
|
55 |
+
loss.backward()
|
56 |
+
optimizer.step()
|
57 |
+
optimizer.zero_grad()
|
58 |
+
|
59 |
+
losses = {key: value.mean().detach().data.cpu().numpy() for key, value in loss_dict.items()}
|
60 |
+
logger.log_iter(losses=losses)
|
61 |
+
|
62 |
+
# Visualization
|
63 |
+
avd_network.eval()
|
64 |
+
with torch.no_grad():
|
65 |
+
source = x['source'][:6].cuda()
|
66 |
+
driving = torch.cat([x['driving'][[0, 1]].cuda(), source[[2, 3, 2, 1]]], dim=0)
|
67 |
+
kp_source = kp_detector(source)
|
68 |
+
kp_driving = kp_detector(driving)
|
69 |
+
|
70 |
+
out = avd_network(kp_source, kp_driving)
|
71 |
+
kp_driving = out
|
72 |
+
dense_motion = dense_motion_network(source_image=source, kp_driving=kp_driving,
|
73 |
+
kp_source=kp_source)
|
74 |
+
generated = inpainting_network(source, dense_motion)
|
75 |
+
|
76 |
+
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
|
77 |
+
|
78 |
+
scheduler.step(epoch)
|
79 |
+
model_save = {
|
80 |
+
'inpainting_network': inpainting_network,
|
81 |
+
'dense_motion_network': dense_motion_network,
|
82 |
+
'kp_detector': kp_detector,
|
83 |
+
'avd_network': avd_network,
|
84 |
+
'optimizer_avd': optimizer
|
85 |
+
}
|
86 |
+
if bg_predictor :
|
87 |
+
model_save['bg_predictor'] = bg_predictor
|
88 |
+
|
89 |
+
logger.log_epoch(epoch, model_save,
|
90 |
+
inp={'source': source, 'driving': driving},
|
91 |
+
out=generated)
|