AlekseyKorshuk commited on
Commit
f844f44
·
1 Parent(s): 7ea1f4a

feat: updates

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 yoyo-nb
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import imageio
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib.animation as animation
6
+ from skimage.transform import resize
7
+ import warnings
8
+ import os
9
+ from demo import make_animation
10
+ from skimage import img_as_ubyte
11
+ from demo import load_checkpoints
12
+ import gradio
13
+
14
+
15
+ def inference(source_image_path='./assets/source.png', driving_video_path='./assets/driving.mp4', dataset_name="vox"):
16
+ # edit the config
17
+ device = torch.device('cpu')
18
+ # dataset_name = 'vox' # ['vox', 'taichi', 'ted', 'mgif']
19
+ # source_image_path = './assets/source.png'
20
+ # driving_video_path = './assets/driving.mp4'
21
+ output_video_path = './generated.mp4'
22
+
23
+ pixel = 256 # for vox, taichi and mgif, the resolution is 256*256
24
+ if (dataset_name == 'ted'): # for ted, the resolution is 384*384
25
+ pixel = 384
26
+ config_path = f'config/{dataset_name}-{pixel}.yaml'
27
+ checkpoint_path = f'checkpoints/{dataset_name}.pth.tar'
28
+ predict_mode = 'relative' # ['standard', 'relative', 'avd']
29
+
30
+ warnings.filterwarnings("ignore")
31
+
32
+ source_image = imageio.imread(source_image_path)
33
+ reader = imageio.get_reader(driving_video_path)
34
+
35
+ source_image = resize(source_image, (pixel, pixel))[..., :3]
36
+
37
+ fps = reader.get_meta_data()['fps']
38
+ driving_video = []
39
+ try:
40
+ for im in reader:
41
+ driving_video.append(im)
42
+ except RuntimeError:
43
+ pass
44
+ reader.close()
45
+
46
+ driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video]
47
+
48
+ # driving_video = driving_video[:10]
49
+
50
+ def display(source, driving, generated=None) -> animation.ArtistAnimation:
51
+ fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))
52
+
53
+ ims = []
54
+ for i in range(len(driving)):
55
+ cols = [source]
56
+ cols.append(driving[i])
57
+ if generated is not None:
58
+ cols.append(generated[i])
59
+ im = plt.imshow(np.concatenate(cols, axis=1), animated=True)
60
+ plt.axis('off')
61
+ ims.append([im])
62
+
63
+ ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)
64
+ # plt.show()
65
+ plt.close()
66
+ return ani
67
+
68
+ inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path=config_path,
69
+ checkpoint_path=checkpoint_path,
70
+ device=device)
71
+
72
+ predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network,
73
+ avd_network, device=device, mode=predict_mode)
74
+
75
+ # save resulting video
76
+ imageio.mimsave(output_video_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)
77
+
78
+ ani = display(source_image, driving_video, predictions)
79
+ ani.save('animation.mp4', writer='imagemagick', fps=60)
80
+ return 'animation.mp4'
81
+
82
+
83
+ demo = gradio.Interface(
84
+ fn=inference,
85
+ inputs=[
86
+ gradio.inputs.Image(type="filepath", label="Input image"),
87
+ gradio.inputs.Video(label="Input video"),
88
+ gradio.inputs.Dropdown(['vox', 'taichi', 'ted', 'mgif'], type="value", default="vox", label="Model",
89
+ optional=False),
90
+
91
+ ],
92
+ outputs=["video"],
93
+ examples=[
94
+ ['./assets/source.png', './assets/driving.mp4', "vox"],
95
+ ['./assets/source_ted.png', './assets/driving_ted.mp4', "ted"],
96
+ ],
97
+ )
98
+
99
+ if __name__ == "__main__":
100
+ demo.launch()
assets/driving.mp4 ADDED
Binary file (556 kB). View file
 
assets/driving_ted.mp4 ADDED
Binary file (206 kB). View file
 
