Spanicin commited on
Commit
7f191c3
·
verified ·
1 Parent(s): f90a52f

Upload animate.py

Browse files
Files changed (1) hide show
  1. src/facerender/animate.py +262 -0
src/facerender/animate.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import yaml
4
+ import numpy as np
5
+ import warnings
6
+ from skimage import img_as_ubyte
7
+
8
+ warnings.filterwarnings('ignore')
9
+
10
+
11
+ import imageio
12
+ import torch
13
+ import torchvision
14
+
15
+
16
+ from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
17
+ from src.facerender.modules.mapping import MappingNet
18
+ from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
19
+ from src.facerender.modules.make_animation import make_animation
20
+
21
+ from pydub import AudioSegment
22
+ from src.utils.face_enhancer import enhancer as face_enhancer
23
+ from src.utils.paste_pic import paste_pic
24
+ from src.utils.videoio import save_video_with_watermark
25
+
26
+
27
+ class AnimateFromCoeff():
28
+
29
+ def __init__(self, free_view_checkpoint, mapping_checkpoint,
30
+ config_path, device):
31
+
32
+ with open(config_path) as f:
33
+ config = yaml.safe_load(f)
34
+
35
+ generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
36
+ **config['model_params']['common_params'])
37
+ kp_extractor = KPDetector(**config['model_params']['kp_detector_params'],
38
+ **config['model_params']['common_params'])
39
+ he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],
40
+ **config['model_params']['common_params'])
41
+ mapping = MappingNet(**config['model_params']['mapping_params'])
42
+
43
+
44
+ generator.to(device)
45
+ kp_extractor.to(device)
46
+ he_estimator.to(device)
47
+ mapping.to(device)
48
+
49
+ # Wrap models in DataParallel for multi-GPU support
50
+ if torch.cuda.device_count() > 1:
51
+ print(f"Using {torch.cuda.device_count()} GPUs")
52
+ generator = torch.nn.DataParallel(generator)
53
+ kp_extractor = torch.nn.DataParallel(kp_extractor)
54
+ he_estimator = torch.nn.DataParallel(he_estimator)
55
+ mapping = torch.nn.DataParallel(mapping)
56
+
57
+ for param in generator.parameters():
58
+ param.requires_grad = False
59
+ for param in kp_extractor.parameters():
60
+ param.requires_grad = False
61
+ for param in he_estimator.parameters():
62
+ param.requires_grad = False
63
+ for param in mapping.parameters():
64
+ param.requires_grad = False
65
+
66
+ if free_view_checkpoint is not None:
67
+ self.load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
68
+ else:
69
+ raise AttributeError("Checkpoint should be specified for video head pose estimator.")
70
+
71
+ if mapping_checkpoint is not None:
72
+ self.load_cpk_mapping(mapping_checkpoint, mapping=mapping)
73
+ else:
74
+ raise AttributeError("Checkpoint should be specified for video head pose estimator.")
75
+
76
+ self.kp_extractor = kp_extractor
77
+ self.generator = generator
78
+ self.he_estimator = he_estimator
79
+ self.mapping = mapping
80
+
81
+ self.kp_extractor.eval()
82
+ self.generator.eval()
83
+ self.he_estimator.eval()
84
+ self.mapping.eval()
85
+
86
+ self.device = device
87
+
88
+ def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
89
+ kp_detector=None, he_estimator=None, optimizer_generator=None,
90
+ optimizer_discriminator=None, optimizer_kp_detector=None,
91
+ optimizer_he_estimator=None, device="cpu"):
92
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
93
+
94
+ def adjust_state_dict(state_dict, model):
95
+ new_state_dict = {}
96
+ if isinstance(model, torch.nn.DataParallel):
97
+ for k, v in state_dict.items():
98
+ new_key = f"module.{k}" # Add 'module.' prefix
99
+ new_state_dict[new_key] = v
100
+ else:
101
+ new_state_dict = state_dict # Keep original state_dict for single-GPU models
102
+ return new_state_dict
103
+
104
+ if generator is not None:
105
+ generator_state_dict = adjust_state_dict(checkpoint['generator'], generator)
106
+ generator.load_state_dict(generator_state_dict)
107
+ if kp_detector is not None:
108
+ kp_state_dict = adjust_state_dict(checkpoint['kp_detector'],kp_detector)
109
+ kp_detector.load_state_dict(kp_state_dict)
110
+ if he_estimator is not None:
111
+ he_state_dict = adjust_state_dict(checkpoint['he_estimator'],he_estimator)
112
+ he_estimator.load_state_dict(he_state_dict)
113
+ if discriminator is not None:
114
+ try:
115
+ discriminator_dict =adjust_state_dict(checkpoint['discriminator'],discriminator)
116
+ discriminator.load_state_dict(discriminator_dict)
117
+ except:
118
+ print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
119
+ if optimizer_generator is not None:
120
+ optimizer_generator_dict =adjust_state_dict(checkpoint['optimizer_generator'],optimizer_generator)
121
+ optimizer_generator.load_state_dict(optimizer_generator_dict)
122
+ if optimizer_discriminator is not None:
123
+ try:
124
+ optimizer_discriminator_dict = adjust_state_dict(checkpoint['optimizer_discriminator'],optimizer_discriminator)
125
+ optimizer_discriminator.load_state_dict(optimizer_discriminator_dict)
126
+ except RuntimeError as e:
127
+ print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
128
+ if optimizer_kp_detector is not None:
129
+ optimizer_kp_detector_dict = adjust_state_dict(checkpoint['optimizer_kp_detector'],optimizer_kp_detector)
130
+ optimizer_kp_detector.load_state_dict(optimizer_kp_detector_dict)
131
+ if optimizer_he_estimator is not None:
132
+ optimizer_he_estimator_dict = adjust_state_dict(checkpoint['optimizer_he_estimator'],optimizer_he_estimator)
133
+ optimizer_he_estimator.load_state_dict(optimizer_he_estimator_dict)
134
+
135
+ return checkpoint['epoch']
136
+
137
+ def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
138
+ optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
139
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
140
+
141
+ def adjust_state_dict(state_dict, model):
142
+ new_state_dict = {}
143
+ if isinstance(model, torch.nn.DataParallel):
144
+ for k, v in state_dict.items():
145
+ new_key = f"module.{k}" # Add 'module.' prefix
146
+ new_state_dict[new_key] = v
147
+ else:
148
+ new_state_dict = state_dict # Keep original state_dict for single-GPU models
149
+ return new_state_dict
150
+
151
+ if mapping is not None:
152
+ mapping_dict = adjust_state_dict(checkpoint['mapping'],mapping)
153
+ mapping.load_state_dict(mapping_dict)
154
+ if discriminator is not None:
155
+ discriminator_dict = adjust_state_dict(checkpoint['discriminator'],discriminator)
156
+ discriminator.load_state_dict(discriminator_dict)
157
+ if optimizer_mapping is not None:
158
+ optimizer_mapping_dict = adjust_state_dict(checkpoint['optimizer_mapping'],optimizer_mapping)
159
+ optimizer_mapping.load_state_dict(optimizer_mapping_dict)
160
+ if optimizer_discriminator is not None:
161
+ optimizer_discriminator_dict = adjust_state_dict(checkpoint['optimizer_discriminator'],optimizer_discriminator)
162
+ optimizer_discriminator.load_state_dict(optimizer_discriminator_dict)
163
+
164
+ return checkpoint['epoch']
165
+
166
+ def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop'):
167
+ source_image=x['source_image'].type(torch.FloatTensor)
168
+ source_semantics=x['source_semantics'].type(torch.FloatTensor)
169
+ target_semantics=x['target_semantics_list'].type(torch.FloatTensor)
170
+ source_image=source_image.to(self.device)
171
+ source_semantics=source_semantics.to(self.device)
172
+ target_semantics=target_semantics.to(self.device)
173
+ if 'yaw_c_seq' in x:
174
+ yaw_c_seq = x['yaw_c_seq'].type(torch.FloatTensor)
175
+ yaw_c_seq = x['yaw_c_seq'].to(self.device)
176
+ else:
177
+ yaw_c_seq = None
178
+ if 'pitch_c_seq' in x:
179
+ pitch_c_seq = x['pitch_c_seq'].type(torch.FloatTensor)
180
+ pitch_c_seq = x['pitch_c_seq'].to(self.device)
181
+ else:
182
+ pitch_c_seq = None
183
+ if 'roll_c_seq' in x:
184
+ roll_c_seq = x['roll_c_seq'].type(torch.FloatTensor)
185
+ roll_c_seq = x['roll_c_seq'].to(self.device)
186
+ else:
187
+ roll_c_seq = None
188
+
189
+ frame_num = x['frame_num']
190
+
191
+ predictions_video = make_animation(source_image, source_semantics, target_semantics,
192
+ self.generator, self.kp_extractor, self.he_estimator, self.mapping,
193
+ yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True)
194
+
195
+ predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
196
+ predictions_video = predictions_video[:frame_num]
197
+
198
+ video = []
199
+ for idx in range(predictions_video.shape[0]):
200
+ image = predictions_video[idx]
201
+ image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
202
+ video.append(image)
203
+ result = img_as_ubyte(video)
204
+
205
+ ### the generated video is 256x256, so we keep the aspect ratio,
206
+ original_size = crop_info[0]
207
+ if original_size:
208
+ result = [ cv2.resize(result_i,(256, int(256.0 * original_size[1]/original_size[0]) )) for result_i in result ]
209
+
210
+ video_name = x['video_name'] + '.mp4'
211
+ path = os.path.join(video_save_dir, 'temp_'+video_name)
212
+
213
+ imageio.mimsave(path, result, fps=float(25))
214
+
215
+ av_path = os.path.join(video_save_dir, video_name)
216
+ return_path = av_path
217
+
218
+ audio_path = x['audio_path']
219
+ audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
220
+ new_audio_path = os.path.join(video_save_dir, audio_name+'.wav')
221
+ print('new_audio_path',new_audio_path)
222
+ start_time = 0
223
+ # cog will not keep the .mp3 filename
224
+ sound = AudioSegment.from_file(audio_path)
225
+ frames = frame_num
226
+ end_time = start_time + frames*1/25*1000
227
+ word1=sound.set_frame_rate(16000)
228
+ word = word1[start_time:end_time]
229
+ word.export(new_audio_path, format="wav")
230
+
231
+ base64_video,temp_file_path = save_video_with_watermark(path, new_audio_path, av_path, watermark= False)
232
+ print(f'The generated video is named {video_name} in {video_save_dir}')
233
+
234
+ if preprocess.lower() == 'full':
235
+ # only add watermark to the full image.
236
+ video_name_full = x['video_name'] + '_full.mp4'
237
+ full_video_path = os.path.join(video_save_dir, video_name_full)
238
+ return_path = full_video_path
239
+ base64_video,temp_file_path = paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path)
240
+ print(f'The generated video is named {video_save_dir}/{video_name_full}')
241
+ else:
242
+ full_video_path = av_path
243
+
244
+ #### paste back then enhancers
245
+ if enhancer:
246
+ video_name_enhancer = x['video_name'] + '_enhanced.mp4'
247
+ enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
248
+ av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
249
+ return_path = av_path_enhancer
250
+ enhanced_images = face_enhancer(temp_file_path, method=enhancer, bg_upsampler=background_enhancer)
251
+
252
+ imageio.mimsave(enhanced_path, enhanced_images, fps=float(25))
253
+
254
+ base64_video,temp_file_path = save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False)
255
+ print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
256
+ os.remove(enhanced_path)
257
+
258
+ os.remove(path)
259
+ # os.remove(new_audio_path)
260
+
261
+ return return_path,base64_video,temp_file_path, new_audio_path
262
+