MohamedIFQ commited on
Commit
45ddb95
·
verified ·
1 Parent(s): c748a8f

Upload 4 files

Browse files
Files changed (4) hide show
  1. cog.yaml +30 -0
  2. demo.py +307 -0
  3. predict.py +308 -0
  4. requirements.txt +10 -0
cog.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ gpu: true
3
+ python_version: "3.8"
4
+ system_packages:
5
+ - "libgl1-mesa-glx"
6
+ - "libglib2.0-0"
7
+ - "libsox-fmt-mp3"
8
+ python_packages:
9
+ - "torch==1.7.1"
10
+ - "torchvision==0.8.2"
11
+ - "numpy==1.18.1"
12
+ - "ipython==7.21.0"
13
+ - "Pillow==8.3.1"
14
+ - "scikit-image==0.18.3"
15
+ - "librosa==0.7.2"
16
+ - "tqdm==4.62.3"
17
+ - "scipy==1.7.1"
18
+ - "dominate==2.6.0"
19
+ - "albumentations==0.5.2"
20
+ - "beautifulsoup4==4.10.0"
21
+ - "sox==1.4.1"
22
+ - "h5py==3.4.0"
23
+ - "numba==0.48"
24
+ - "moviepy==1.0.3"
25
+ run:
26
+ - apt update -y && apt-get install ffmpeg -y
27
+ - apt-get install sox libsox-fmt-mp3 -y
28
+ - pip install opencv-python==4.1.2.30
29
+
30
+ predict: "predict.py:Predictor"
demo.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from os.path import join
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+ import torch
7
+ from collections import OrderedDict
8
+ import librosa
9
+ from skimage.io import imread
10
+ import cv2
11
+ import scipy.io as sio
12
+ import argparse
13
+ import yaml
14
+ import albumentations as A
15
+ import albumentations.pytorch
16
+ from pathlib import Path
17
+
18
+ from options.test_audio2feature_options import TestOptions as FeatureOptions
19
+ from options.test_audio2headpose_options import TestOptions as HeadposeOptions
20
+ from options.test_feature2face_options import TestOptions as RenderOptions
21
+
22
+ from datasets import create_dataset
23
+ from models import create_model
24
+ from models.networks import APC_encoder
25
+ import util.util as util
26
+ from util.visualizer import Visualizer
27
+ from funcs import utils
28
+ from funcs import audio_funcs
29
+ import soundfile as sf
30
+ import warnings
31
+ warnings.filterwarnings("ignore")
32
+
33
+
34
+
35
+ def write_video_with_audio(audio_path, output_path, prefix='pred_'):
36
+ fps, fourcc = 60, cv2.VideoWriter_fourcc(*'DIVX')
37
+ video_tmp_path = join(save_root, 'tmp.avi')
38
+ out = cv2.VideoWriter(video_tmp_path, fourcc, fps, (Renderopt.loadSize, Renderopt.loadSize))
39
+ for j in tqdm(range(nframe), position=0, desc='writing video'):
40
+ img = cv2.imread(join(save_root, prefix + str(j+1) + '.jpg'))
41
+ out.write(img)
42
+ out.release()
43
+ cmd = 'ffmpeg -i "' + video_tmp_path + '" -i "' + audio_path + '" -codec copy -shortest "' + output_path + '"'
44
+ subprocess.call(cmd, shell=True)
45
+ os.remove(video_tmp_path) # remove the template video
46
+
47
+
48
+
49
+ if __name__ == '__main__':
50
+ parser = argparse.ArgumentParser()
51
+ parser.add_argument('--id', default='May', help="person name, e.g. Obama1, Obama2, May, Nadella, McStay")
52
+ parser.add_argument('--driving_audio', default='./data/input/00083.wav', help="path to driving audio")
53
+ parser.add_argument('--save_intermediates', default=0, help="whether to save intermediate results")
54
+ parser.add_argument('--device', type=str, default='cpu', help='use cuda for GPU or use cpu for CPU')
55
+
56
+
57
+ ############################### I/O Settings ##############################
58
+ # load config files
59
+ opt = parser.parse_args()
60
+ device = torch.device(opt.device)
61
+ with open(join('./config/', opt.id + '.yaml')) as f:
62
+ config = yaml.load(f, Loader=yaml.SafeLoader)
63
+ data_root = join('./data/', opt.id)
64
+ # create the results folder
65
+ audio_name = os.path.split(opt.driving_audio)[1][:-4]
66
+ save_root = join('./results/', opt.id, audio_name)
67
+ if not os.path.exists(save_root):
68
+ os.makedirs(save_root)
69
+
70
+
71
+
72
+ ############################ Hyper Parameters #############################
73
+ h, w, sr, FPS = 512, 512, 16000, 60
74
+ mouth_indices = np.concatenate([np.arange(4, 11), np.arange(46, 64)])
75
+ eye_brow_indices = [27, 65, 28, 68, 29, 67, 30, 66, 31, 72, 32, 69, 33, 70, 34, 71]
76
+ eye_brow_indices = np.array(eye_brow_indices, np.int32)
77
+
78
+
79
+
80
+ ############################ Pre-defined Data #############################
81
+ mean_pts3d = np.load(join(data_root, 'mean_pts3d.npy'))
82
+ fit_data = np.load(config['dataset_params']['fit_data_path'])
83
+ pts3d = np.load(config['dataset_params']['pts3d_path']) - mean_pts3d
84
+ trans = fit_data['trans'][:,:,0].astype(np.float32)
85
+ mean_translation = trans.mean(axis=0)
86
+ candidate_eye_brow = pts3d[10:, eye_brow_indices]
87
+ std_mean_pts3d = np.load(config['dataset_params']['pts3d_path']).mean(axis=0)
88
+ # candidates images
89
+ img_candidates = []
90
+ for j in range(4):
91
+ output = imread(join(data_root, 'candidates', f'normalized_full_{j}.jpg'))
92
+ output = A.pytorch.transforms.ToTensor(normalize={'mean':(0.5,0.5,0.5),
93
+ 'std':(0.5,0.5,0.5)})(image=output)['image']
94
+ img_candidates.append(output)
95
+ img_candidates = torch.cat(img_candidates).unsqueeze(0).to(device)
96
+
97
+ # shoulders
98
+ shoulders = np.load(join(data_root, 'normalized_shoulder_points.npy'))
99
+ shoulder3D = np.load(join(data_root, 'shoulder_points3D.npy'))[1]
100
+ ref_trans = trans[1]
101
+
102
+ # camera matrix, we always use training set intrinsic parameters.
103
+ camera = utils.camera()
104
+ camera_intrinsic = np.load(join(data_root, 'camera_intrinsic.npy')).astype(np.float32)
105
+ APC_feat_database = np.load(join(data_root, 'APC_feature_base.npy'))
106
+
107
+ # load reconstruction data
108
+ scale = sio.loadmat(join(data_root, 'id_scale.mat'))['scale'][0,0]
109
+ # Audio2Mel_torch = audio_funcs.Audio2Mel(n_fft=512, hop_length=int(16000/120), win_length=int(16000/60), sampling_rate=16000,
110
+ # n_mel_channels=80, mel_fmin=90, mel_fmax=7600.0).to(device)
111
+
112
+
113
+
114
+ ########################### Experiment Settings ###########################
115
+ #### user config
116
+ use_LLE = config['model_params']['APC']['use_LLE']
117
+ Knear = config['model_params']['APC']['Knear']
118
+ LLE_percent = config['model_params']['APC']['LLE_percent']
119
+ headpose_sigma = config['model_params']['Headpose']['sigma']
120
+ Feat_smooth_sigma = config['model_params']['Audio2Mouth']['smooth']
121
+ Head_smooth_sigma = config['model_params']['Headpose']['smooth']
122
+ Feat_center_smooth_sigma, Head_center_smooth_sigma = 0, 0
123
+ AMP_method = config['model_params']['Audio2Mouth']['AMP'][0]
124
+ Feat_AMPs = config['model_params']['Audio2Mouth']['AMP'][1:]
125
+ rot_AMP, trans_AMP = config['model_params']['Headpose']['AMP']
126
+ shoulder_AMP = config['model_params']['Headpose']['shoulder_AMP']
127
+ save_feature_maps = config['model_params']['Image2Image']['save_input']
128
+
129
+ #### common settings
130
+ Featopt = FeatureOptions().parse()
131
+ Headopt = HeadposeOptions().parse()
132
+ Renderopt = RenderOptions().parse()
133
+ Featopt.load_epoch = config['model_params']['Audio2Mouth']['ckp_path']
134
+ Headopt.load_epoch = config['model_params']['Headpose']['ckp_path']
135
+ Renderopt.dataroot = config['dataset_params']['root']
136
+ Renderopt.load_epoch = config['model_params']['Image2Image']['ckp_path']
137
+ Renderopt.size = config['model_params']['Image2Image']['size']
138
+ ## GPU or CPU
139
+ if opt.device == 'cpu':
140
+ Featopt.gpu_ids = Headopt.gpu_ids = Renderopt.gpu_ids = []
141
+
142
+
143
+
144
+ ############################# Load Models #################################
145
+ print('---------- Loading Model: APC-------------')
146
+ APC_model = APC_encoder(config['model_params']['APC']['mel_dim'],
147
+ config['model_params']['APC']['hidden_size'],
148
+ config['model_params']['APC']['num_layers'],
149
+ config['model_params']['APC']['residual'])
150
+ APC_model.load_state_dict(torch.load(config['model_params']['APC']['ckp_path'],map_location=device), strict=False)
151
+ if opt.device == 'cuda':
152
+ APC_model.cuda()
153
+ APC_model.eval()
154
+ print('---------- Loading Model: {} -------------'.format(Featopt.task))
155
+ Audio2Feature = create_model(Featopt)
156
+ Audio2Feature.setup(Featopt)
157
+ Audio2Feature.eval()
158
+ print('---------- Loading Model: {} -------------'.format(Headopt.task))
159
+ Audio2Headpose = create_model(Headopt)
160
+ Audio2Headpose.setup(Headopt)
161
+ Audio2Headpose.eval()
162
+ if Headopt.feature_decoder == 'WaveNet':
163
+ if opt.device == 'cuda':
164
+ Headopt.A2H_receptive_field = Audio2Headpose.Audio2Headpose.module.WaveNet.receptive_field
165
+ else:
166
+ Headopt.A2H_receptive_field = Audio2Headpose.Audio2Headpose.WaveNet.receptive_field
167
+ print('---------- Loading Model: {} -------------'.format(Renderopt.task))
168
+ facedataset = create_dataset(Renderopt)
169
+ Feature2Face = create_model(Renderopt)
170
+ Feature2Face.setup(Renderopt)
171
+ Feature2Face.eval()
172
+ visualizer = Visualizer(Renderopt)
173
+
174
+
175
+
176
+ ############################## Inference ##################################
177
+ print('Processing audio: {} ...'.format(audio_name))
178
+ # read audio
179
+ audio, _ = librosa.load(opt.driving_audio, sr=sr)
180
+ total_frames = np.int32(audio.shape[0] / sr * FPS)
181
+
182
+
183
+ #### 1. compute APC features
184
+ print('1. Computing APC features...')
185
+ mel80 = utils.compute_mel_one_sequence(audio, device=opt.device)
186
+ mel_nframe = mel80.shape[0]
187
+ with torch.no_grad():
188
+ length = torch.Tensor([mel_nframe])
189
+ mel80_torch = torch.from_numpy(mel80.astype(np.float32)).to(device).unsqueeze(0)
190
+ hidden_reps = APC_model.forward(mel80_torch, length)[0] # [mel_nframe, 512]
191
+ hidden_reps = hidden_reps.cpu().numpy()
192
+ audio_feats = hidden_reps
193
+
194
+
195
+ #### 2. manifold projection
196
+ if use_LLE:
197
+ print('2. Manifold projection...')
198
+ ind = utils.KNN_with_torch(audio_feats, APC_feat_database, K=Knear)
199
+ weights, feat_fuse = utils.compute_LLE_projection_all_frame(audio_feats, APC_feat_database, ind, audio_feats.shape[0])
200
+ audio_feats = audio_feats * (1-LLE_percent) + feat_fuse * LLE_percent
201
+
202
+
203
+ #### 3. Audio2Mouth
204
+ print('3. Audio2Mouth inference...')
205
+ pred_Feat = Audio2Feature.generate_sequences(audio_feats, sr, FPS, fill_zero=True, opt=Featopt)
206
+
207
+
208
+ #### 4. Audio2Headpose
209
+ print('4. Headpose inference...')
210
+ # set history headposes as zero
211
+ pre_headpose = np.zeros(Headopt.A2H_wavenet_input_channels, np.float32)
212
+ pred_Head = Audio2Headpose.generate_sequences(audio_feats, pre_headpose, fill_zero=True, sigma_scale=0.3, opt=Headopt)
213
+
214
+
215
+ #### 5. Post-Processing
216
+ print('5. Post-processing...')
217
+ nframe = min(pred_Feat.shape[0], pred_Head.shape[0])
218
+ pred_pts3d = np.zeros([nframe, 73, 3])
219
+ pred_pts3d[:, mouth_indices] = pred_Feat.reshape(-1, 25, 3)[:nframe]
220
+
221
+ ## mouth
222
+ pred_pts3d = utils.landmark_smooth_3d(pred_pts3d, Feat_smooth_sigma, area='only_mouth')
223
+ pred_pts3d = utils.mouth_pts_AMP(pred_pts3d, True, AMP_method, Feat_AMPs)
224
+ pred_pts3d = pred_pts3d + mean_pts3d
225
+ pred_pts3d = utils.solve_intersect_mouth(pred_pts3d) # solve intersect lips if exist
226
+
227
+ ## headpose
228
+ pred_Head[:, 0:3] *= rot_AMP
229
+ pred_Head[:, 3:6] *= trans_AMP
230
+ pred_headpose = utils.headpose_smooth(pred_Head[:,:6], Head_smooth_sigma).astype(np.float32)
231
+ pred_headpose[:, 3:] += mean_translation
232
+ pred_headpose[:, 0] += 180
233
+
234
+ ## compute projected landmarks
235
+ pred_landmarks = np.zeros([nframe, 73, 2], dtype=np.float32)
236
+ final_pts3d = np.zeros([nframe, 73, 3], dtype=np.float32)
237
+ final_pts3d[:] = std_mean_pts3d.copy()
238
+ final_pts3d[:, 46:64] = pred_pts3d[:nframe, 46:64]
239
+ for k in tqdm(range(nframe)):
240
+ ind = k % candidate_eye_brow.shape[0]
241
+ final_pts3d[k, eye_brow_indices] = candidate_eye_brow[ind] + mean_pts3d[eye_brow_indices]
242
+ pred_landmarks[k], _, _ = utils.project_landmarks(camera_intrinsic, camera.relative_rotation,
243
+ camera.relative_translation, scale,
244
+ pred_headpose[k], final_pts3d[k])
245
+
246
+ ## Upper Body Motion
247
+ pred_shoulders = np.zeros([nframe, 18, 2], dtype=np.float32)
248
+ pred_shoulders3D = np.zeros([nframe, 18, 3], dtype=np.float32)
249
+ for k in range(nframe):
250
+ diff_trans = pred_headpose[k][3:] - ref_trans
251
+ pred_shoulders3D[k] = shoulder3D + diff_trans * shoulder_AMP
252
+ # project
253
+ project = camera_intrinsic.dot(pred_shoulders3D[k].T)
254
+ project[:2, :] /= project[2, :] # divide z
255
+ pred_shoulders[k] = project[:2, :].T
256
+
257
+
258
+ #### 6. Image2Image translation & Save resuls
259
+ print('6. Image2Image translation & Saving results...')
260
+ for ind in tqdm(range(0, nframe), desc='Image2Image translation inference'):
261
+ # feature_map: [input_nc, h, w]
262
+ current_pred_feature_map = facedataset.dataset.get_data_test_mode(pred_landmarks[ind],
263
+ pred_shoulders[ind],
264
+ facedataset.dataset.image_pad)
265
+ input_feature_maps = current_pred_feature_map.unsqueeze(0).to(device)
266
+ pred_fake = Feature2Face.inference(input_feature_maps, img_candidates)
267
+ # save results
268
+ visual_list = [('pred', util.tensor2im(pred_fake[0]))]
269
+ if save_feature_maps:
270
+ visual_list += [('input', np.uint8(current_pred_feature_map[0].cpu().numpy() * 255))]
271
+ visuals = OrderedDict(visual_list)
272
+ visualizer.save_images(save_root, visuals, str(ind+1))
273
+
274
+
275
+ ## make videos
276
+ # generate corresponding audio, reused for all results
277
+ tmp_audio_path = join(save_root, 'tmp.wav')
278
+ tmp_audio_clip = audio[ : np.int32(nframe * sr / FPS)]
279
+ sf.write(tmp_audio_path, tmp_audio_clip, sr)
280
+ # librosa.output.write_wav(tmp_audio_path, tmp_audio_clip, sr)
281
+
282
+
283
+ final_path = join(save_root, audio_name + '.avi')
284
+ write_video_with_audio(tmp_audio_path, final_path, 'pred_')
285
+ feature_maps_path = join(save_root, audio_name + '_feature_maps.avi')
286
+ write_video_with_audio(tmp_audio_path, feature_maps_path, 'input_')
287
+
288
+ if os.path.exists(tmp_audio_path):
289
+ os.remove(tmp_audio_path)
290
+ if not opt.save_intermediates:
291
+ _img_paths = list(map(lambda x:str(x), list(Path(save_root).glob('*.jpg'))))
292
+ for i in tqdm(range(len(_img_paths)), desc='deleting intermediate images'):
293
+ os.remove(_img_paths[i])
294
+
295
+
296
+ print('Finish!')
297
+
298
+
299
+
300
+
301
+
302
+
303
+
304
+
305
+
306
+
307
+
predict.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from os.path import join
4
+ import yaml
5
+ import tempfile
6
+ import argparse
7
+ from skimage.io import imread
8
+ import numpy as np
9
+ import librosa
10
+ from util import util
11
+ from tqdm import tqdm
12
+ import torch
13
+ from collections import OrderedDict
14
+ import cv2
15
+ from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
16
+ from cog import BasePredictor, Input, Path
17
+ import scipy.io as sio
18
+ import albumentations as A
19
+ from options.test_audio2feature_options import TestOptions as FeatureOptions
20
+ from options.test_audio2headpose_options import TestOptions as HeadposeOptions
21
+ from options.test_feature2face_options import TestOptions as RenderOptions
22
+ from datasets import create_dataset
23
+ from models import create_model
24
+ from models.networks import APC_encoder
25
+ from util.visualizer import Visualizer
26
+ from funcs import utils, audio_funcs
27
+ from demo import write_video_with_audio
28
+ import warnings
29
+
30
+ warnings.filterwarnings("ignore")
31
+
32
+
33
+ class Predictor(BasePredictor):
34
+ def setup(self):
35
+ self.parser = argparse.ArgumentParser()
36
+ self.parser.add_argument('--id', default='May', help="person name, e.g. Obama1, Obama2, May, Nadella, McStay")
37
+ self.parser.add_argument('--driving_audio', default='data/Input/00083.wav', help="path to driving audio")
38
+ self.parser.add_argument('--save_intermediates', default=0, help="whether to save intermediate results")
39
+
40
+ def predict(self,
41
+ driving_audio: Path = Input(description='driving audio, if the file is more than 20 seconds, only the first 20 seconds will be processed for video generation'),
42
+ talking_head: str = Input(description="choose a talking head", choices=['May', 'Obama1', 'Obama2', 'Nadella', 'McStay'], default='May')
43
+ ) -> Path:
44
+
45
+ ############################### I/O Settings ##############################
46
+ # load config files
47
+ opt = self.parser.parse_args('')
48
+ opt.driving_audio = str(driving_audio)
49
+ opt.id = talking_head
50
+ with open(join('config', opt.id + '.yaml')) as f:
51
+ config = yaml.safe_load(f)
52
+ data_root = join('data', opt.id)
53
+
54
+ ############################ Hyper Parameters #############################
55
+ h, w, sr, FPS = 512, 512, 16000, 60
56
+ mouth_indices = np.concatenate([np.arange(4, 11), np.arange(46, 64)])
57
+ eye_brow_indices = [27, 65, 28, 68, 29, 67, 30, 66, 31, 72, 32, 69, 33, 70, 34, 71]
58
+ eye_brow_indices = np.array(eye_brow_indices, np.int32)
59
+
60
+ ############################ Pre-defined Data #############################
61
+ mean_pts3d = np.load(join(data_root, 'mean_pts3d.npy'))
62
+ fit_data = np.load(config['dataset_params']['fit_data_path'])
63
+ pts3d = np.load(config['dataset_params']['pts3d_path']) - mean_pts3d
64
+ trans = fit_data['trans'][:, :, 0].astype(np.float32)
65
+ mean_translation = trans.mean(axis=0)
66
+ candidate_eye_brow = pts3d[10:, eye_brow_indices]
67
+ std_mean_pts3d = np.load(config['dataset_params']['pts3d_path']).mean(axis=0)
68
+ # candidates images
69
+ img_candidates = []
70
+ for j in range(4):
71
+ output = imread(join(data_root, 'candidates', f'normalized_full_{j}.jpg'))
72
+ output = A.pytorch.transforms.ToTensor(normalize={'mean': (0.5, 0.5, 0.5),
73
+ 'std': (0.5, 0.5, 0.5)})(image=output)['image']
74
+ img_candidates.append(output)
75
+ img_candidates = torch.cat(img_candidates).unsqueeze(0).cuda()
76
+
77
+ # shoulders
78
+ shoulders = np.load(join(data_root, 'normalized_shoulder_points.npy'))
79
+ shoulder3D = np.load(join(data_root, 'shoulder_points3D.npy'))[1]
80
+ ref_trans = trans[1]
81
+
82
+ # camera matrix, we always use training set intrinsic parameters.
83
+ camera = utils.camera()
84
+ camera_intrinsic = np.load(join(data_root, 'camera_intrinsic.npy')).astype(np.float32)
85
+ APC_feat_database = np.load(join(data_root, 'APC_feature_base.npy'))
86
+
87
+ # load reconstruction data
88
+ scale = sio.loadmat(join(data_root, 'id_scale.mat'))['scale'][0, 0]
89
+ Audio2Mel_torch = audio_funcs.Audio2Mel(n_fft=512, hop_length=int(16000 / 120), win_length=int(16000 / 60),
90
+ sampling_rate=16000,
91
+ n_mel_channels=80, mel_fmin=90, mel_fmax=7600.0).cuda()
92
+
93
+ ########################### Experiment Settings ###########################
94
+ #### user config
95
+ use_LLE = config['model_params']['APC']['use_LLE']
96
+ Knear = config['model_params']['APC']['Knear']
97
+ LLE_percent = config['model_params']['APC']['LLE_percent']
98
+ headpose_sigma = config['model_params']['Headpose']['sigma']
99
+ Feat_smooth_sigma = config['model_params']['Audio2Mouth']['smooth']
100
+ Head_smooth_sigma = config['model_params']['Headpose']['smooth']
101
+ Feat_center_smooth_sigma, Head_center_smooth_sigma = 0, 0
102
+ AMP_method = config['model_params']['Audio2Mouth']['AMP'][0]
103
+ Feat_AMPs = config['model_params']['Audio2Mouth']['AMP'][1:]
104
+ rot_AMP, trans_AMP = config['model_params']['Headpose']['AMP']
105
+ shoulder_AMP = config['model_params']['Headpose']['shoulder_AMP']
106
+ save_feature_maps = config['model_params']['Image2Image']['save_input']
107
+
108
+ #### common settings
109
+ Featopt = FeatureOptions().parse()
110
+ Headopt = HeadposeOptions().parse()
111
+ Renderopt = RenderOptions().parse()
112
+ Featopt.load_epoch = config['model_params']['Audio2Mouth']['ckp_path']
113
+ Headopt.load_epoch = config['model_params']['Headpose']['ckp_path']
114
+ Renderopt.dataroot = config['dataset_params']['root']
115
+ Renderopt.load_epoch = config['model_params']['Image2Image']['ckp_path']
116
+ Renderopt.size = config['model_params']['Image2Image']['size']
117
+
118
+ ############################# Load Models #################################
119
+ print('---------- Loading Model: APC-------------')
120
+ APC_model = APC_encoder(config['model_params']['APC']['mel_dim'],
121
+ config['model_params']['APC']['hidden_size'],
122
+ config['model_params']['APC']['num_layers'],
123
+ config['model_params']['APC']['residual'])
124
+ # load all 5 here?
125
+ APC_model.load_state_dict(torch.load(config['model_params']['APC']['ckp_path']), strict=False)
126
+ APC_model.cuda()
127
+ APC_model.eval()
128
+ print('---------- Loading Model: {} -------------'.format(Featopt.task))
129
+ Audio2Feature = create_model(Featopt)
130
+ Audio2Feature.setup(Featopt)
131
+ Audio2Feature.eval()
132
+ print('---------- Loading Model: {} -------------'.format(Headopt.task))
133
+ Audio2Headpose = create_model(Headopt)
134
+ Audio2Headpose.setup(Headopt)
135
+ Audio2Headpose.eval()
136
+ if Headopt.feature_decoder == 'WaveNet':
137
+ Headopt.A2H_receptive_field = Audio2Headpose.Audio2Headpose.module.WaveNet.receptive_field
138
+ print('---------- Loading Model: {} -------------'.format(Renderopt.task))
139
+ facedataset = create_dataset(Renderopt)
140
+ Feature2Face = create_model(Renderopt)
141
+ Feature2Face.setup(Renderopt)
142
+ Feature2Face.eval()
143
+ visualizer = Visualizer(Renderopt)
144
+
145
+ # check audio duration and trim audio
146
+ extension_name = os.path.basename(opt.driving_audio).split('.')[-1]
147
+ audio_threshold = 10
148
+ duration = librosa.get_duration(filename=opt.driving_audio)
149
+ if duration > audio_threshold:
150
+ print(f'audio file is longer than {audio_threshold} seconds, trimming the first {audio_threshold} seconds '
151
+ f'for further processing')
152
+ ffmpeg_extract_subclip(opt.driving_audio, 0, audio_threshold, targetname=f'shorter_input.{extension_name}')
153
+ opt.driving_audio = f'shorter_input.{extension_name}'
154
+
155
+ # create the results folder
156
+ audio_name = os.path.basename(opt.driving_audio).split('.')[0]
157
+ save_root = join('results', opt.id, audio_name)
158
+ os.makedirs(save_root, exist_ok=True)
159
+ clean_folder(save_root)
160
+ out_path = Path(tempfile.mkdtemp()) / "out.mp4"
161
+
162
+ ############################## Inference ##################################
163
+ print('Processing audio: {} ...'.format(audio_name))
164
+ # read audio
165
+ audio, _ = librosa.load(opt.driving_audio, sr=sr)
166
+ total_frames = np.int32(audio.shape[0] / sr * FPS)
167
+
168
+ #### 1. compute APC features
169
+ print('1. Computing APC features...')
170
+ mel80 = utils.compute_mel_one_sequence(audio)
171
+ mel_nframe = mel80.shape[0]
172
+ with torch.no_grad():
173
+ length = torch.Tensor([mel_nframe])
174
+ mel80_torch = torch.from_numpy(mel80.astype(np.float32)).cuda().unsqueeze(0)
175
+ hidden_reps = APC_model.forward(mel80_torch, length)[0] # [mel_nframe, 512]
176
+ hidden_reps = hidden_reps.cpu().numpy()
177
+ audio_feats = hidden_reps
178
+
179
+ #### 2. manifold projection
180
+ if use_LLE:
181
+ print('2. Manifold projection...')
182
+ ind = utils.KNN_with_torch(audio_feats, APC_feat_database, K=Knear)
183
+ weights, feat_fuse = utils.compute_LLE_projection_all_frame(audio_feats, APC_feat_database, ind,
184
+ audio_feats.shape[0])
185
+ audio_feats = audio_feats * (1 - LLE_percent) + feat_fuse * LLE_percent
186
+
187
+ #### 3. Audio2Mouth
188
+ print('3. Audio2Mouth inference...')
189
+ pred_Feat = Audio2Feature.generate_sequences(audio_feats, sr, FPS, fill_zero=True, opt=Featopt)
190
+
191
+ #### 4. Audio2Headpose
192
+ print('4. Headpose inference...')
193
+ # set history headposes as zero
194
+ pre_headpose = np.zeros(Headopt.A2H_wavenet_input_channels, np.float32)
195
+ pred_Head = Audio2Headpose.generate_sequences(audio_feats, pre_headpose, fill_zero=True, sigma_scale=0.3,
196
+ opt=Headopt)
197
+
198
+ #### 5. Post-Processing
199
+ print('5. Post-processing...')
200
+ nframe = min(pred_Feat.shape[0], pred_Head.shape[0])
201
+ pred_pts3d = np.zeros([nframe, 73, 3])
202
+ pred_pts3d[:, mouth_indices] = pred_Feat.reshape(-1, 25, 3)[:nframe]
203
+
204
+ ## mouth
205
+ pred_pts3d = utils.landmark_smooth_3d(pred_pts3d, Feat_smooth_sigma, area='only_mouth')
206
+ pred_pts3d = utils.mouth_pts_AMP(pred_pts3d, True, AMP_method, Feat_AMPs)
207
+ pred_pts3d = pred_pts3d + mean_pts3d
208
+ pred_pts3d = utils.solve_intersect_mouth(pred_pts3d) # solve intersect lips if exist
209
+
210
+ ## headpose
211
+ pred_Head[:, 0:3] *= rot_AMP
212
+ pred_Head[:, 3:6] *= trans_AMP
213
+ pred_headpose = utils.headpose_smooth(pred_Head[:, :6], Head_smooth_sigma).astype(np.float32)
214
+ pred_headpose[:, 3:] += mean_translation
215
+ pred_headpose[:, 0] += 180
216
+
217
+ ## compute projected landmarks
218
+ pred_landmarks = np.zeros([nframe, 73, 2], dtype=np.float32)
219
+ final_pts3d = np.zeros([nframe, 73, 3], dtype=np.float32)
220
+ final_pts3d[:] = std_mean_pts3d.copy()
221
+ final_pts3d[:, 46:64] = pred_pts3d[:nframe, 46:64]
222
+ for k in tqdm(range(nframe)):
223
+ ind = k % candidate_eye_brow.shape[0]
224
+ final_pts3d[k, eye_brow_indices] = candidate_eye_brow[ind] + mean_pts3d[eye_brow_indices]
225
+ pred_landmarks[k], _, _ = utils.project_landmarks(camera_intrinsic, camera.relative_rotation,
226
+ camera.relative_translation, scale,
227
+ pred_headpose[k], final_pts3d[k])
228
+
229
+ ## Upper Body Motion
230
+ pred_shoulders = np.zeros([nframe, 18, 2], dtype=np.float32)
231
+ pred_shoulders3D = np.zeros([nframe, 18, 3], dtype=np.float32)
232
+ for k in range(nframe):
233
+ diff_trans = pred_headpose[k][3:] - ref_trans
234
+ pred_shoulders3D[k] = shoulder3D + diff_trans * shoulder_AMP
235
+ # project
236
+ project = camera_intrinsic.dot(pred_shoulders3D[k].T)
237
+ project[:2, :] /= project[2, :] # divide z
238
+ pred_shoulders[k] = project[:2, :].T
239
+
240
+ #### 6. Image2Image translation & Save resuls
241
+ print('6. Image2Image translation & Saving results...')
242
+ for ind in tqdm(range(0, nframe), desc='Image2Image translation inference'):
243
+ # feature_map: [input_nc, h, w]
244
+ current_pred_feature_map = facedataset.dataset.get_data_test_mode(pred_landmarks[ind],
245
+ pred_shoulders[ind],
246
+ facedataset.dataset.image_pad)
247
+ input_feature_maps = current_pred_feature_map.unsqueeze(0).cuda()
248
+ pred_fake = Feature2Face.inference(input_feature_maps, img_candidates)
249
+ # save results
250
+ visual_list = [('pred', util.tensor2im(pred_fake[0]))]
251
+ if save_feature_maps:
252
+ visual_list += [('input', np.uint8(current_pred_feature_map[0].cpu().numpy() * 255))]
253
+ visuals = OrderedDict(visual_list)
254
+ visualizer.save_images(save_root, visuals, str(ind + 1))
255
+
256
+ ## make videos
257
+ # generate corresponding audio, reused for all results
258
+ tmp_audio_path = join(save_root, 'tmp.wav')
259
+ tmp_audio_clip = audio[: np.int32(nframe * sr / FPS)]
260
+ librosa.output.write_wav(tmp_audio_path, tmp_audio_clip, sr)
261
+
262
+ def write_video_with_audio(audio_path, output_path, prefix='pred_'):
263
+ fps, fourcc = 60, cv2.VideoWriter_fourcc(*'DIVX')
264
+ video_tmp_path = join(save_root, 'tmp.avi')
265
+ out = cv2.VideoWriter(video_tmp_path, fourcc, fps, (Renderopt.loadSize, Renderopt.loadSize))
266
+ for j in tqdm(range(nframe), position=0, desc='writing video'):
267
+ img = cv2.imread(join(save_root, prefix + str(j + 1) + '.jpg'))
268
+ out.write(img)
269
+ out.release()
270
+ cmd = 'ffmpeg -i "' + video_tmp_path + '" -i "' + audio_path + '" -codec copy -shortest "' + output_path + '"'
271
+ subprocess.call(cmd, shell=True)
272
+ os.remove(video_tmp_path) # remove the template video
273
+
274
+ temp_out = 'temp_video.avi'
275
+ write_video_with_audio(tmp_audio_path, temp_out, 'pred_')
276
+ # convert to mp4
277
+ cmd = ("ffmpeg -i "
278
+ + temp_out + " -strict -2 "
279
+ + str(out_path)
280
+ )
281
+ subprocess.call(cmd, shell=True)
282
+
283
+ if os.path.exists(tmp_audio_path):
284
+ os.remove(tmp_audio_path)
285
+ if os.path.exists(temp_out):
286
+ os.remove(temp_out)
287
+ if os.path.exists(f'shorter_input.{extension_name}'):
288
+ os.remove(f'shorter_input.{extension_name}')
289
+ if not opt.save_intermediates:
290
+ _img_paths = list(map(lambda x: str(x), list(Path(save_root).glob('*.jpg'))))
291
+ for i in tqdm(range(len(_img_paths)), desc='deleting intermediate images'):
292
+ os.remove(_img_paths[i])
293
+
294
+ print('Finish!')
295
+
296
+ return out_path
297
+
298
+
299
+ def clean_folder(folder):
300
+ for filename in os.listdir(folder):
301
+ file_path = os.path.join(folder, filename)
302
+ try:
303
+ if os.path.isfile(file_path) or os.path.islink(file_path):
304
+ os.unlink(file_path)
305
+ elif os.path.isdir(file_path):
306
+ shutil.rmtree(file_path)
307
+ except Exception as e:
308
+ print('Failed to delete %s. Reason: %s' % (file_path, e))
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ tqdm
2
+ librosa==0.7.0
3
+ scikit_image
4
+ opencv_python==4.4.0.40
5
+ scipy
6
+ dominate
7
+ albumentations==0.5.2
8
+ numpy
9
+ beautifulsoup4
10
+ scikit-image