assets/source.png ADDED
assets/source_ted.png ADDED
augmentation.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code from https://github.com/hassony2/torch_videovision
3
+ """
4
+
5
+ import numbers
6
+
7
+ import random
8
+ import numpy as np
9
+ import PIL
10
+
11
+ from skimage.transform import resize, rotate
12
+ import torchvision
13
+
14
+ import warnings
15
+
16
+ from skimage import img_as_ubyte, img_as_float
17
+
18
+
19
+ def crop_clip(clip, min_h, min_w, h, w):
20
+ if isinstance(clip[0], np.ndarray):
21
+ cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
22
+
23
+ elif isinstance(clip[0], PIL.Image.Image):
24
+ cropped = [
25
+ img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
26
+ ]
27
+ else:
28
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
29
+ 'but got list of {0}'.format(type(clip[0])))
30
+ return cropped
31
+
32
+
33
+ def pad_clip(clip, h, w):
34
+ im_h, im_w = clip[0].shape[:2]
35
+ pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
36
+ pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)
37
+
38
+ return np.pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')
39
+
40
+
41
+ def resize_clip(clip, size, interpolation='bilinear'):
42
+ if isinstance(clip[0], np.ndarray):
43
+ if isinstance(size, numbers.Number):
44
+ im_h, im_w, im_c = clip[0].shape
45
+ # Min spatial dim already matches minimal size
46
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w
47
+ and im_h == size):
48
+ return clip
49
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
50
+ size = (new_w, new_h)
51
+ else:
52
+ size = size[1], size[0]
53
+
54
+ scaled = [
55
+ resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
56
+ mode='constant', anti_aliasing=True) for img in clip
57
+ ]
58
+ elif isinstance(clip[0], PIL.Image.Image):
59
+ if isinstance(size, numbers.Number):
60
+ im_w, im_h = clip[0].size
61
+ # Min spatial dim already matches minimal size
62
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w
63
+ and im_h == size):
64
+ return clip
65
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
66
+ size = (new_w, new_h)
67
+ else:
68
+ size = size[1], size[0]
69
+ if interpolation == 'bilinear':
70
+ pil_inter = PIL.Image.NEAREST
71
+ else:
72
+ pil_inter = PIL.Image.BILINEAR
73
+ scaled = [img.resize(size, pil_inter) for img in clip]
74
+ else:
75
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
76
+ 'but got list of {0}'.format(type(clip[0])))
77
+ return scaled
78
+
79
+
80
+ def get_resize_sizes(im_h, im_w, size):
81
+ if im_w < im_h:
82
+ ow = size
83
+ oh = int(size * im_h / im_w)
84
+ else:
85
+ oh = size
86
+ ow = int(size * im_w / im_h)
87
+ return oh, ow
88
+
89
+
90
+ class RandomFlip(object):
91
+ def __init__(self, time_flip=False, horizontal_flip=False):
92
+ self.time_flip = time_flip
93
+ self.horizontal_flip = horizontal_flip
94
+
95
+ def __call__(self, clip):
96
+ if random.random() < 0.5 and self.time_flip:
97
+ return clip[::-1]
98
+ if random.random() < 0.5 and self.horizontal_flip:
99
+ return [np.fliplr(img) for img in clip]
100
+
101
+ return clip
102
+
103
+
104
+ class RandomResize(object):
105
+ """Resizes a list of (H x W x C) numpy.ndarray to the final size
106
+ The larger the original image is, the more times it takes to
107
+ interpolate
108
+ Args:
109
+ interpolation (str): Can be one of 'nearest', 'bilinear'
110
+ defaults to nearest
111
+ size (tuple): (widht, height)
112
+ """
113
+
114
+ def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
115
+ self.ratio = ratio
116
+ self.interpolation = interpolation
117
+
118
+ def __call__(self, clip):
119
+ scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
120
+
121
+ if isinstance(clip[0], np.ndarray):
122
+ im_h, im_w, im_c = clip[0].shape
123
+ elif isinstance(clip[0], PIL.Image.Image):
124
+ im_w, im_h = clip[0].size
125
+
126
+ new_w = int(im_w * scaling_factor)
127
+ new_h = int(im_h * scaling_factor)
128
+ new_size = (new_w, new_h)
129
+ resized = resize_clip(
130
+ clip, new_size, interpolation=self.interpolation)
131
+
132
+ return resized
133
+
134
+
135
+ class RandomCrop(object):
136
+ """Extract random crop at the same location for a list of videos
137
+ Args:
138
+ size (sequence or int): Desired output size for the
139
+ crop in format (h, w)
140
+ """
141
+
142
+ def __init__(self, size):
143
+ if isinstance(size, numbers.Number):
144
+ size = (size, size)
145
+
146
+ self.size = size
147
+
148
+ def __call__(self, clip):
149
+ """
150
+ Args:
151
+ img (PIL.Image or numpy.ndarray): List of videos to be cropped
152
+ in format (h, w, c) in numpy.ndarray
153
+ Returns:
154
+ PIL.Image or numpy.ndarray: Cropped list of videos
155
+ """
156
+ h, w = self.size
157
+ if isinstance(clip[0], np.ndarray):
158
+ im_h, im_w, im_c = clip[0].shape
159
+ elif isinstance(clip[0], PIL.Image.Image):
160
+ im_w, im_h = clip[0].size
161
+ else:
162
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
163
+ 'but got list of {0}'.format(type(clip[0])))
164
+
165
+ clip = pad_clip(clip, h, w)
166
+ im_h, im_w = clip.shape[1:3]
167
+ x1 = 0 if h == im_h else random.randint(0, im_w - w)
168
+ y1 = 0 if w == im_w else random.randint(0, im_h - h)
169
+ cropped = crop_clip(clip, y1, x1, h, w)
170
+
171
+ return cropped
172
+
173
+
174
+ class RandomRotation(object):
175
+ """Rotate entire clip randomly by a random angle within
176
+ given bounds
177
+ Args:
178
+ degrees (sequence or int): Range of degrees to select from
179
+ If degrees is a number instead of sequence like (min, max),
180
+ the range of degrees, will be (-degrees, +degrees).
181
+ """
182
+
183
+ def __init__(self, degrees):
184
+ if isinstance(degrees, numbers.Number):
185
+ if degrees < 0:
186
+ raise ValueError('If degrees is a single number,'
187
+ 'must be positive')
188
+ degrees = (-degrees, degrees)
189
+ else:
190
+ if len(degrees) != 2:
191
+ raise ValueError('If degrees is a sequence,'
192
+ 'it must be of len 2.')
193
+
194
+ self.degrees = degrees
195
+
196
+ def __call__(self, clip):
197
+ """
198
+ Args:
199
+ img (PIL.Image or numpy.ndarray): List of videos to be cropped
200
+ in format (h, w, c) in numpy.ndarray
201
+ Returns:
202
+ PIL.Image or numpy.ndarray: Cropped list of videos
203
+ """
204
+ angle = random.uniform(self.degrees[0], self.degrees[1])
205
+ if isinstance(clip[0], np.ndarray):
206
+ rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
207
+ elif isinstance(clip[0], PIL.Image.Image):
208
+ rotated = [img.rotate(angle) for img in clip]
209
+ else:
210
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
211
+ 'but got list of {0}'.format(type(clip[0])))
212
+
213
+ return rotated
214
+
215
+
216
+ class ColorJitter(object):
217
+ """Randomly change the brightness, contrast and saturation and hue of the clip
218
+ Args:
219
+ brightness (float): How much to jitter brightness. brightness_factor
220
+ is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
221
+ contrast (float): How much to jitter contrast. contrast_factor
222
+ is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
223
+ saturation (float): How much to jitter saturation. saturation_factor
224
+ is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
225
+ hue(float): How much to jitter hue. hue_factor is chosen uniformly from
226
+ [-hue, hue]. Should be >=0 and <= 0.5.
227
+ """
228
+
229
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
230
+ self.brightness = brightness
231
+ self.contrast = contrast
232
+ self.saturation = saturation
233
+ self.hue = hue
234
+
235
+ def get_params(self, brightness, contrast, saturation, hue):
236
+ if brightness > 0:
237
+ brightness_factor = random.uniform(
238
+ max(0, 1 - brightness), 1 + brightness)
239
+ else:
240
+ brightness_factor = None
241
+
242
+ if contrast > 0:
243
+ contrast_factor = random.uniform(
244
+ max(0, 1 - contrast), 1 + contrast)
245
+ else:
246
+ contrast_factor = None
247
+
248
+ if saturation > 0:
249
+ saturation_factor = random.uniform(
250
+ max(0, 1 - saturation), 1 + saturation)
251
+ else:
252
+ saturation_factor = None
253
+
254
+ if hue > 0:
255
+ hue_factor = random.uniform(-hue, hue)
256
+ else:
257
+ hue_factor = None
258
+ return brightness_factor, contrast_factor, saturation_factor, hue_factor
259
+
260
+ def __call__(self, clip):
261
+ """
262
+ Args:
263
+ clip (list): list of PIL.Image
264
+ Returns:
265
+ list PIL.Image : list of transformed PIL.Image
266
+ """
267
+ if isinstance(clip[0], np.ndarray):
268
+ brightness, contrast, saturation, hue = self.get_params(
269
+ self.brightness, self.contrast, self.saturation, self.hue)
270
+
271
+ # Create img transform function sequence
272
+ img_transforms = []
273
+ if brightness is not None:
274
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
275
+ if saturation is not None:
276
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
277
+ if hue is not None:
278
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
279
+ if contrast is not None:
280
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
281
+ random.shuffle(img_transforms)
282
+ img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,
283
+ img_as_float]
284
+
285
+ with warnings.catch_warnings():
286
+ warnings.simplefilter("ignore")
287
+ jittered_clip = []
288
+ for img in clip:
289
+ jittered_img = img
290
+ for func in img_transforms:
291
+ jittered_img = func(jittered_img)
292
+ jittered_clip.append(jittered_img.astype('float32'))
293
+ elif isinstance(clip[0], PIL.Image.Image):
294
+ brightness, contrast, saturation, hue = self.get_params(
295
+ self.brightness, self.contrast, self.saturation, self.hue)
296
+
297
+ # Create img transform function sequence
298
+ img_transforms = []
299
+ if brightness is not None:
300
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
301
+ if saturation is not None:
302
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
303
+ if hue is not None:
304
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
305
+ if contrast is not None:
306
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
307
+ random.shuffle(img_transforms)
308
+
309
+ # Apply to all videos
310
+ jittered_clip = []
311
+ for img in clip:
312
+ for func in img_transforms:
313
+ jittered_img = func(img)
314
+ jittered_clip.append(jittered_img)
315
+
316
+ else:
317
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
318
+ 'but got list of {0}'.format(type(clip[0])))
319
+ return jittered_clip
320
+
321
+
322
+ class AllAugmentationTransform:
323
+ def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None):
324
+ self.transforms = []
325
+
326
+ if flip_param is not None:
327
+ self.transforms.append(RandomFlip(**flip_param))
328
+
329
+ if rotation_param is not None:
330
+ self.transforms.append(RandomRotation(**rotation_param))
331
+
332
+ if resize_param is not None:
333
+ self.transforms.append(RandomResize(**resize_param))
334
+
335
+ if crop_param is not None:
336
+ self.transforms.append(RandomCrop(**crop_param))
337
+
338
+ if jitter_param is not None:
339
+ self.transforms.append(ColorJitter(**jitter_param))
340
+
341
+ def __call__(self, clip):
342
+ for t in self.transforms:
343
+ clip = t(clip)
344
+ return clip
checkpoints/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # Checkpoints
config/mgif-256.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_params:
2
+ root_dir: ../moving-gif
3
+ frame_shape: null
4
+ id_sampling: False
5
+ augmentation_params:
6
+ flip_param:
7
+ horizontal_flip: True
8
+ time_flip: True
9
+ crop_param:
10
+ size: [256, 256]
11
+ resize_param:
12
+ ratio: [0.9, 1.1]
13
+ jitter_param:
14
+ hue: 0.5
15
+
16
+ model_params:
17
+ common_params:
18
+ num_tps: 10
19
+ num_channels: 3
20
+ bg: False
21
+ multi_mask: True
22
+ generator_params:
23
+ block_expansion: 64
24
+ max_features: 512
25
+ num_down_blocks: 3
26
+ dense_motion_params:
27
+ block_expansion: 64
28
+ max_features: 1024
29
+ num_blocks: 5
30
+ scale_factor: 0.25
31
+ avd_network_params:
32
+ id_bottle_size: 128
33
+ pose_bottle_size: 128
34
+
35
+
36
+ train_params:
37
+ num_epochs: 100
38
+ num_repeats: 50
39
+ epoch_milestones: [70, 90]
40
+ lr_generator: 2.0e-4
41
+ batch_size: 28
42
+ scales: [1, 0.5, 0.25, 0.125]
43
+ dataloader_workers: 12
44
+ checkpoint_freq: 50
45
+ dropout_epoch: 35
46
+ dropout_maxp: 0.5
47
+ dropout_startp: 0.2
48
+ dropout_inc_epoch: 10
49
+ bg_start: 0
50
+ transform_params:
51
+ sigma_affine: 0.05
52
+ sigma_tps: 0.005
53
+ points_tps: 5
54
+ loss_weights:
55
+ perceptual: [10, 10, 10, 10, 10]
56
+ equivariance_value: 10
57
+ warp_loss: 10
58
+ bg: 10
59
+
60
+ train_avd_params:
61
+ num_epochs: 100
62
+ num_repeats: 50
63
+ batch_size: 256
64
+ dataloader_workers: 24
65
+ checkpoint_freq: 10
66
+ epoch_milestones: [70, 90]
67
+ lr: 1.0e-3
68
+ lambda_shift: 1
69
+ lambda_affine: 1
70
+ random_scale: 0.25
71
+
72
+ visualizer_params:
73
+ kp_size: 5
74
+ draw_border: True
75
+ colormap: 'gist_rainbow'
config/taichi-256.yaml ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset parameters
2
+ # Each dataset should contain 2 folders train and test
3
+ # Each video can be represented as:
4
+ # - an image of concatenated frames
5
+ # - '.mp4' or '.gif'
6
+ # - folder with all frames from a specific video
7
+ # In case of Taichi. Same (youtube) video can be splitted in many parts (chunks). Each part has a following
8
+ # format (id)#other#info.mp4. For example '12335#adsbf.mp4' has an id 12335. In case of TaiChi id stands for youtube
9
+ # video id.
10
+ dataset_params:
11
+ # Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.
12
+ root_dir: ../taichi
13
+ # Image shape, needed for staked .png format.
14
+ frame_shape: null
15
+ # In case of TaiChi single video can be splitted in many chunks, or the maybe several videos for single person.
16
+ # In this case epoch can be a pass over different videos (if id_sampling=True) or over different chunks (if id_sampling=False)
17
+ # If the name of the video '12335#adsbf.mp4' the id is assumed to be 12335
18
+ id_sampling: True
19
+ # Augmentation parameters see augmentation.py for all posible augmentations
20
+ augmentation_params:
21
+ flip_param:
22
+ horizontal_flip: True
23
+ time_flip: True
24
+ jitter_param:
25
+ brightness: 0.1
26
+ contrast: 0.1
27
+ saturation: 0.1
28
+ hue: 0.1
29
+
30
+ # Defines model architecture
31
+ model_params:
32
+ common_params:
33
+ # Number of TPS transformation
34
+ num_tps: 10
35
+ # Number of channels per image
36
+ num_channels: 3
37
+ # Whether to estimate affine background transformation
38
+ bg: True
39
+ # Whether to estimate the multi-resolution occlusion masks
40
+ multi_mask: True
41
+ generator_params:
42
+ # Number of features mutliplier
43
+ block_expansion: 64
44
+ # Maximum allowed number of features
45
+ max_features: 512
46
+ # Number of downsampling blocks and Upsampling blocks.
47
+ num_down_blocks: 3
48
+ dense_motion_params:
49
+ # Number of features mutliplier
50
+ block_expansion: 64
51
+ # Maximum allowed number of features
52
+ max_features: 1024
53
+ # Number of block in Unet.
54
+ num_blocks: 5
55
+ # Optical flow is predicted on smaller images for better performance,
56
+ # scale_factor=0.25 means that 256x256 image will be resized to 64x64
57
+ scale_factor: 0.25
58
+ avd_network_params:
59
+ # Bottleneck for identity branch
60
+ id_bottle_size: 128
61
+ # Bottleneck for pose branch
62
+ pose_bottle_size: 128
63
+
64
+ # Parameters of training
65
+ train_params:
66
+ # Number of training epochs
67
+ num_epochs: 100
68
+ # For better i/o performance when number of videos is small number of epochs can be multiplied by this number.
69
+ # Thus effectivlly with num_repeats=100 each epoch is 100 times larger.
70
+ num_repeats: 150
71
+ # Drop learning rate by 10 times after this epochs
72
+ epoch_milestones: [70, 90]
73
+ # Initial learing rate for all modules
74
+ lr_generator: 2.0e-4
75
+ batch_size: 28
76
+ # Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256,
77
+ # than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32.
78
+ scales: [1, 0.5, 0.25, 0.125]
79
+ # Dataset preprocessing cpu workers
80
+ dataloader_workers: 12
81
+ # Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs.
82
+ checkpoint_freq: 50
83
+ # Parameters of dropout
84
+ # The first dropout_epoch training uses dropout operation
85
+ dropout_epoch: 35
86
+ # The probability P will linearly increase from dropout_startp to dropout_maxp in dropout_inc_epoch epochs
87
+ dropout_maxp: 0.7
88
+ dropout_startp: 0.0
89
+ dropout_inc_epoch: 10
90
+ # Estimate affine background transformation from the bg_start epoch.
91
+ bg_start: 0
92
+ # Parameters of random TPS transformation for equivariance loss
93
+ transform_params:
94
+ # Sigma for affine part
95
+ sigma_affine: 0.05
96
+ # Sigma for deformation part
97
+ sigma_tps: 0.005
98
+ # Number of point in the deformation grid
99
+ points_tps: 5
100
+ loss_weights:
101
+ # Weights for perceptual loss.
102
+ perceptual: [10, 10, 10, 10, 10]
103
+ # Weights for value equivariance.
104
+ equivariance_value: 10
105
+ # Weights for warp loss.
106
+ warp_loss: 10
107
+ # Weights for bg loss.
108
+ bg: 10
109
+
110
+ # Parameters of training (animation-via-disentanglement)
111
+ train_avd_params:
112
+ # Number of training epochs, visualization is produced after each epoch.
113
+ num_epochs: 100
114
+ # For better i/o performance when number of videos is small number of epochs can be multiplied by this number.
115
+ # Thus effectively with num_repeats=100 each epoch is 100 times larger.
116
+ num_repeats: 150
117
+ # Batch size.
118
+ batch_size: 256
119
+ # Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs.
120
+ checkpoint_freq: 10
121
+ # Dataset preprocessing cpu workers
122
+ dataloader_workers: 24
123
+ # Drop learning rate 10 times after this epochs
124
+ epoch_milestones: [70, 90]
125
+ # Initial learning rate
126
+ lr: 1.0e-3
127
+ # Weights for equivariance loss.
128
+ lambda_shift: 1
129
+ random_scale: 0.25
130
+
131
+ visualizer_params:
132
+ kp_size: 5
133
+ draw_border: True
134
+ colormap: 'gist_rainbow'
config/ted-384.yaml ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_params:
2
+ root_dir: ../TED384-v2
3
+ frame_shape: null
4
+ id_sampling: True
5
+ augmentation_params:
6
+ flip_param:
7
+ horizontal_flip: True
8
+ time_flip: True
9
+ jitter_param:
10
+ brightness: 0.1
11
+ contrast: 0.1
12
+ saturation: 0.1
13
+ hue: 0.1
14
+
15
+ model_params:
16
+ common_params:
17
+ num_tps: 10
18
+ num_channels: 3
19
+ bg: True
20
+ multi_mask: True
21
+ generator_params:
22
+ block_expansion: 64
23
+ max_features: 512
24
+ num_down_blocks: 3
25
+ dense_motion_params:
26
+ block_expansion: 64
27
+ max_features: 1024
28
+ num_blocks: 5
29
+ scale_factor: 0.25
30
+ avd_network_params:
31
+ id_bottle_size: 128
32
+ pose_bottle_size: 128
33
+
34
+
35
+ train_params:
36
+ num_epochs: 100
37
+ num_repeats: 150
38
+ epoch_milestones: [70, 90]
39
+ lr_generator: 2.0e-4
40
+ batch_size: 12
41
+ scales: [1, 0.5, 0.25, 0.125]
42
+ dataloader_workers: 6
43
+ checkpoint_freq: 50
44
+ dropout_epoch: 35
45
+ dropout_maxp: 0.5
46
+ dropout_startp: 0.0
47
+ dropout_inc_epoch: 10
48
+ bg_start: 0
49
+ transform_params:
50
+ sigma_affine: 0.05
51
+ sigma_tps: 0.005
52
+ points_tps: 5
53
+ loss_weights:
54
+ perceptual: [10, 10, 10, 10, 10]
55
+ equivariance_value: 10
56
+ warp_loss: 10
57
+ bg: 10
58
+
59
+ train_avd_params:
60
+ num_epochs: 30
61
+ num_repeats: 500
62
+ batch_size: 256
63
+ dataloader_workers: 24
64
+ checkpoint_freq: 10
65
+ epoch_milestones: [20, 25]
66
+ lr: 1.0e-3
67
+ lambda_shift: 1
68
+ random_scale: 0.25
69
+
70
+ visualizer_params:
71
+ kp_size: 5
72
+ draw_border: True
73
+ colormap: 'gist_rainbow'
config/vox-256.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_params:
2
+ root_dir: ../vox
3
+ frame_shape: null
4
+ id_sampling: True
5
+ augmentation_params:
6
+ flip_param:
7
+ horizontal_flip: True
8
+ time_flip: True
9
+ jitter_param:
10
+ brightness: 0.1
11
+ contrast: 0.1
12
+ saturation: 0.1
13
+ hue: 0.1
14
+
15
+
16
+ model_params:
17
+ common_params:
18
+ num_tps: 10
19
+ num_channels: 3
20
+ bg: True
21
+ multi_mask: True
22
+ generator_params:
23
+ block_expansion: 64
24
+ max_features: 512
25
+ num_down_blocks: 3
26
+ dense_motion_params:
27
+ block_expansion: 64
28
+ max_features: 1024
29
+ num_blocks: 5
30
+ scale_factor: 0.25
31
+ avd_network_params:
32
+ id_bottle_size: 128
33
+ pose_bottle_size: 128
34
+
35
+
36
+ train_params:
37
+ num_epochs: 100
38
+ num_repeats: 75
39
+ epoch_milestones: [70, 90]
40
+ lr_generator: 2.0e-4
41
+ batch_size: 28
42
+ scales: [1, 0.5, 0.25, 0.125]
43
+ dataloader_workers: 12
44
+ checkpoint_freq: 50
45
+ dropout_epoch: 35
46
+ dropout_maxp: 0.3
47
+ dropout_startp: 0.1
48
+ dropout_inc_epoch: 10
49
+ bg_start: 10
50
+ transform_params:
51
+ sigma_affine: 0.05
52
+ sigma_tps: 0.005
53
+ points_tps: 5
54
+ loss_weights:
55
+ perceptual: [10, 10, 10, 10, 10]
56
+ equivariance_value: 10
57
+ warp_loss: 10
58
+ bg: 10
59
+
60
+ train_avd_params:
61
+ num_epochs: 200
62
+ num_repeats: 300
63
+ batch_size: 256
64
+ dataloader_workers: 24
65
+ checkpoint_freq: 50
66
+ epoch_milestones: [140, 180]
67
+ lr: 1.0e-3
68
+ lambda_shift: 1
69
+ random_scale: 0.25
70
+
71
+ visualizer_params:
72
+ kp_size: 5
73
+ draw_border: True
74
+ colormap: 'gist_rainbow'
demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
demo.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ matplotlib.use('Agg')
3
+ import sys
4
+ import yaml
5
+ from argparse import ArgumentParser
6
+ from tqdm import tqdm
7
+ from scipy.spatial import ConvexHull
8
+ import numpy as np
9
+ import imageio
10
+ from skimage.transform import resize
11
+ from skimage import img_as_ubyte
12
+ import torch
13
+ from modules.inpainting_network import InpaintingNetwork
14
+ from modules.keypoint_detector import KPDetector
15
+ from modules.dense_motion import DenseMotionNetwork
16
+ from modules.avd_network import AVDNetwork
17
+
18
+ if sys.version_info[0] < 3:
19
+ raise Exception("You must use Python 3 or higher. Recommended version is Python 3.9")
20
+
21
+ def relative_kp(kp_source, kp_driving, kp_driving_initial):
22
+
23
+ source_area = ConvexHull(kp_source['fg_kp'][0].data.cpu().numpy()).volume
24
+ driving_area = ConvexHull(kp_driving_initial['fg_kp'][0].data.cpu().numpy()).volume
25
+ adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
26
+
27
+ kp_new = {k: v for k, v in kp_driving.items()}
28
+
29
+ kp_value_diff = (kp_driving['fg_kp'] - kp_driving_initial['fg_kp'])
30
+ kp_value_diff *= adapt_movement_scale
31
+ kp_new['fg_kp'] = kp_value_diff + kp_source['fg_kp']
32
+
33
+ return kp_new
34
+
35
+ def load_checkpoints(config_path, checkpoint_path, device):
36
+ with open(config_path) as f:
37
+ config = yaml.load(f)
38
+
39
+ inpainting = InpaintingNetwork(**config['model_params']['generator_params'],
40
+ **config['model_params']['common_params'])
41
+ kp_detector = KPDetector(**config['model_params']['common_params'])
42
+ dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'],
43
+ **config['model_params']['dense_motion_params'])
44
+ avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'],
45
+ **config['model_params']['avd_network_params'])
46
+ kp_detector.to(device)
47
+ dense_motion_network.to(device)
48
+ inpainting.to(device)
49
+ avd_network.to(device)
50
+
51
+ checkpoint = torch.load(checkpoint_path, map_location=device)
52
+
53
+ inpainting.load_state_dict(checkpoint['inpainting_network'])
54
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
55
+ dense_motion_network.load_state_dict(checkpoint['dense_motion_network'])
56
+ if 'avd_network' in checkpoint:
57
+ avd_network.load_state_dict(checkpoint['avd_network'])
58
+
59
+ inpainting.eval()
60
+ kp_detector.eval()
61
+ dense_motion_network.eval()
62
+ avd_network.eval()
63
+
64
+ return inpainting, kp_detector, dense_motion_network, avd_network
65
+
66
+
67
+ def make_animation(source_image, driving_video, inpainting_network, kp_detector, dense_motion_network, avd_network, device, mode = 'relative'):
68
+ assert mode in ['standard', 'relative', 'avd']
69
+ with torch.no_grad():
70
+ predictions = []
71
+ source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
72
+ source = source.to(device)
73
+ driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3).to(device)
74
+ kp_source = kp_detector(source)
75
+ kp_driving_initial = kp_detector(driving[:, :, 0])
76
+
77
+ for frame_idx in tqdm(range(driving.shape[2])):
78
+ driving_frame = driving[:, :, frame_idx]
79
+ driving_frame = driving_frame.to(device)
80
+ kp_driving = kp_detector(driving_frame)
81
+ if mode == 'standard':
82
+ kp_norm = kp_driving
83
+ elif mode=='relative':
84
+ kp_norm = relative_kp(kp_source=kp_source, kp_driving=kp_driving,
85
+ kp_driving_initial=kp_driving_initial)
86
+ elif mode == 'avd':
87
+ kp_norm = avd_network(kp_source, kp_driving)
88
+ dense_motion = dense_motion_network(source_image=source, kp_driving=kp_norm,
89
+ kp_source=kp_source, bg_param = None,
90
+ dropout_flag = False)
91
+ out = inpainting_network(source, dense_motion)
92
+
93
+ predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
94
+ return predictions
95
+
96
+
97
+ def find_best_frame(source, driving, cpu):
98
+ import face_alignment
99
+
100
+ def normalize_kp(kp):
101
+ kp = kp - kp.mean(axis=0, keepdims=True)
102
+ area = ConvexHull(kp[:, :2]).volume
103
+ area = np.sqrt(area)
104
+ kp[:, :2] = kp[:, :2] / area
105
+ return kp
106
+
107
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
108
+ device= 'cpu' if cpu else 'cuda')
109
+ kp_source = fa.get_landmarks(255 * source)[0]
110
+ kp_source = normalize_kp(kp_source)
111
+ norm = float('inf')
112
+ frame_num = 0
113
+ for i, image in tqdm(enumerate(driving)):
114
+ kp_driving = fa.get_landmarks(255 * image)[0]
115
+ kp_driving = normalize_kp(kp_driving)
116
+ new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
117
+ if new_norm < norm:
118
+ norm = new_norm
119
+ frame_num = i
120
+ return frame_num
121
+
122
+
123
+ if __name__ == "__main__":
124
+ parser = ArgumentParser()
125
+ parser.add_argument("--config", required=True, help="path to config")
126
+ parser.add_argument("--checkpoint", default='checkpoints/vox.pth.tar', help="path to checkpoint to restore")
127
+
128
+ parser.add_argument("--source_image", default='./assets/source.png', help="path to source image")
129
+ parser.add_argument("--driving_video", default='./assets/driving.mp4', help="path to driving video")
130
+ parser.add_argument("--result_video", default='./result.mp4', help="path to output")
131
+
132
+ parser.add_argument("--img_shape", default="256,256", type=lambda x: list(map(int, x.split(','))),
133
+ help='Shape of image, that the model was trained on.')
134
+
135
+ parser.add_argument("--mode", default='relative', choices=['standard', 'relative', 'avd'], help="Animate mode: ['standard', 'relative', 'avd'], when use the relative mode to animate a face, use '--find_best_frame' can get better quality result")
136
+
137
+ parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true",
138
+ help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)")
139
+
140
+ parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
141
+
142
+ opt = parser.parse_args()
143
+
144
+ source_image = imageio.imread(opt.source_image)
145
+ reader = imageio.get_reader(opt.driving_video)
146
+ fps = reader.get_meta_data()['fps']
147
+ driving_video = []
148
+ try:
149
+ for im in reader:
150
+ driving_video.append(im)
151
+ except RuntimeError:
152
+ pass
153
+ reader.close()
154
+
155
+ if opt.cpu:
156
+ device = torch.device('cpu')
157
+ else:
158
+ device = torch.device('cuda')
159
+
160
+ source_image = resize(source_image, opt.img_shape)[..., :3]
161
+ driving_video = [resize(frame, opt.img_shape)[..., :3] for frame in driving_video]
162
+ inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = opt.config, checkpoint_path = opt.checkpoint, device = device)
163
+
164
+ if opt.find_best_frame:
165
+ i = find_best_frame(source_image, driving_video, opt.cpu)
166
+ print ("Best frame: " + str(i))
167
+ driving_forward = driving_video[i:]
168
+ driving_backward = driving_video[:(i+1)][::-1]
169
+ predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode)
170
+ predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode)
171
+ predictions = predictions_backward[::-1] + predictions_forward[1:]
172
+ else:
173
+ predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = opt.mode)
174
+
175
+ imageio.mimsave(opt.result_video, [img_as_ubyte(frame) for frame in predictions], fps=fps)
176
+
frames_dataset.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from skimage import io, img_as_float32
3
+ from skimage.color import gray2rgb
4
+ from sklearn.model_selection import train_test_split
5
+ from imageio import mimread
6
+ from skimage.transform import resize
7
+ import numpy as np
8
+ from torch.utils.data import Dataset
9
+ from augmentation import AllAugmentationTransform
10
+ import glob
11
+ from functools import partial
12
+
13
+
14
+ def read_video(name, frame_shape):
15
+ """
16
+ Read video which can be:
17
+ - an image of concatenated frames
18
+ - '.mp4' and'.gif'
19
+ - folder with videos
20
+ """
21
+
22
+ if os.path.isdir(name):
23
+ frames = sorted(os.listdir(name))
24
+ num_frames = len(frames)
25
+ video_array = np.array(
26
+ [img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)])
27
+ elif name.lower().endswith('.png') or name.lower().endswith('.jpg'):
28
+ image = io.imread(name)
29
+
30
+ if len(image.shape) == 2 or image.shape[2] == 1:
31
+ image = gray2rgb(image)
32
+
33
+ if image.shape[2] == 4:
34
+ image = image[..., :3]
35
+
36
+ image = img_as_float32(image)
37
+
38
+ video_array = np.moveaxis(image, 1, 0)
39
+
40
+ video_array = video_array.reshape((-1,) + frame_shape)
41
+ video_array = np.moveaxis(video_array, 1, 2)
42
+ elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):
43
+ video = mimread(name)
44
+ if len(video[0].shape) == 2:
45
+ video = [gray2rgb(frame) for frame in video]
46
+ if frame_shape is not None:
47
+ video = np.array([resize(frame, frame_shape) for frame in video])
48
+ video = np.array(video)
49
+ if video.shape[-1] == 4:
50
+ video = video[..., :3]
51
+ video_array = img_as_float32(video)
52
+ else:
53
+ raise Exception("Unknown file extensions %s" % name)
54
+
55
+ return video_array
56
+
57
+
58
+ class FramesDataset(Dataset):
59
+ """
60
+ Dataset of videos, each video can be represented as:
61
+ - an image of concatenated frames
62
+ - '.mp4' or '.gif'
63
+ - folder with all frames
64
+ """
65
+
66
+ def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
67
+ random_seed=0, pairs_list=None, augmentation_params=None):
68
+ self.root_dir = root_dir
69
+ self.videos = os.listdir(root_dir)
70
+ self.frame_shape = frame_shape
71
+ print(self.frame_shape)
72
+ self.pairs_list = pairs_list
73
+ self.id_sampling = id_sampling
74
+
75
+ if os.path.exists(os.path.join(root_dir, 'train')):
76
+ assert os.path.exists(os.path.join(root_dir, 'test'))
77
+ print("Use predefined train-test split.")
78
+ if id_sampling:
79
+ train_videos = {os.path.basename(video).split('#')[0] for video in
80
+ os.listdir(os.path.join(root_dir, 'train'))}
81
+ train_videos = list(train_videos)
82
+ else:
83
+ train_videos = os.listdir(os.path.join(root_dir, 'train'))
84
+ test_videos = os.listdir(os.path.join(root_dir, 'test'))
85
+ self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
86
+ else:
87
+ print("Use random train-test split.")
88
+ train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)
89
+
90
+ if is_train:
91
+ self.videos = train_videos
92
+ else:
93
+ self.videos = test_videos
94
+
95
+ self.is_train = is_train
96
+
97
+ if self.is_train:
98
+ self.transform = AllAugmentationTransform(**augmentation_params)
99
+ else:
100
+ self.transform = None
101
+
102
+ def __len__(self):
103
+ return len(self.videos)
104
+
105
+ def __getitem__(self, idx):
106
+
107
+ if self.is_train and self.id_sampling:
108
+ name = self.videos[idx]
109
+ path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
110
+ else:
111
+ name = self.videos[idx]
112
+ path = os.path.join(self.root_dir, name)
113
+
114
+ video_name = os.path.basename(path)
115
+ if self.is_train and os.path.isdir(path):
116
+
117
+ frames = os.listdir(path)
118
+ num_frames = len(frames)
119
+ frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))
120
+
121
+ if self.frame_shape is not None:
122
+ resize_fn = partial(resize, output_shape=self.frame_shape)
123
+ else:
124
+ resize_fn = img_as_float32
125
+
126
+ if type(frames[0]) is bytes:
127
+ video_array = [resize_fn(io.imread(os.path.join(path, frames[idx].decode('utf-8')))) for idx in
128
+ frame_idx]
129
+ else:
130
+ video_array = [resize_fn(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx]
131
+ else:
132
+
133
+ video_array = read_video(path, frame_shape=self.frame_shape)
134
+
135
+ num_frames = len(video_array)
136
+ frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range(
137
+ num_frames)
138
+ video_array = video_array[frame_idx]
139
+
140
+
141
+ if self.transform is not None:
142
+ video_array = self.transform(video_array)
143
+
144
+ out = {}
145
+ if self.is_train:
146
+ source = np.array(video_array[0], dtype='float32')
147
+ driving = np.array(video_array[1], dtype='float32')
148
+
149
+ out['driving'] = driving.transpose((2, 0, 1))
150
+ out['source'] = source.transpose((2, 0, 1))
151
+ else:
152
+ video = np.array(video_array, dtype='float32')
153
+ out['video'] = video.transpose((3, 0, 1, 2))
154
+
155
+ out['name'] = video_name
156
+ return out
157
+
158
+
159
+ class DatasetRepeater(Dataset):
160
+ """
161
+ Pass several times over the same dataset for better i/o performance
162
+ """
163
+
164
+ def __init__(self, dataset, num_repeats=100):
165
+ self.dataset = dataset
166
+ self.num_repeats = num_repeats
167
+
168
+ def __len__(self):
169
+ return self.num_repeats * self.dataset.__len__()
170
+
171
+ def __getitem__(self, idx):
172
+ return self.dataset[idx % self.dataset.__len__()]
173
+
generated.mp4 ADDED
Binary file (11.3 kB). View file
 
logger.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import imageio
5
+
6
+ import os
7
+ from skimage.draw import circle
8
+
9
+ import matplotlib.pyplot as plt
10
+ import collections
11
+
12
+
13
+ class Logger:
14
+ def __init__(self, log_dir, checkpoint_freq=50, visualizer_params=None, zfill_num=8, log_file_name='log.txt'):
15
+
16
+ self.loss_list = []
17
+ self.cpk_dir = log_dir
18
+ self.visualizations_dir = os.path.join(log_dir, 'train-vis')
19
+ if not os.path.exists(self.visualizations_dir):
20
+ os.makedirs(self.visualizations_dir)
21
+ self.log_file = open(os.path.join(log_dir, log_file_name), 'a')
22
+ self.zfill_num = zfill_num
23
+ self.visualizer = Visualizer(**visualizer_params)
24
+ self.checkpoint_freq = checkpoint_freq
25
+ self.epoch = 0
26
+ self.best_loss = float('inf')
27
+ self.names = None
28
+
29
+ def log_scores(self, loss_names):
30
+ loss_mean = np.array(self.loss_list).mean(axis=0)
31
+
32
+ loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)])
33
+ loss_string = str(self.epoch).zfill(self.zfill_num) + ") " + loss_string
34
+
35
+ print(loss_string, file=self.log_file)
36
+ self.loss_list = []
37
+ self.log_file.flush()
38
+
39
+ def visualize_rec(self, inp, out):
40
+ image = self.visualizer.visualize(inp['driving'], inp['source'], out)
41
+ imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image)
42
+
43
+ def save_cpk(self, emergent=False):
44
+ cpk = {k: v.state_dict() for k, v in self.models.items()}
45
+ cpk['epoch'] = self.epoch
46
+ cpk_path = os.path.join(self.cpk_dir, '%s-checkpoint.pth.tar' % str(self.epoch).zfill(self.zfill_num))
47
+ if not (os.path.exists(cpk_path) and emergent):
48
+ torch.save(cpk, cpk_path)
49
+
50
+ @staticmethod
51
+ def load_cpk(checkpoint_path, inpainting_network=None, dense_motion_network =None, kp_detector=None,
52
+ bg_predictor=None, avd_network=None, optimizer=None, optimizer_bg_predictor=None,
53
+ optimizer_avd=None):
54
+ checkpoint = torch.load(checkpoint_path)
55
+ if inpainting_network is not None:
56
+ inpainting_network.load_state_dict(checkpoint['inpainting_network'])
57
+ if kp_detector is not None:
58
+ kp_detector.load_state_dict(checkpoint['kp_detector'])
59
+ if bg_predictor is not None and 'bg_predictor' in checkpoint:
60
+ bg_predictor.load_state_dict(checkpoint['bg_predictor'])
61
+ if dense_motion_network is not None:
62
+ dense_motion_network.load_state_dict(checkpoint['dense_motion_network'])
63
+ if avd_network is not None:
64
+ if 'avd_network' in checkpoint:
65
+ avd_network.load_state_dict(checkpoint['avd_network'])
66
+ if optimizer_bg_predictor is not None and 'optimizer_bg_predictor' in checkpoint:
67
+ optimizer_bg_predictor.load_state_dict(checkpoint['optimizer_bg_predictor'])
68
+ if optimizer is not None and 'optimizer' in checkpoint:
69
+ optimizer.load_state_dict(checkpoint['optimizer'])
70
+ if optimizer_avd is not None:
71
+ if 'optimizer_avd' in checkpoint:
72
+ optimizer_avd.load_state_dict(checkpoint['optimizer_avd'])
73
+ epoch = -1
74
+ if 'epoch' in checkpoint:
75
+ epoch = checkpoint['epoch']
76
+ return epoch
77
+
78
+ def __enter__(self):
79
+ return self
80
+
81
+ def __exit__(self):
82
+ if 'models' in self.__dict__:
83
+ self.save_cpk()
84
+ self.log_file.close()
85
+
86
+ def log_iter(self, losses):
87
+ losses = collections.OrderedDict(losses.items())
88
+ self.names = list(losses.keys())
89
+ self.loss_list.append(list(losses.values()))
90
+
91
+ def log_epoch(self, epoch, models, inp, out):
92
+ self.epoch = epoch
93
+ self.models = models
94
+ if (self.epoch + 1) % self.checkpoint_freq == 0:
95
+ self.save_cpk()
96
+ self.log_scores(self.names)
97
+ self.visualize_rec(inp, out)
98
+
99
+
100
+ class Visualizer:
101
+ def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'):
102
+ self.kp_size = kp_size
103
+ self.draw_border = draw_border
104
+ self.colormap = plt.get_cmap(colormap)
105
+
106
+ def draw_image_with_kp(self, image, kp_array):
107
+ image = np.copy(image)
108
+ spatial_size = np.array(image.shape[:2][::-1])[np.newaxis]
109
+ kp_array = spatial_size * (kp_array + 1) / 2
110
+ num_kp = kp_array.shape[0]
111
+ for kp_ind, kp in enumerate(kp_array):
112
+ rr, cc = circle(kp[1], kp[0], self.kp_size, shape=image.shape[:2])
113
+ image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3]
114
+ return image
115
+
116
+ def create_image_column_with_kp(self, images, kp):
117
+ image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)])
118
+ return self.create_image_column(image_array)
119
+
120
+ def create_image_column(self, images):
121
+ if self.draw_border:
122
+ images = np.copy(images)
123
+ images[:, :, [0, -1]] = (1, 1, 1)
124
+ images[:, :, [0, -1]] = (1, 1, 1)
125
+ return np.concatenate(list(images), axis=0)
126
+
127
+ def create_image_grid(self, *args):
128
+ out = []
129
+ for arg in args:
130
+ if type(arg) == tuple:
131
+ out.append(self.create_image_column_with_kp(arg[0], arg[1]))
132
+ else:
133
+ out.append(self.create_image_column(arg))
134
+ return np.concatenate(out, axis=1)
135
+
136
+ def visualize(self, driving, source, out):
137
+ images = []
138
+
139
+ # Source image with keypoints
140
+ source = source.data.cpu()
141
+ kp_source = out['kp_source']['fg_kp'].data.cpu().numpy()
142
+ source = np.transpose(source, [0, 2, 3, 1])
143
+ images.append((source, kp_source))
144
+
145
+ # Equivariance visualization
146
+ if 'transformed_frame' in out:
147
+ transformed = out['transformed_frame'].data.cpu().numpy()
148
+ transformed = np.transpose(transformed, [0, 2, 3, 1])
149
+ transformed_kp = out['transformed_kp']['fg_kp'].data.cpu().numpy()
150
+ images.append((transformed, transformed_kp))
151
+
152
+ # Driving image with keypoints
153
+ kp_driving = out['kp_driving']['fg_kp'].data.cpu().numpy()
154
+ driving = driving.data.cpu().numpy()
155
+ driving = np.transpose(driving, [0, 2, 3, 1])
156
+ images.append((driving, kp_driving))
157
+
158
+ # Deformed image
159
+ if 'deformed' in out:
160
+ deformed = out['deformed'].data.cpu().numpy()
161
+ deformed = np.transpose(deformed, [0, 2, 3, 1])
162
+ images.append(deformed)
163
+
164
+ # Result with and without keypoints
165
+ prediction = out['prediction'].data.cpu().numpy()
166
+ prediction = np.transpose(prediction, [0, 2, 3, 1])
167
+ if 'kp_norm' in out:
168
+ kp_norm = out['kp_norm']['fg_kp'].data.cpu().numpy()
169
+ images.append((prediction, kp_norm))
170
+ images.append(prediction)
171
+
172
+
173
+ ## Occlusion map
174
+ if 'occlusion_map' in out:
175
+ for i in range(len(out['occlusion_map'])):
176
+ occlusion_map = out['occlusion_map'][i].data.cpu().repeat(1, 3, 1, 1)
177
+ occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy()
178
+ occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1])
179
+ images.append(occlusion_map)
180
+
181
+ # Deformed images according to each individual transform
182
+ if 'deformed_source' in out:
183
+ full_mask = []
184
+ for i in range(out['deformed_source'].shape[1]):
185
+ image = out['deformed_source'][:, i].data.cpu()
186
+ # import ipdb;ipdb.set_trace()
187
+ image = F.interpolate(image, size=source.shape[1:3])
188
+ mask = out['contribution_maps'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1)
189
+ mask = F.interpolate(mask, size=source.shape[1:3])
190
+ image = np.transpose(image.numpy(), (0, 2, 3, 1))
191
+ mask = np.transpose(mask.numpy(), (0, 2, 3, 1))
192
+
193
+ if i != 0:
194
+ color = np.array(self.colormap((i - 1) / (out['deformed_source'].shape[1] - 1)))[:3]
195
+ else:
196
+ color = np.array((0, 0, 0))
197
+
198
+ color = color.reshape((1, 1, 1, 3))
199
+
200
+ images.append(image)
201
+ if i != 0:
202
+ images.append(mask * color)
203
+ else:
204
+ images.append(mask)
205
+
206
+ full_mask.append(mask * color)
207
+
208
+ images.append(sum(full_mask))
209
+
210
+ image = self.create_image_grid(*images)
211
+ image = (255 * image).astype(np.uint8)
212
+ return image
modules/avd_network.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class AVDNetwork(nn.Module):
7
+ """
8
+ Animation via Disentanglement network
9
+ """
10
+
11
+ def __init__(self, num_tps, id_bottle_size=64, pose_bottle_size=64):
12
+ super(AVDNetwork, self).__init__()
13
+ input_size = 5*2 * num_tps
14
+ self.num_tps = num_tps
15
+
16
+ self.id_encoder = nn.Sequential(
17
+ nn.Linear(input_size, 256),
18
+ nn.BatchNorm1d(256),
19
+ nn.ReLU(inplace=True),
20
+ nn.Linear(256, 512),
21
+ nn.BatchNorm1d(512),
22
+ nn.ReLU(inplace=True),
23
+ nn.Linear(512, 1024),
24
+ nn.BatchNorm1d(1024),
25
+ nn.ReLU(inplace=True),
26
+ nn.Linear(1024, id_bottle_size)
27
+ )
28
+
29
+ self.pose_encoder = nn.Sequential(
30
+ nn.Linear(input_size, 256),
31
+ nn.BatchNorm1d(256),
32
+ nn.ReLU(inplace=True),
33
+ nn.Linear(256, 512),
34
+ nn.BatchNorm1d(512),
35
+ nn.ReLU(inplace=True),
36
+ nn.Linear(512, 1024),
37
+ nn.BatchNorm1d(1024),
38
+ nn.ReLU(inplace=True),
39
+ nn.Linear(1024, pose_bottle_size)
40
+ )
41
+
42
+ self.decoder = nn.Sequential(
43
+ nn.Linear(pose_bottle_size + id_bottle_size, 1024),
44
+ nn.BatchNorm1d(1024),
45
+ nn.ReLU(),
46
+ nn.Linear(1024, 512),
47
+ nn.BatchNorm1d(512),
48
+ nn.ReLU(),
49
+ nn.Linear(512, 256),
50
+ nn.BatchNorm1d(256),
51
+ nn.ReLU(),
52
+ nn.Linear(256, input_size)
53
+ )
54
+
55
+ def forward(self, kp_source, kp_random):
56
+
57
+ bs = kp_source['fg_kp'].shape[0]
58
+
59
+ pose_emb = self.pose_encoder(kp_random['fg_kp'].view(bs, -1))
60
+ id_emb = self.id_encoder(kp_source['fg_kp'].view(bs, -1))
61
+
62
+ rec = self.decoder(torch.cat([pose_emb, id_emb], dim=1))
63
+
64
+ rec = {'fg_kp': rec.view(bs, self.num_tps*5, -1)}
65
+ return rec
modules/bg_motion_predictor.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ from torchvision import models
4
+
5
+ class BGMotionPredictor(nn.Module):
6
+ """
7
+ Module for background estimation, return single transformation, parametrized as 3x3 matrix. The third row is [0 0 1]
8
+ """
9
+
10
+ def __init__(self):
11
+ super(BGMotionPredictor, self).__init__()
12
+ self.bg_encoder = models.resnet18(pretrained=False)
13
+ self.bg_encoder.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
14
+ num_features = self.bg_encoder.fc.in_features
15
+ self.bg_encoder.fc = nn.Linear(num_features, 6)
16
+ self.bg_encoder.fc.weight.data.zero_()
17
+ self.bg_encoder.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
18
+
19
+ def forward(self, source_image, driving_image):
20
+ bs = source_image.shape[0]
21
+ out = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).type(source_image.type())
22
+ prediction = self.bg_encoder(torch.cat([source_image, driving_image], dim=1))
23
+ out[:, :2, :] = prediction.view(bs, 2, 3)
24
+ return out
modules/dense_motion.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
5
+ from modules.util import to_homogeneous, from_homogeneous, UpBlock2d, TPS
6
+ import math
7
+
8
+ class DenseMotionNetwork(nn.Module):
9
+ """
10
+ Module that estimating an optical flow and multi-resolution occlusion masks
11
+ from K TPS transformations and an affine transformation.
12
+ """
13
+
14
+ def __init__(self, block_expansion, num_blocks, max_features, num_tps, num_channels,
15
+ scale_factor=0.25, bg = False, multi_mask = True, kp_variance=0.01):
16
+ super(DenseMotionNetwork, self).__init__()
17
+
18
+ if scale_factor != 1:
19
+ self.down = AntiAliasInterpolation2d(num_channels, scale_factor)
20
+ self.scale_factor = scale_factor
21
+ self.multi_mask = multi_mask
22
+
23
+ self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_channels * (num_tps+1) + num_tps*5+1),
24
+ max_features=max_features, num_blocks=num_blocks)
25
+
26
+ hourglass_output_size = self.hourglass.out_channels
27
+ self.maps = nn.Conv2d(hourglass_output_size[-1], num_tps + 1, kernel_size=(7, 7), padding=(3, 3))
28
+
29
+ if multi_mask:
30
+ up = []
31
+ self.up_nums = int(math.log(1/scale_factor, 2))
32
+ self.occlusion_num = 4
33
+
34
+ channel = [hourglass_output_size[-1]//(2**i) for i in range(self.up_nums)]
35
+ for i in range(self.up_nums):
36
+ up.append(UpBlock2d(channel[i], channel[i]//2, kernel_size=3, padding=1))
37
+ self.up = nn.ModuleList(up)
38
+
39
+ channel = [hourglass_output_size[-i-1] for i in range(self.occlusion_num-self.up_nums)[::-1]]
40
+ for i in range(self.up_nums):
41
+ channel.append(hourglass_output_size[-1]//(2**(i+1)))
42
+ occlusion = []
43
+
44
+ for i in range(self.occlusion_num):
45
+ occlusion.append(nn.Conv2d(channel[i], 1, kernel_size=(7, 7), padding=(3, 3)))
46
+ self.occlusion = nn.ModuleList(occlusion)
47
+ else:
48
+ occlusion = [nn.Conv2d(hourglass_output_size[-1], 1, kernel_size=(7, 7), padding=(3, 3))]
49
+ self.occlusion = nn.ModuleList(occlusion)
50
+
51
+ self.num_tps = num_tps
52
+ self.bg = bg
53
+ self.kp_variance = kp_variance
54
+
55
+
56
+ def create_heatmap_representations(self, source_image, kp_driving, kp_source):
57
+
58
+ spatial_size = source_image.shape[2:]
59
+ gaussian_driving = kp2gaussian(kp_driving['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance)
60
+ gaussian_source = kp2gaussian(kp_source['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance)
61
+ heatmap = gaussian_driving - gaussian_source
62
+
63
+ zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type()).to(heatmap.device)
64
+ heatmap = torch.cat([zeros, heatmap], dim=1)
65
+
66
+ return heatmap
67
+
68
+ def create_transformations(self, source_image, kp_driving, kp_source, bg_param):
69
+ # K TPS transformaions
70
+ bs, _, h, w = source_image.shape
71
+ kp_1 = kp_driving['fg_kp']
72
+ kp_2 = kp_source['fg_kp']
73
+ kp_1 = kp_1.view(bs, -1, 5, 2)
74
+ kp_2 = kp_2.view(bs, -1, 5, 2)
75
+ trans = TPS(mode = 'kp', bs = bs, kp_1 = kp_1, kp_2 = kp_2)
76
+ driving_to_source = trans.transform_frame(source_image)
77
+
78
+ identity_grid = make_coordinate_grid((h, w), type=kp_1.type()).to(kp_1.device)
79
+ identity_grid = identity_grid.view(1, 1, h, w, 2)
80
+ identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
81
+
82
+ # affine background transformation
83
+ if not (bg_param is None):
84
+ identity_grid = to_homogeneous(identity_grid)
85
+ identity_grid = torch.matmul(bg_param.view(bs, 1, 1, 1, 3, 3), identity_grid.unsqueeze(-1)).squeeze(-1)
86
+ identity_grid = from_homogeneous(identity_grid)
87
+
88
+ transformations = torch.cat([identity_grid, driving_to_source], dim=1)
89
+ return transformations
90
+
91
+ def create_deformed_source_image(self, source_image, transformations):
92
+
93
+ bs, _, h, w = source_image.shape
94
+ source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_tps + 1, 1, 1, 1, 1)
95
+ source_repeat = source_repeat.view(bs * (self.num_tps + 1), -1, h, w)
96
+ transformations = transformations.view((bs * (self.num_tps + 1), h, w, -1))
97
+ deformed = F.grid_sample(source_repeat, transformations, align_corners=True)
98
+ deformed = deformed.view((bs, self.num_tps+1, -1, h, w))
99
+ return deformed
100
+
101
+ def dropout_softmax(self, X, P):
102
+ '''
103
+ Dropout for TPS transformations. Eq(7) and Eq(8) in the paper.
104
+ '''
105
+ drop = (torch.rand(X.shape[0],X.shape[1]) < (1-P)).type(X.type()).to(X.device)
106
+ drop[..., 0] = 1
107
+ drop = drop.repeat(X.shape[2],X.shape[3],1,1).permute(2,3,0,1)
108
+
109
+ maxx = X.max(1).values.unsqueeze_(1)
110
+ X = X - maxx
111
+ X_exp = X.exp()
112
+ X[:,1:,...] /= (1-P)
113
+ mask_bool =(drop == 0)
114
+ X_exp = X_exp.masked_fill(mask_bool, 0)
115
+ partition = X_exp.sum(dim=1, keepdim=True) + 1e-6
116
+ return X_exp / partition
117
+
118
+ def forward(self, source_image, kp_driving, kp_source, bg_param = None, dropout_flag=False, dropout_p = 0):
119
+ if self.scale_factor != 1:
120
+ source_image = self.down(source_image)
121
+
122
+ bs, _, h, w = source_image.shape
123
+
124
+ out_dict = dict()
125
+ heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)
126
+ transformations = self.create_transformations(source_image, kp_driving, kp_source, bg_param)
127
+ deformed_source = self.create_deformed_source_image(source_image, transformations)
128
+ out_dict['deformed_source'] = deformed_source
129
+ # out_dict['transformations'] = transformations
130
+ deformed_source = deformed_source.view(bs,-1,h,w)
131
+ input = torch.cat([heatmap_representation, deformed_source], dim=1)
132
+ input = input.view(bs, -1, h, w)
133
+
134
+ prediction = self.hourglass(input, mode = 1)
135
+
136
+ contribution_maps = self.maps(prediction[-1])
137
+ if(dropout_flag):
138
+ contribution_maps = self.dropout_softmax(contribution_maps, dropout_p)
139
+ else:
140
+ contribution_maps = F.softmax(contribution_maps, dim=1)
141
+ out_dict['contribution_maps'] = contribution_maps
142
+
143
+ # Combine the K+1 transformations
144
+ # Eq(6) in the paper
145
+ contribution_maps = contribution_maps.unsqueeze(2)
146
+ transformations = transformations.permute(0, 1, 4, 2, 3)
147
+ deformation = (transformations * contribution_maps).sum(dim=1)
148
+ deformation = deformation.permute(0, 2, 3, 1)
149
+
150
+ out_dict['deformation'] = deformation # Optical Flow
151
+
152
+ occlusion_map = []
153
+ if self.multi_mask:
154
+ for i in range(self.occlusion_num-self.up_nums):
155
+ occlusion_map.append(torch.sigmoid(self.occlusion[i](prediction[self.up_nums-self.occlusion_num+i])))
156
+ prediction = prediction[-1]
157
+ for i in range(self.up_nums):
158
+ prediction = self.up[i](prediction)
159
+ occlusion_map.append(torch.sigmoid(self.occlusion[i+self.occlusion_num-self.up_nums](prediction)))
160
+ else:
161
+ occlusion_map.append(torch.sigmoid(self.occlusion[0](prediction[-1])))
162
+
163
+ out_dict['occlusion_map'] = occlusion_map # Multi-resolution Occlusion Masks
164
+ return out_dict
modules/inpainting_network.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
5
+ from modules.dense_motion import DenseMotionNetwork
6
+
7
+
8
+ class InpaintingNetwork(nn.Module):
9
+ """
10
+ Inpaint the missing regions and reconstruct the Driving image.
11
+ """
12
+ def __init__(self, num_channels, block_expansion, max_features, num_down_blocks, multi_mask = True, **kwargs):
13
+ super(InpaintingNetwork, self).__init__()
14
+
15
+ self.num_down_blocks = num_down_blocks
16
+ self.multi_mask = multi_mask
17
+ self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
18
+
19
+ down_blocks = []
20
+ for i in range(num_down_blocks):
21
+ in_features = min(max_features, block_expansion * (2 ** i))
22
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
23
+ down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
24
+ self.down_blocks = nn.ModuleList(down_blocks)
25
+
26
+ up_blocks = []
27
+ in_features = [max_features, max_features, max_features//2]
28
+ out_features = [max_features//2, max_features//4, max_features//8]
29
+ for i in range(num_down_blocks):
30
+ up_blocks.append(UpBlock2d(in_features[i], out_features[i], kernel_size=(3, 3), padding=(1, 1)))
31
+ self.up_blocks = nn.ModuleList(up_blocks)
32
+
33
+ resblock = []
34
+ for i in range(num_down_blocks):
35
+ resblock.append(ResBlock2d(in_features[i], kernel_size=(3, 3), padding=(1, 1)))
36
+ resblock.append(ResBlock2d(in_features[i], kernel_size=(3, 3), padding=(1, 1)))
37
+ self.resblock = nn.ModuleList(resblock)
38
+
39
+ self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
40
+ self.num_channels = num_channels
41
+
42
+ def deform_input(self, inp, deformation):
43
+ _, h_old, w_old, _ = deformation.shape
44
+ _, _, h, w = inp.shape
45
+ if h_old != h or w_old != w:
46
+ deformation = deformation.permute(0, 3, 1, 2)
47
+ deformation = F.interpolate(deformation, size=(h, w), mode='bilinear', align_corners=True)
48
+ deformation = deformation.permute(0, 2, 3, 1)
49
+ return F.grid_sample(inp, deformation,align_corners=True)
50
+
51
+ def occlude_input(self, inp, occlusion_map):
52
+ if not self.multi_mask:
53
+ if inp.shape[2] != occlusion_map.shape[2] or inp.shape[3] != occlusion_map.shape[3]:
54
+ occlusion_map = F.interpolate(occlusion_map, size=inp.shape[2:], mode='bilinear',align_corners=True)
55
+ out = inp * occlusion_map
56
+ return out
57
+
58
+ def forward(self, source_image, dense_motion):
59
+ out = self.first(source_image)
60
+ encoder_map = [out]
61
+ for i in range(len(self.down_blocks)):
62
+ out = self.down_blocks[i](out)
63
+ encoder_map.append(out)
64
+
65
+ output_dict = {}
66
+ output_dict['contribution_maps'] = dense_motion['contribution_maps']
67
+ output_dict['deformed_source'] = dense_motion['deformed_source']
68
+
69
+ occlusion_map = dense_motion['occlusion_map']
70
+ output_dict['occlusion_map'] = occlusion_map
71
+
72
+ deformation = dense_motion['deformation']
73
+ out_ij = self.deform_input(out.detach(), deformation)
74
+ out = self.deform_input(out, deformation)
75
+
76
+ out_ij = self.occlude_input(out_ij, occlusion_map[0].detach())
77
+ out = self.occlude_input(out, occlusion_map[0])
78
+
79
+ warped_encoder_maps = []
80
+ warped_encoder_maps.append(out_ij)
81
+
82
+ for i in range(self.num_down_blocks):
83
+
84
+ out = self.resblock[2*i](out)
85
+ out = self.resblock[2*i+1](out)
86
+ out = self.up_blocks[i](out)
87
+
88
+ encode_i = encoder_map[-(i+2)]
89
+ encode_ij = self.deform_input(encode_i.detach(), deformation)
90
+ encode_i = self.deform_input(encode_i, deformation)
91
+
92
+ occlusion_ind = 0
93
+ if self.multi_mask:
94
+ occlusion_ind = i+1
95
+ encode_ij = self.occlude_input(encode_ij, occlusion_map[occlusion_ind].detach())
96
+ encode_i = self.occlude_input(encode_i, occlusion_map[occlusion_ind])
97
+ warped_encoder_maps.append(encode_ij)
98
+
99
+ if(i==self.num_down_blocks-1):
100
+ break
101
+
102
+ out = torch.cat([out, encode_i], 1)
103
+
104
+ deformed_source = self.deform_input(source_image, deformation)
105
+ output_dict["deformed"] = deformed_source
106
+ output_dict["warped_encoder_maps"] = warped_encoder_maps
107
+
108
+ occlusion_last = occlusion_map[-1]
109
+ if not self.multi_mask:
110
+ occlusion_last = F.interpolate(occlusion_last, size=out.shape[2:], mode='bilinear',align_corners=True)
111
+
112
+ out = out * (1 - occlusion_last) + encode_i
113
+ out = self.final(out)
114
+ out = torch.sigmoid(out)
115
+ out = out * (1 - occlusion_last) + deformed_source * occlusion_last
116
+ output_dict["prediction"] = out
117
+
118
+ return output_dict
119
+
120
+ def get_encode(self, driver_image, occlusion_map):
121
+ out = self.first(driver_image)
122
+ encoder_map = []
123
+ encoder_map.append(self.occlude_input(out.detach(), occlusion_map[-1].detach()))
124
+ for i in range(len(self.down_blocks)):
125
+ out = self.down_blocks[i](out.detach())
126
+ out_mask = self.occlude_input(out.detach(), occlusion_map[2-i].detach())
127
+ encoder_map.append(out_mask.detach())
128
+
129
+ return encoder_map
130
+
modules/keypoint_detector.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ from torchvision import models
4
+
5
+ class KPDetector(nn.Module):
6
+ """
7
+ Predict K*5 keypoints.
8
+ """
9
+
10
+ def __init__(self, num_tps, **kwargs):
11
+ super(KPDetector, self).__init__()
12
+ self.num_tps = num_tps
13
+
14
+ self.fg_encoder = models.resnet18(pretrained=False)
15
+ num_features = self.fg_encoder.fc.in_features
16
+ self.fg_encoder.fc = nn.Linear(num_features, num_tps*5*2)
17
+
18
+
19
+ def forward(self, image):
20
+
21
+ fg_kp = self.fg_encoder(image)
22
+ bs, _, = fg_kp.shape
23
+ fg_kp = torch.sigmoid(fg_kp)
24
+ fg_kp = fg_kp * 2 - 1
25
+ out = {'fg_kp': fg_kp.view(bs, self.num_tps*5, -1)}
26
+
27
+ return out
modules/model.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from modules.util import AntiAliasInterpolation2d, TPS
5
+ from torchvision import models
6
+ import numpy as np
7
+
8
+
9
+ class Vgg19(torch.nn.Module):
10
+ """
11
+ Vgg19 network for perceptual loss. See Sec 3.3.
12
+ """
13
+ def __init__(self, requires_grad=False):
14
+ super(Vgg19, self).__init__()
15
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
16
+ self.slice1 = torch.nn.Sequential()
17
+ self.slice2 = torch.nn.Sequential()
18
+ self.slice3 = torch.nn.Sequential()
19
+ self.slice4 = torch.nn.Sequential()
20
+ self.slice5 = torch.nn.Sequential()
21
+ for x in range(2):
22
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
23
+ for x in range(2, 7):
24
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
25
+ for x in range(7, 12):
26
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
27
+ for x in range(12, 21):
28
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
29
+ for x in range(21, 30):
30
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
31
+
32
+ self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
33
+ requires_grad=False)
34
+ self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
35
+ requires_grad=False)
36
+
37
+ if not requires_grad:
38
+ for param in self.parameters():
39
+ param.requires_grad = False
40
+
41
+ def forward(self, X):
42
+ X = (X - self.mean) / self.std
43
+ h_relu1 = self.slice1(X)
44
+ h_relu2 = self.slice2(h_relu1)
45
+ h_relu3 = self.slice3(h_relu2)
46
+ h_relu4 = self.slice4(h_relu3)
47
+ h_relu5 = self.slice5(h_relu4)
48
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
49
+ return out
50
+
51
+
52
+ class ImagePyramide(torch.nn.Module):
53
+ """
54
+ Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
55
+ """
56
+ def __init__(self, scales, num_channels):
57
+ super(ImagePyramide, self).__init__()
58
+ downs = {}
59
+ for scale in scales:
60
+ downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
61
+ self.downs = nn.ModuleDict(downs)
62
+
63
+ def forward(self, x):
64
+ out_dict = {}
65
+ for scale, down_module in self.downs.items():
66
+ out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
67
+ return out_dict
68
+
69
+
70
+ def detach_kp(kp):
71
+ return {key: value.detach() for key, value in kp.items()}
72
+
73
+
74
+ class GeneratorFullModel(torch.nn.Module):
75
+ """
76
+ Merge all generator related updates into single model for better multi-gpu usage
77
+ """
78
+
79
+ def __init__(self, kp_extractor, bg_predictor, dense_motion_network, inpainting_network, train_params, *kwargs):
80
+ super(GeneratorFullModel, self).__init__()
81
+ self.kp_extractor = kp_extractor
82
+ self.inpainting_network = inpainting_network
83
+ self.dense_motion_network = dense_motion_network
84
+
85
+ self.bg_predictor = None
86
+ if bg_predictor:
87
+ self.bg_predictor = bg_predictor
88
+ self.bg_start = train_params['bg_start']
89
+
90
+ self.train_params = train_params
91
+ self.scales = train_params['scales']
92
+
93
+ self.pyramid = ImagePyramide(self.scales, inpainting_network.num_channels)
94
+ if torch.cuda.is_available():
95
+ self.pyramid = self.pyramid.cuda()
96
+
97
+ self.loss_weights = train_params['loss_weights']
98
+ self.dropout_epoch = train_params['dropout_epoch']
99
+ self.dropout_maxp = train_params['dropout_maxp']
100
+ self.dropout_inc_epoch = train_params['dropout_inc_epoch']
101
+ self.dropout_startp =train_params['dropout_startp']
102
+
103
+ if sum(self.loss_weights['perceptual']) != 0:
104
+ self.vgg = Vgg19()
105
+ if torch.cuda.is_available():
106
+ self.vgg = self.vgg.cuda()
107
+
108
+
109
+ def forward(self, x, epoch):
110
+ kp_source = self.kp_extractor(x['source'])
111
+ kp_driving = self.kp_extractor(x['driving'])
112
+ bg_param = None
113
+ if self.bg_predictor:
114
+ if(epoch>=self.bg_start):
115
+ bg_param = self.bg_predictor(x['source'], x['driving'])
116
+
117
+ if(epoch>=self.dropout_epoch):
118
+ dropout_flag = False
119
+ dropout_p = 0
120
+ else:
121
+ # dropout_p will linearly increase from dropout_startp to dropout_maxp
122
+ dropout_flag = True
123
+ dropout_p = min(epoch/self.dropout_inc_epoch * self.dropout_maxp + self.dropout_startp, self.dropout_maxp)
124
+
125
+ dense_motion = self.dense_motion_network(source_image=x['source'], kp_driving=kp_driving,
126
+ kp_source=kp_source, bg_param = bg_param,
127
+ dropout_flag = dropout_flag, dropout_p = dropout_p)
128
+ generated = self.inpainting_network(x['source'], dense_motion)
129
+ generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
130
+
131
+ loss_values = {}
132
+
133
+ pyramide_real = self.pyramid(x['driving'])
134
+ pyramide_generated = self.pyramid(generated['prediction'])
135
+
136
+ # reconstruction loss
137
+ if sum(self.loss_weights['perceptual']) != 0:
138
+ value_total = 0
139
+ for scale in self.scales:
140
+ x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
141
+ y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
142
+
143
+ for i, weight in enumerate(self.loss_weights['perceptual']):
144
+ value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
145
+ value_total += self.loss_weights['perceptual'][i] * value
146
+ loss_values['perceptual'] = value_total
147
+
148
+ # equivariance loss
149
+ if self.loss_weights['equivariance_value'] != 0:
150
+ transform_random = TPS(mode = 'random', bs = x['driving'].shape[0], **self.train_params['transform_params'])
151
+ transform_grid = transform_random.transform_frame(x['driving'])
152
+ transformed_frame = F.grid_sample(x['driving'], transform_grid, padding_mode="reflection",align_corners=True)
153
+ transformed_kp = self.kp_extractor(transformed_frame)
154
+
155
+ generated['transformed_frame'] = transformed_frame
156
+ generated['transformed_kp'] = transformed_kp
157
+
158
+ warped = transform_random.warp_coordinates(transformed_kp['fg_kp'])
159
+ kp_d = kp_driving['fg_kp']
160
+ value = torch.abs(kp_d - warped).mean()
161
+ loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value
162
+
163
+ # warp loss
164
+ if self.loss_weights['warp_loss'] != 0:
165
+ occlusion_map = generated['occlusion_map']
166
+ encode_map = self.inpainting_network.get_encode(x['driving'], occlusion_map)
167
+ decode_map = generated['warped_encoder_maps']
168
+ value = 0
169
+ for i in range(len(encode_map)):
170
+ value += torch.abs(encode_map[i]-decode_map[-i-1]).mean()
171
+
172
+ loss_values['warp_loss'] = self.loss_weights['warp_loss'] * value
173
+
174
+ # bg loss
175
+ if self.bg_predictor and epoch >= self.bg_start and self.loss_weights['bg'] != 0:
176
+ bg_param_reverse = self.bg_predictor(x['driving'], x['source'])
177
+ value = torch.matmul(bg_param, bg_param_reverse)
178
+ eye = torch.eye(3).view(1, 1, 3, 3).type(value.type())
179
+ value = torch.abs(eye - value).mean()
180
+ loss_values['bg'] = self.loss_weights['bg'] * value
181
+
182
+ return loss_values, generated
modules/util.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+
5
+
6
+ class TPS:
7
+ '''
8
+ TPS transformation, mode 'kp' for Eq(2) in the paper, mode 'random' for equivariance loss.
9
+ '''
10
+ def __init__(self, mode, bs, **kwargs):
11
+ self.bs = bs
12
+ self.mode = mode
13
+ if mode == 'random':
14
+ noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
15
+ self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
16
+ self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
17
+ self.control_points = self.control_points.unsqueeze(0)
18
+ self.control_params = torch.normal(mean=0,
19
+ std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
20
+ elif mode == 'kp':
21
+ kp_1 = kwargs["kp_1"]
22
+ kp_2 = kwargs["kp_2"]
23
+ device = kp_1.device
24
+ kp_type = kp_1.type()
25
+ self.gs = kp_1.shape[1]
26
+ n = kp_1.shape[2]
27
+ K = torch.norm(kp_1[:,:,:, None]-kp_1[:,:, None, :], dim=4, p=2)
28
+ K = K**2
29
+ K = K * torch.log(K+1e-9)
30
+
31
+ one1 = torch.ones(self.bs, kp_1.shape[1], kp_1.shape[2], 1).to(device).type(kp_type)
32
+ kp_1p = torch.cat([kp_1,one1], 3)
33
+
34
+ zero = torch.zeros(self.bs, kp_1.shape[1], 3, 3).to(device).type(kp_type)
35
+ P = torch.cat([kp_1p, zero],2)
36
+ L = torch.cat([K,kp_1p.permute(0,1,3,2)],2)
37
+ L = torch.cat([L,P],3)
38
+
39
+ zero = torch.zeros(self.bs, kp_1.shape[1], 3, 2).to(device).type(kp_type)
40
+ Y = torch.cat([kp_2, zero], 2)
41
+ one = torch.eye(L.shape[2]).expand(L.shape).to(device).type(kp_type)*0.01
42
+ L = L + one
43
+
44
+ param = torch.matmul(torch.inverse(L),Y)
45
+ self.theta = param[:,:,n:,:].permute(0,1,3,2)
46
+
47
+ self.control_points = kp_1
48
+ self.control_params = param[:,:,:n,:]
49
+ else:
50
+ raise Exception("Error TPS mode")
51
+
52
+ def transform_frame(self, frame):
53
+ grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0).to(frame.device)
54
+ grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
55
+ shape = [self.bs, frame.shape[2], frame.shape[3], 2]
56
+ if self.mode == 'kp':
57
+ shape.insert(1, self.gs)
58
+ grid = self.warp_coordinates(grid).view(*shape)
59
+ return grid
60
+
61
+ def warp_coordinates(self, coordinates):
62
+ theta = self.theta.type(coordinates.type()).to(coordinates.device)
63
+ control_points = self.control_points.type(coordinates.type()).to(coordinates.device)
64
+ control_params = self.control_params.type(coordinates.type()).to(coordinates.device)
65
+
66
+ if self.mode == 'kp':
67
+ transformed = torch.matmul(theta[:, :, :, :2], coordinates.permute(0, 2, 1)) + theta[:, :, :, 2:]
68
+
69
+ distances = coordinates.view(coordinates.shape[0], 1, 1, -1, 2) - control_points.view(self.bs, control_points.shape[1], -1, 1, 2)
70
+
71
+ distances = distances ** 2
72
+ result = distances.sum(-1)
73
+ result = result * torch.log(result + 1e-9)
74
+ result = torch.matmul(result.permute(0, 1, 3, 2), control_params)
75
+ transformed = transformed.permute(0, 1, 3, 2) + result
76
+
77
+ elif self.mode == 'random':
78
+ theta = theta.unsqueeze(1)
79
+ transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
80
+ transformed = transformed.squeeze(-1)
81
+ ances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
82
+ distances = ances ** 2
83
+
84
+ result = distances.sum(-1)
85
+ result = result * torch.log(result + 1e-9)
86
+ result = result * control_params
87
+ result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
88
+ transformed = transformed + result
89
+ else:
90
+ raise Exception("Error TPS mode")
91
+
92
+ return transformed
93
+
94
+
95
+ def kp2gaussian(kp, spatial_size, kp_variance):
96
+ """
97
+ Transform a keypoint into gaussian like representation
98
+ """
99
+
100
+ coordinate_grid = make_coordinate_grid(spatial_size, kp.type()).to(kp.device)
101
+ number_of_leading_dimensions = len(kp.shape) - 1
102
+ shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
103
+ coordinate_grid = coordinate_grid.view(*shape)
104
+ repeats = kp.shape[:number_of_leading_dimensions] + (1, 1, 1)
105
+ coordinate_grid = coordinate_grid.repeat(*repeats)
106
+
107
+ # Preprocess kp shape
108
+ shape = kp.shape[:number_of_leading_dimensions] + (1, 1, 2)
109
+ kp = kp.view(*shape)
110
+
111
+ mean_sub = (coordinate_grid - kp)
112
+
113
+ out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
114
+
115
+ return out
116
+
117
+
118
+ def make_coordinate_grid(spatial_size, type):
119
+ """
120
+ Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
121
+ """
122
+ h, w = spatial_size
123
+ x = torch.arange(w).type(type)
124
+ y = torch.arange(h).type(type)
125
+
126
+ x = (2 * (x / (w - 1)) - 1)
127
+ y = (2 * (y / (h - 1)) - 1)
128
+
129
+ yy = y.view(-1, 1).repeat(1, w)
130
+ xx = x.view(1, -1).repeat(h, 1)
131
+
132
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
133
+
134
+ return meshed
135
+
136
+
137
+ class ResBlock2d(nn.Module):
138
+ """
139
+ Res block, preserve spatial resolution.
140
+ """
141
+
142
+ def __init__(self, in_features, kernel_size, padding):
143
+ super(ResBlock2d, self).__init__()
144
+ self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
145
+ padding=padding)
146
+ self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
147
+ padding=padding)
148
+ self.norm1 = nn.InstanceNorm2d(in_features, affine=True)
149
+ self.norm2 = nn.InstanceNorm2d(in_features, affine=True)
150
+
151
+ def forward(self, x):
152
+ out = self.norm1(x)
153
+ out = F.relu(out)
154
+ out = self.conv1(out)
155
+ out = self.norm2(out)
156
+ out = F.relu(out)
157
+ out = self.conv2(out)
158
+ out += x
159
+ return out
160
+
161
+
162
+ class UpBlock2d(nn.Module):
163
+ """
164
+ Upsampling block for use in decoder.
165
+ """
166
+
167
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
168
+ super(UpBlock2d, self).__init__()
169
+
170
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
171
+ padding=padding, groups=groups)
172
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
173
+
174
+ def forward(self, x):
175
+ out = F.interpolate(x, scale_factor=2)
176
+ out = self.conv(out)
177
+ out = self.norm(out)
178
+ out = F.relu(out)
179
+ return out
180
+
181
+
182
+ class DownBlock2d(nn.Module):
183
+ """
184
+ Downsampling block for use in encoder.
185
+ """
186
+
187
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
188
+ super(DownBlock2d, self).__init__()
189
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
190
+ padding=padding, groups=groups)
191
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
192
+ self.pool = nn.AvgPool2d(kernel_size=(2, 2))
193
+
194
+ def forward(self, x):
195
+ out = self.conv(x)
196
+ out = self.norm(out)
197
+ out = F.relu(out)
198
+ out = self.pool(out)
199
+ return out
200
+
201
+
202
+ class SameBlock2d(nn.Module):
203
+ """
204
+ Simple block, preserve spatial resolution.
205
+ """
206
+
207
+ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):
208
+ super(SameBlock2d, self).__init__()
209
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
210
+ kernel_size=kernel_size, padding=padding, groups=groups)
211
+ self.norm = nn.InstanceNorm2d(out_features, affine=True)
212
+
213
+ def forward(self, x):
214
+ out = self.conv(x)
215
+ out = self.norm(out)
216
+ out = F.relu(out)
217
+ return out
218
+
219
+
220
+ class Encoder(nn.Module):
221
+ """
222
+ Hourglass Encoder
223
+ """
224
+
225
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
226
+ super(Encoder, self).__init__()
227
+
228
+ down_blocks = []
229
+ for i in range(num_blocks):
230
+ down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
231
+ min(max_features, block_expansion * (2 ** (i + 1))),
232
+ kernel_size=3, padding=1))
233
+ self.down_blocks = nn.ModuleList(down_blocks)
234
+
235
+ def forward(self, x):
236
+ outs = [x]
237
+ #print('encoder:' ,outs[-1].shape)
238
+ for down_block in self.down_blocks:
239
+ outs.append(down_block(outs[-1]))
240
+ #print('encoder:' ,outs[-1].shape)
241
+ return outs
242
+
243
+
244
+ class Decoder(nn.Module):
245
+ """
246
+ Hourglass Decoder
247
+ """
248
+
249
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
250
+ super(Decoder, self).__init__()
251
+
252
+ up_blocks = []
253
+ self.out_channels = []
254
+ for i in range(num_blocks)[::-1]:
255
+ in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
256
+ self.out_channels.append(in_filters)
257
+ out_filters = min(max_features, block_expansion * (2 ** i))
258
+ up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
259
+
260
+ self.up_blocks = nn.ModuleList(up_blocks)
261
+ self.out_channels.append(block_expansion + in_features)
262
+ # self.out_filters = block_expansion + in_features
263
+
264
+ def forward(self, x, mode = 0):
265
+ out = x.pop()
266
+ outs = []
267
+ for up_block in self.up_blocks:
268
+ out = up_block(out)
269
+ skip = x.pop()
270
+ out = torch.cat([out, skip], dim=1)
271
+ outs.append(out)
272
+ if(mode == 0):
273
+ return out
274
+ else:
275
+ return outs
276
+
277
+
278
+ class Hourglass(nn.Module):
279
+ """
280
+ Hourglass architecture.
281
+ """
282
+
283
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
284
+ super(Hourglass, self).__init__()
285
+ self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
286
+ self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
287
+ self.out_channels = self.decoder.out_channels
288
+ # self.out_filters = self.decoder.out_filters
289
+
290
+ def forward(self, x, mode = 0):
291
+ return self.decoder(self.encoder(x), mode)
292
+
293
+
294
+ class AntiAliasInterpolation2d(nn.Module):
295
+ """
296
+ Band-limited downsampling, for better preservation of the input signal.
297
+ """
298
+ def __init__(self, channels, scale):
299
+ super(AntiAliasInterpolation2d, self).__init__()
300
+ sigma = (1 / scale - 1) / 2
301
+ kernel_size = 2 * round(sigma * 4) + 1
302
+ self.ka = kernel_size // 2
303
+ self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
304
+
305
+ kernel_size = [kernel_size, kernel_size]
306
+ sigma = [sigma, sigma]
307
+ # The gaussian kernel is the product of the
308
+ # gaussian function of each dimension.
309
+ kernel = 1
310
+ meshgrids = torch.meshgrid(
311
+ [
312
+ torch.arange(size, dtype=torch.float32)
313
+ for size in kernel_size
314
+ ]
315
+ )
316
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
317
+ mean = (size - 1) / 2
318
+ kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
319
+
320
+ # Make sure sum of values in gaussian kernel equals 1.
321
+ kernel = kernel / torch.sum(kernel)
322
+ # Reshape to depthwise convolutional weight
323
+ kernel = kernel.view(1, 1, *kernel.size())
324
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
325
+
326
+ self.register_buffer('weight', kernel)
327
+ self.groups = channels
328
+ self.scale = scale
329
+
330
+ def forward(self, input):
331
+ if self.scale == 1.0:
332
+ return input
333
+
334
+ out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
335
+ out = F.conv2d(out, weight=self.weight, groups=self.groups)
336
+ out = F.interpolate(out, scale_factor=(self.scale, self.scale))
337
+
338
+ return out
339
+
340
+
341
+ def to_homogeneous(coordinates):
342
+ ones_shape = list(coordinates.shape)
343
+ ones_shape[-1] = 1
344
+ ones = torch.ones(ones_shape).type(coordinates.type())
345
+
346
+ return torch.cat([coordinates, ones], dim=-1)
347
+
348
+ def from_homogeneous(coordinates):
349
+ return coordinates[..., :2] / coordinates[..., 2:3]
reconstruction.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ from logger import Logger, Visualizer
6
+ import numpy as np
7
+ import imageio
8
+
9
+
10
+ def reconstruction(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset):
11
+ png_dir = os.path.join(log_dir, 'reconstruction/png')
12
+ log_dir = os.path.join(log_dir, 'reconstruction')
13
+
14
+ if checkpoint is not None:
15
+ Logger.load_cpk(checkpoint, inpainting_network=inpainting_network, kp_detector=kp_detector,
16
+ bg_predictor=bg_predictor, dense_motion_network=dense_motion_network)
17
+ else:
18
+ raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
19
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
20
+
21
+ if not os.path.exists(log_dir):
22
+ os.makedirs(log_dir)
23
+
24
+ if not os.path.exists(png_dir):
25
+ os.makedirs(png_dir)
26
+
27
+ loss_list = []
28
+
29
+ inpainting_network.eval()
30
+ kp_detector.eval()
31
+ dense_motion_network.eval()
32
+ if bg_predictor:
33
+ bg_predictor.eval()
34
+
35
+ for it, x in tqdm(enumerate(dataloader)):
36
+ with torch.no_grad():
37
+ predictions = []
38
+ visualizations = []
39
+ if torch.cuda.is_available():
40
+ x['video'] = x['video'].cuda()
41
+ kp_source = kp_detector(x['video'][:, :, 0])
42
+ for frame_idx in range(x['video'].shape[2]):
43
+ source = x['video'][:, :, 0]
44
+ driving = x['video'][:, :, frame_idx]
45
+ kp_driving = kp_detector(driving)
46
+ bg_params = None
47
+ if bg_predictor:
48
+ bg_params = bg_predictor(source, driving)
49
+
50
+ dense_motion = dense_motion_network(source_image=source, kp_driving=kp_driving,
51
+ kp_source=kp_source, bg_param = bg_params,
52
+ dropout_flag = False)
53
+ out = inpainting_network(source, dense_motion)
54
+ out['kp_source'] = kp_source
55
+ out['kp_driving'] = kp_driving
56
+
57
+ predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
58
+
59
+ visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
60
+ driving=driving, out=out)
61
+ visualizations.append(visualization)
62
+ loss = torch.abs(out['prediction'] - driving).mean().cpu().numpy()
63
+
64
+ loss_list.append(loss)
65
+ # print(np.mean(loss_list))
66
+ predictions = np.concatenate(predictions, axis=1)
67
+ imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))
68
+
69
+ print("Reconstruction loss: %s" % np.mean(loss_list))
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cffi==1.14.6
2
+ cycler==0.10.0
3
+ decorator==5.1.0
4
+ face-alignment==1.3.5
5
+ imageio==2.9.0
6
+ imageio-ffmpeg==0.4.5
7
+ kiwisolver==1.3.2
8
+ matplotlib==3.4.3
9
+ networkx==2.6.3
10
+ numpy==1.20.3
11
+ pandas==1.3.3
12
+ Pillow==8.3.2
13
+ pycparser==2.20
14
+ pyparsing==2.4.7
15
+ python-dateutil==2.8.2
16
+ pytz==2021.1
17
+ PyWavelets==1.1.1
18
+ PyYAML==5.4.1
19
+ scikit-image==0.18.3
20
+ scikit-learn==1.0
21
+ scipy==1.7.1
22
+ six==1.16.0
23
+ torch==1.11.0
24
+ torchvision==0.12.0
25
+ tqdm==4.62.3
26
+ gradio
run.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ matplotlib.use('Agg')
3
+
4
+ import os, sys
5
+ import yaml
6
+ from argparse import ArgumentParser
7
+ from time import gmtime, strftime
8
+ from shutil import copy
9
+ from frames_dataset import FramesDataset
10
+
11
+ from modules.inpainting_network import InpaintingNetwork
12
+ from modules.keypoint_detector import KPDetector
13
+ from modules.bg_motion_predictor import BGMotionPredictor
14
+ from modules.dense_motion import DenseMotionNetwork
15
+ from modules.avd_network import AVDNetwork
16
+ import torch
17
+ from train import train
18
+ from train_avd import train_avd
19
+ from reconstruction import reconstruction
20
+ import os
21
+
22
+
23
+ if __name__ == "__main__":
24
+
25
+ if sys.version_info[0] < 3:
26
+ raise Exception("You must use Python 3 or higher. Recommended version is Python 3.9")
27
+
28
+ parser = ArgumentParser()
29
+ parser.add_argument("--config", default="config/vox-256.yaml", help="path to config")
30
+ parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "train_avd"])
31
+ parser.add_argument("--log_dir", default='log', help="path to log into")
32
+ parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore")
33
+ parser.add_argument("--device_ids", default="0,1", type=lambda x: list(map(int, x.split(','))),
34
+ help="Names of the devices comma separated.")
35
+
36
+ opt = parser.parse_args()
37
+ with open(opt.config) as f:
38
+ config = yaml.load(f)
39
+
40
+ if opt.checkpoint is not None:
41
+ log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
42
+ else:
43
+ log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
44
+ log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())
45
+
46
+ inpainting = InpaintingNetwork(**config['model_params']['generator_params'],
47
+ **config['model_params']['common_params'])
48
+
49
+ if torch.cuda.is_available():
50
+ cuda_device = torch.device('cuda:'+str(opt.device_ids[0]))
51
+ inpainting.to(cuda_device)
52
+
53
+ kp_detector = KPDetector(**config['model_params']['common_params'])
54
+ dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'],
55
+ **config['model_params']['dense_motion_params'])
56
+
57
+ if torch.cuda.is_available():
58
+ kp_detector.to(opt.device_ids[0])
59
+ dense_motion_network.to(opt.device_ids[0])
60
+
61
+ bg_predictor = None
62
+ if (config['model_params']['common_params']['bg']):
63
+ bg_predictor = BGMotionPredictor()
64
+ if torch.cuda.is_available():
65
+ bg_predictor.to(opt.device_ids[0])
66
+
67
+ avd_network = None
68
+ if opt.mode == "train_avd":
69
+ avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'],
70
+ **config['model_params']['avd_network_params'])
71
+ if torch.cuda.is_available():
72
+ avd_network.to(opt.device_ids[0])
73
+
74
+ dataset = FramesDataset(is_train=(opt.mode.startswith('train')), **config['dataset_params'])
75
+
76
+ if not os.path.exists(log_dir):
77
+ os.makedirs(log_dir)
78
+ if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
79
+ copy(opt.config, log_dir)
80
+
81
+ if opt.mode == 'train':
82
+ print("Training...")
83
+ train(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset)
84
+ elif opt.mode == 'train_avd':
85
+ print("Training Animation via Disentaglement...")
86
+ train_avd(config, inpainting, kp_detector, bg_predictor, dense_motion_network, avd_network, opt.checkpoint, log_dir, dataset)
87
+ elif opt.mode == 'reconstruction':
88
+ print("Reconstruction...")
89
+ reconstruction(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset)
train.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import trange
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ from logger import Logger
5
+ from modules.model import GeneratorFullModel
6
+ from torch.optim.lr_scheduler import MultiStepLR
7
+ from torch.nn.utils import clip_grad_norm_
8
+ from frames_dataset import DatasetRepeater
9
+ import math
10
+
11
+ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset):
12
+ train_params = config['train_params']
13
+ optimizer = torch.optim.Adam(
14
+ [{'params': list(inpainting_network.parameters()) +
15
+ list(dense_motion_network.parameters()) +
16
+ list(kp_detector.parameters()), 'initial_lr': train_params['lr_generator']}],lr=train_params['lr_generator'], betas=(0.5, 0.999), weight_decay = 1e-4)
17
+
18
+ optimizer_bg_predictor = None
19
+ if bg_predictor:
20
+ optimizer_bg_predictor = torch.optim.Adam(
21
+ [{'params':bg_predictor.parameters(),'initial_lr': train_params['lr_generator']}],
22
+ lr=train_params['lr_generator'], betas=(0.5, 0.999), weight_decay = 1e-4)
23
+
24
+ if checkpoint is not None:
25
+ start_epoch = Logger.load_cpk(
26
+ checkpoint, inpainting_network = inpainting_network, dense_motion_network = dense_motion_network,
27
+ kp_detector = kp_detector, bg_predictor = bg_predictor,
28
+ optimizer = optimizer, optimizer_bg_predictor = optimizer_bg_predictor)
29
+ print('load success:', start_epoch)
30
+ start_epoch += 1
31
+ else:
32
+ start_epoch = 0
33
+
34
+ scheduler_optimizer = MultiStepLR(optimizer, train_params['epoch_milestones'], gamma=0.1,
35
+ last_epoch=start_epoch - 1)
36
+ if bg_predictor:
37
+ scheduler_bg_predictor = MultiStepLR(optimizer_bg_predictor, train_params['epoch_milestones'],
38
+ gamma=0.1, last_epoch=start_epoch - 1)
39
+
40
+ if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
41
+ dataset = DatasetRepeater(dataset, train_params['num_repeats'])
42
+ dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True,
43
+ num_workers=train_params['dataloader_workers'], drop_last=True)
44
+
45
+ generator_full = GeneratorFullModel(kp_detector, bg_predictor, dense_motion_network, inpainting_network, train_params)
46
+
47
+ if torch.cuda.is_available():
48
+ generator_full = torch.nn.DataParallel(generator_full).cuda()
49
+
50
+ bg_start = train_params['bg_start']
51
+
52
+ with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'],
53
+ checkpoint_freq=train_params['checkpoint_freq']) as logger:
54
+ for epoch in trange(start_epoch, train_params['num_epochs']):
55
+ for x in dataloader:
56
+ if(torch.cuda.is_available()):
57
+ x['driving'] = x['driving'].cuda()
58
+ x['source'] = x['source'].cuda()
59
+
60
+ losses_generator, generated = generator_full(x, epoch)
61
+ loss_values = [val.mean() for val in losses_generator.values()]
62
+ loss = sum(loss_values)
63
+ loss.backward()
64
+
65
+ clip_grad_norm_(kp_detector.parameters(), max_norm=10, norm_type = math.inf)
66
+ clip_grad_norm_(dense_motion_network.parameters(), max_norm=10, norm_type = math.inf)
67
+ if bg_predictor and epoch>=bg_start:
68
+ clip_grad_norm_(bg_predictor.parameters(), max_norm=10, norm_type = math.inf)
69
+
70
+ optimizer.step()
71
+ optimizer.zero_grad()
72
+ if bg_predictor and epoch>=bg_start:
73
+ optimizer_bg_predictor.step()
74
+ optimizer_bg_predictor.zero_grad()
75
+
76
+ losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
77
+ logger.log_iter(losses=losses)
78
+
79
+ scheduler_optimizer.step()
80
+ if bg_predictor:
81
+ scheduler_bg_predictor.step()
82
+
83
+ model_save = {
84
+ 'inpainting_network': inpainting_network,
85
+ 'dense_motion_network': dense_motion_network,
86
+ 'kp_detector': kp_detector,
87
+ 'optimizer': optimizer,
88
+ }
89
+ if bg_predictor and epoch>=bg_start:
90
+ model_save['bg_predictor'] = bg_predictor
91
+ model_save['optimizer_bg_predictor'] = optimizer_bg_predictor
92
+
93
+ logger.log_epoch(epoch, model_save, inp=x, out=generated)
94
+
train_avd.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import trange
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ from logger import Logger
5
+ from torch.optim.lr_scheduler import MultiStepLR
6
+ from frames_dataset import DatasetRepeater
7
+
8
+
9
+ def random_scale(kp_params, scale):
10
+ theta = torch.rand(kp_params['fg_kp'].shape[0], 2) * (2 * scale) + (1 - scale)
11
+ theta = torch.diag_embed(theta).unsqueeze(1).type(kp_params['fg_kp'].type())
12
+ new_kp_params = {'fg_kp': torch.matmul(theta, kp_params['fg_kp'].unsqueeze(-1)).squeeze(-1)}
13
+ return new_kp_params
14
+
15
+
16
+ def train_avd(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network,
17
+ avd_network, checkpoint, log_dir, dataset):
18
+ train_params = config['train_avd_params']
19
+
20
+ optimizer = torch.optim.Adam(avd_network.parameters(), lr=train_params['lr'], betas=(0.5, 0.999))
21
+
22
+ if checkpoint is not None:
23
+ Logger.load_cpk(checkpoint, inpainting_network=inpainting_network, kp_detector=kp_detector,
24
+ bg_predictor=bg_predictor, avd_network=avd_network,
25
+ dense_motion_network= dense_motion_network,optimizer_avd=optimizer)
26
+ start_epoch = 0
27
+ else:
28
+ raise AttributeError("Checkpoint should be specified for mode='train_avd'.")
29
+
30
+ scheduler = MultiStepLR(optimizer, train_params['epoch_milestones'], gamma=0.1)
31
+
32
+ if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
33
+ dataset = DatasetRepeater(dataset, train_params['num_repeats'])
34
+
35
+ dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True,
36
+ num_workers=train_params['dataloader_workers'], drop_last=True)
37
+
38
+ with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'],
39
+ checkpoint_freq=train_params['checkpoint_freq']) as logger:
40
+ for epoch in trange(start_epoch, train_params['num_epochs']):
41
+ avd_network.train()
42
+ for x in dataloader:
43
+ with torch.no_grad():
44
+ kp_source = kp_detector(x['source'].cuda())
45
+ kp_driving_gt = kp_detector(x['driving'].cuda())
46
+ kp_driving_random = random_scale(kp_driving_gt, scale=train_params['random_scale'])
47
+ rec = avd_network(kp_source, kp_driving_random)
48
+
49
+ reconstruction_kp = train_params['lambda_shift'] * \
50
+ torch.abs(kp_driving_gt['fg_kp'] - rec['fg_kp']).mean()
51
+
52
+ loss_dict = {'rec_kp': reconstruction_kp}
53
+ loss = reconstruction_kp
54
+
55
+ loss.backward()
56
+ optimizer.step()
57
+ optimizer.zero_grad()
58
+
59
+ losses = {key: value.mean().detach().data.cpu().numpy() for key, value in loss_dict.items()}
60
+ logger.log_iter(losses=losses)
61
+
62
+ # Visualization
63
+ avd_network.eval()
64
+ with torch.no_grad():
65
+ source = x['source'][:6].cuda()
66
+ driving = torch.cat([x['driving'][[0, 1]].cuda(), source[[2, 3, 2, 1]]], dim=0)
67
+ kp_source = kp_detector(source)
68
+ kp_driving = kp_detector(driving)
69
+
70
+ out = avd_network(kp_source, kp_driving)
71
+ kp_driving = out
72
+ dense_motion = dense_motion_network(source_image=source, kp_driving=kp_driving,
73
+ kp_source=kp_source)
74
+ generated = inpainting_network(source, dense_motion)
75
+
76
+ generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
77
+
78
+ scheduler.step(epoch)
79
+ model_save = {
80
+ 'inpainting_network': inpainting_network,
81
+ 'dense_motion_network': dense_motion_network,
82
+ 'kp_detector': kp_detector,
83
+ 'avd_network': avd_network,
84
+ 'optimizer_avd': optimizer
85
+ }
86
+ if bg_predictor :
87
+ model_save['bg_predictor'] = bg_predictor
88
+
89
+ logger.log_epoch(epoch, model_save,
90
+ inp={'source': source, 'driving': driving},
91
+ out=generated)