Spanicin commited on
Commit
f90a52f
·
verified ·
1 Parent(s): 7989b7e

Delete src/facerender/animate.py

Browse files
Files changed (1) hide show
  1. src/facerender/animate.py +0 -219
src/facerender/animate.py DELETED
@@ -1,219 +0,0 @@
1
- import os
2
- import cv2
3
- import yaml
4
- import numpy as np
5
- import warnings
6
- from skimage import img_as_ubyte
7
-
8
- warnings.filterwarnings('ignore')
9
-
10
-
11
- import imageio
12
- import torch
13
- import torchvision
14
-
15
-
16
- from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
17
- from src.facerender.modules.mapping import MappingNet
18
- from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
19
- from src.facerender.modules.make_animation import make_animation
20
-
21
- from pydub import AudioSegment
22
- from src.utils.face_enhancer import enhancer as face_enhancer
23
- from src.utils.paste_pic import paste_pic
24
- from src.utils.videoio import save_video_with_watermark
25
-
26
-
27
- class AnimateFromCoeff():
28
-
29
- def __init__(self, free_view_checkpoint, mapping_checkpoint,
30
- config_path, device):
31
-
32
- with open(config_path) as f:
33
- config = yaml.safe_load(f)
34
-
35
- generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
36
- **config['model_params']['common_params'])
37
- kp_extractor = KPDetector(**config['model_params']['kp_detector_params'],
38
- **config['model_params']['common_params'])
39
- he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],
40
- **config['model_params']['common_params'])
41
- mapping = MappingNet(**config['model_params']['mapping_params'])
42
-
43
-
44
- generator.to(device)
45
- kp_extractor.to(device)
46
- he_estimator.to(device)
47
- mapping.to(device)
48
- for param in generator.parameters():
49
- param.requires_grad = False
50
- for param in kp_extractor.parameters():
51
- param.requires_grad = False
52
- for param in he_estimator.parameters():
53
- param.requires_grad = False
54
- for param in mapping.parameters():
55
- param.requires_grad = False
56
-
57
- if free_view_checkpoint is not None:
58
- self.load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
59
- else:
60
- raise AttributeError("Checkpoint should be specified for video head pose estimator.")
61
-
62
- if mapping_checkpoint is not None:
63
- self.load_cpk_mapping(mapping_checkpoint, mapping=mapping)
64
- else:
65
- raise AttributeError("Checkpoint should be specified for video head pose estimator.")
66
-
67
- self.kp_extractor = kp_extractor
68
- self.generator = generator
69
- self.he_estimator = he_estimator
70
- self.mapping = mapping
71
-
72
- self.kp_extractor.eval()
73
- self.generator.eval()
74
- self.he_estimator.eval()
75
- self.mapping.eval()
76
-
77
- self.device = device
78
-
79
- def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
80
- kp_detector=None, he_estimator=None, optimizer_generator=None,
81
- optimizer_discriminator=None, optimizer_kp_detector=None,
82
- optimizer_he_estimator=None, device="cpu"):
83
- checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
84
- if generator is not None:
85
- generator.load_state_dict(checkpoint['generator'])
86
- if kp_detector is not None:
87
- kp_detector.load_state_dict(checkpoint['kp_detector'])
88
- if he_estimator is not None:
89
- he_estimator.load_state_dict(checkpoint['he_estimator'])
90
- if discriminator is not None:
91
- try:
92
- discriminator.load_state_dict(checkpoint['discriminator'])
93
- except:
94
- print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
95
- if optimizer_generator is not None:
96
- optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
97
- if optimizer_discriminator is not None:
98
- try:
99
- optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
100
- except RuntimeError as e:
101
- print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
102
- if optimizer_kp_detector is not None:
103
- optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
104
- if optimizer_he_estimator is not None:
105
- optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator'])
106
-
107
- return checkpoint['epoch']
108
-
109
- def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
110
- optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
111
- checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
112
- if mapping is not None:
113
- mapping.load_state_dict(checkpoint['mapping'])
114
- if discriminator is not None:
115
- discriminator.load_state_dict(checkpoint['discriminator'])
116
- if optimizer_mapping is not None:
117
- optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
118
- if optimizer_discriminator is not None:
119
- optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
120
-
121
- return checkpoint['epoch']
122
-
123
- def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop'):
124
- source_image=x['source_image'].type(torch.FloatTensor)
125
- source_semantics=x['source_semantics'].type(torch.FloatTensor)
126
- target_semantics=x['target_semantics_list'].type(torch.FloatTensor)
127
- source_image=source_image.to(self.device)
128
- source_semantics=source_semantics.to(self.device)
129
- target_semantics=target_semantics.to(self.device)
130
- if 'yaw_c_seq' in x:
131
- yaw_c_seq = x['yaw_c_seq'].type(torch.FloatTensor)
132
- yaw_c_seq = x['yaw_c_seq'].to(self.device)
133
- else:
134
- yaw_c_seq = None
135
- if 'pitch_c_seq' in x:
136
- pitch_c_seq = x['pitch_c_seq'].type(torch.FloatTensor)
137
- pitch_c_seq = x['pitch_c_seq'].to(self.device)
138
- else:
139
- pitch_c_seq = None
140
- if 'roll_c_seq' in x:
141
- roll_c_seq = x['roll_c_seq'].type(torch.FloatTensor)
142
- roll_c_seq = x['roll_c_seq'].to(self.device)
143
- else:
144
- roll_c_seq = None
145
-
146
- frame_num = x['frame_num']
147
-
148
- predictions_video = make_animation(source_image, source_semantics, target_semantics,
149
- self.generator, self.kp_extractor, self.he_estimator, self.mapping,
150
- yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True)
151
-
152
- predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
153
- predictions_video = predictions_video[:frame_num]
154
-
155
- video = []
156
- for idx in range(predictions_video.shape[0]):
157
- image = predictions_video[idx]
158
- image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
159
- video.append(image)
160
- result = img_as_ubyte(video)
161
-
162
- ### the generated video is 256x256, so we keep the aspect ratio,
163
- original_size = crop_info[0]
164
- if original_size:
165
- result = [ cv2.resize(result_i,(256, int(256.0 * original_size[1]/original_size[0]) )) for result_i in result ]
166
-
167
- video_name = x['video_name'] + '.mp4'
168
- path = os.path.join(video_save_dir, 'temp_'+video_name)
169
-
170
- imageio.mimsave(path, result, fps=float(25))
171
-
172
- av_path = os.path.join(video_save_dir, video_name)
173
- return_path = av_path
174
-
175
- audio_path = x['audio_path']
176
- audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
177
- new_audio_path = os.path.join(video_save_dir, audio_name+'.wav')
178
- print('new_audio_path',new_audio_path)
179
- start_time = 0
180
- # cog will not keep the .mp3 filename
181
- sound = AudioSegment.from_file(audio_path)
182
- frames = frame_num
183
- end_time = start_time + frames*1/25*1000
184
- word1=sound.set_frame_rate(16000)
185
- word = word1[start_time:end_time]
186
- word.export(new_audio_path, format="wav")
187
-
188
- base64_video,temp_file_path = save_video_with_watermark(path, new_audio_path, av_path, watermark= False)
189
- print(f'The generated video is named {video_name} in {video_save_dir}')
190
-
191
- if preprocess.lower() == 'full':
192
- # only add watermark to the full image.
193
- video_name_full = x['video_name'] + '_full.mp4'
194
- full_video_path = os.path.join(video_save_dir, video_name_full)
195
- return_path = full_video_path
196
- base64_video,temp_file_path = paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path)
197
- print(f'The generated video is named {video_save_dir}/{video_name_full}')
198
- else:
199
- full_video_path = av_path
200
-
201
- #### paste back then enhancers
202
- if enhancer:
203
- video_name_enhancer = x['video_name'] + '_enhanced.mp4'
204
- enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
205
- av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
206
- return_path = av_path_enhancer
207
- enhanced_images = face_enhancer(temp_file_path, method=enhancer, bg_upsampler=background_enhancer)
208
-
209
- imageio.mimsave(enhanced_path, enhanced_images, fps=float(25))
210
-
211
- base64_video,temp_file_path = save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False)
212
- print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
213
- os.remove(enhanced_path)
214
-
215
- os.remove(path)
216
- # os.remove(new_audio_path)
217
-
218
- return return_path,base64_video,temp_file_path, new_audio_path
219
-