Delik commited on
Commit
d85560f
·
verified ·
1 Parent(s): 67ae65b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +397 -4
app.py CHANGED
@@ -1,9 +1,402 @@
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!"
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="textbox", outputs="textbox")
 
 
 
 
7
 
8
- if __name__ == "__main__":
9
  demo.launch()
 
1
+
2
+ import argparse
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
  import gradio as gr
9
+ import shutil
10
+ import librosa
11
+ import python_speech_features
12
+ import time
13
+ from LIA_Model import LIA_Model
14
+ import os
15
+ from tqdm import tqdm
16
+ import argparse
17
+ import numpy as np
18
+ from torchvision import transforms
19
+ from templates import *
20
+ import argparse
21
+ import shutil
22
+ from moviepy.editor import *
23
+ import librosa
24
+ import python_speech_features
25
+ import importlib.util
26
+ import time
27
+ import os
28
+ import time
29
+ import numpy as np
30
+
31
+
32
+ # Disable Gradio analytics to avoid network-related issues
33
+ gr.analytics_enabled = False
34
+
35
+
36
+ def check_package_installed(package_name):
37
+ package_spec = importlib.util.find_spec(package_name)
38
+ if package_spec is None:
39
+ print(f"{package_name} is not installed.")
40
+ return False
41
+ else:
42
+ print(f"{package_name} is installed.")
43
+ return True
44
+
45
+ def frames_to_video(input_path, audio_path, output_path, fps=25):
46
+ image_files = [os.path.join(input_path, img) for img in sorted(os.listdir(input_path))]
47
+ clips = [ImageClip(m).set_duration(1/fps) for m in image_files]
48
+ video = concatenate_videoclips(clips, method="compose")
49
+
50
+ audio = AudioFileClip(audio_path)
51
+ final_video = video.set_audio(audio)
52
+ final_video.write_videofile(output_path, fps=fps, codec='libx264', audio_codec='aac')
53
+
54
+ def load_image(filename, size):
55
+ img = Image.open(filename).convert('RGB')
56
+ img = img.resize((size, size))
57
+ img = np.asarray(img)
58
+ img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
59
+ return img / 255.0
60
+
61
+ def img_preprocessing(img_path, size):
62
+ img = load_image(img_path, size) # [0, 1]
63
+ img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
64
+ imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
65
+ return imgs_norm
66
+
67
+ def saved_image(img_tensor, img_path):
68
+ toPIL = transforms.ToPILImage()
69
+ img = toPIL(img_tensor.detach().cpu().squeeze(0)) # 使用squeeze(0)来移除批次维度
70
+ img.save(img_path)
71
+
72
+ def main(args):
73
+ frames_result_saved_path = os.path.join(args.result_path, 'frames')
74
+ os.makedirs(frames_result_saved_path, exist_ok=True)
75
+ test_image_name = os.path.splitext(os.path.basename(args.test_image_path))[0]
76
+ audio_name = os.path.splitext(os.path.basename(args.test_audio_path))[0]
77
+ predicted_video_256_path = os.path.join(args.result_path, f'{test_image_name}-{audio_name}.mp4')
78
+ predicted_video_512_path = os.path.join(args.result_path, f'{test_image_name}-{audio_name}_SR.mp4')
79
+
80
+ #======Loading Stage 1 model=========
81
+ lia = LIA_Model(motion_dim=args.motion_dim, fusion_type='weighted_sum')
82
+ lia.load_lightning_model(args.stage1_checkpoint_path)
83
+ lia.to(args.device)
84
+ #============================
85
+
86
+ conf = ffhq256_autoenc()
87
+ conf.seed = args.seed
88
+ conf.decoder_layers = args.decoder_layers
89
+ conf.infer_type = args.infer_type
90
+ conf.motion_dim = args.motion_dim
91
+
92
+ if args.infer_type == 'mfcc_full_control':
93
+ conf.face_location=True
94
+ conf.face_scale=True
95
+ conf.mfcc = True
96
+ elif args.infer_type == 'mfcc_pose_only':
97
+ conf.face_location=False
98
+ conf.face_scale=False
99
+ conf.mfcc = True
100
+ elif args.infer_type == 'hubert_pose_only':
101
+ conf.face_location=False
102
+ conf.face_scale=False
103
+ conf.mfcc = False
104
+ elif args.infer_type == 'hubert_audio_only':
105
+ conf.face_location=False
106
+ conf.face_scale=False
107
+ conf.mfcc = False
108
+ elif args.infer_type == 'hubert_full_control':
109
+ conf.face_location=True
110
+ conf.face_scale=True
111
+ conf.mfcc = False
112
+ else:
113
+ print('Type NOT Found!')
114
+ exit(0)
115
+
116
+ if not os.path.exists(args.test_image_path):
117
+ print(f'{args.test_image_path} does not exist!')
118
+ exit(0)
119
+
120
+ if not os.path.exists(args.test_audio_path):
121
+ print(f'{args.test_audio_path} does not exist!')
122
+ exit(0)
123
+
124
+ img_source = img_preprocessing(args.test_image_path, args.image_size).to(args.device)
125
+ one_shot_lia_start, one_shot_lia_direction, feats = lia.get_start_direction_code(img_source, img_source, img_source, img_source)
126
+
127
+ #======Loading Stage 2 model=========
128
+ model = LitModel(conf)
129
+ state = torch.load(args.stage2_checkpoint_path, map_location='cpu')
130
+ model.load_state_dict(state, strict=True)
131
+ model.ema_model.eval()
132
+ model.ema_model.to(args.device)
133
+ #=================================
134
+
135
+ #======Audio Input=========
136
+ if conf.infer_type.startswith('mfcc'):
137
+ # MFCC features
138
+ wav, sr = librosa.load(args.test_audio_path, sr=16000)
139
+ input_values = python_speech_features.mfcc(signal=wav, samplerate=sr, numcep=13, winlen=0.025, winstep=0.01)
140
+ d_mfcc_feat = python_speech_features.base.delta(input_values, 1)
141
+ d_mfcc_feat2 = python_speech_features.base.delta(input_values, 2)
142
+ audio_driven_obj = np.hstack((input_values, d_mfcc_feat, d_mfcc_feat2))
143
+ frame_start, frame_end = 0, int(audio_driven_obj.shape[0]/4)
144
+ audio_start, audio_end = int(frame_start * 4), int(frame_end * 4) # The video frame is fixed to 25 hz and the audio is fixed to 100 hz
145
+
146
+ audio_driven = torch.Tensor(audio_driven_obj[audio_start:audio_end,:]).unsqueeze(0).float().to(args.device)
147
+
148
+ elif conf.infer_type.startswith('hubert'):
149
+ # Hubert features
150
+ if not os.path.exists(args.test_hubert_path):
151
+
152
+ if not check_package_installed('transformers'):
153
+ print('Please install transformers module first.')
154
+ exit(0)
155
+ hubert_model_path = './ckpts/chinese-hubert-large'
156
+ if not os.path.exists(hubert_model_path):
157
+ print('Please download the hubert weight into the ckpts path first.')
158
+ exit(0)
159
+ print('You did not extract the audio features in advance, extracting online now, which will increase processing delay')
160
+
161
+ start_time = time.time()
162
+
163
+ # load hubert model
164
+ from transformers import Wav2Vec2FeatureExtractor, HubertModel
165
+ audio_model = HubertModel.from_pretrained(hubert_model_path).to(args.device)
166
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hubert_model_path)
167
+ audio_model.feature_extractor._freeze_parameters()
168
+ audio_model.eval()
169
+
170
+ # hubert model forward pass
171
+ audio, sr = librosa.load(args.test_audio_path, sr=16000)
172
+ input_values = feature_extractor(audio, sampling_rate=16000, padding=True, do_normalize=True, return_tensors="pt").input_values
173
+ input_values = input_values.to(args.device)
174
+ ws_feats = []
175
+ with torch.no_grad():
176
+ outputs = audio_model(input_values, output_hidden_states=True)
177
+ for i in range(len(outputs.hidden_states)):
178
+ ws_feats.append(outputs.hidden_states[i].detach().cpu().numpy())
179
+ ws_feat_obj = np.array(ws_feats)
180
+ ws_feat_obj = np.squeeze(ws_feat_obj, 1)
181
+ ws_feat_obj = np.pad(ws_feat_obj, ((0, 0), (0, 1), (0, 0)), 'edge') # align the audio length with video frame
182
+
183
+ execution_time = time.time() - start_time
184
+ print(f"Extraction Audio Feature: {execution_time:.2f} Seconds")
185
+
186
+ audio_driven_obj = ws_feat_obj
187
+ else:
188
+ print(f'Using audio feature from path: {args.test_hubert_path}')
189
+ audio_driven_obj = np.load(args.test_hubert_path)
190
+
191
+ frame_start, frame_end = 0, int(audio_driven_obj.shape[1]/2)
192
+ audio_start, audio_end = int(frame_start * 2), int(frame_end * 2) # The video frame is fixed to 25 hz and the audio is fixed to 50 hz
193
+
194
+ audio_driven = torch.Tensor(audio_driven_obj[:,audio_start:audio_end,:]).unsqueeze(0).float().to(args.device)
195
+ #============================
196
+
197
+ # Diffusion Noise
198
+ noisyT = torch.randn((1,frame_end, args.motion_dim)).to(args.device)
199
+
200
+ #======Inputs for Attribute Control=========
201
+ if os.path.exists(args.pose_driven_path):
202
+ pose_obj = np.load(args.pose_driven_path)
203
+
204
+ if len(pose_obj.shape) != 2:
205
+ print('please check your pose information. The shape must be like (T, 3).')
206
+ exit(0)
207
+ if pose_obj.shape[1] != 3:
208
+ print('please check your pose information. The shape must be like (T, 3).')
209
+ exit(0)
210
+
211
+ if pose_obj.shape[0] >= frame_end:
212
+ pose_obj = pose_obj[:frame_end,:]
213
+ else:
214
+ padding = np.tile(pose_obj[-1, :], (frame_end - pose_obj.shape[0], 1))
215
+ pose_obj = np.vstack((pose_obj, padding))
216
+
217
+ pose_signal = torch.Tensor(pose_obj).unsqueeze(0).to(args.device) / 90 # 90 is for normalization here
218
+ else:
219
+ yaw_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_yaw
220
+ pitch_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_pitch
221
+ roll_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_roll
222
+ pose_signal = torch.cat((yaw_signal, pitch_signal, roll_signal), dim=-1)
223
+
224
+ pose_signal = torch.clamp(pose_signal, -1, 1)
225
+
226
+ face_location_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.face_location
227
+ face_scae_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.face_scale
228
+ #===========================================
229
+
230
+ start_time = time.time()
231
+
232
+ #======Diffusion Denosing Process=========
233
+ generated_directions = model.render(one_shot_lia_start, one_shot_lia_direction, audio_driven, face_location_signal, face_scae_signal, pose_signal, noisyT, args.step_T, control_flag=args.control_flag)
234
+ #=========================================
235
+
236
+ execution_time = time.time() - start_time
237
+ print(f"Motion Diffusion Model: {execution_time:.2f} Seconds")
238
+
239
+ generated_directions = generated_directions.detach().cpu().numpy()
240
+
241
+ start_time = time.time()
242
+ #======Rendering images frame-by-frame=========
243
+ for pred_index in tqdm(range(generated_directions.shape[1])):
244
+ ori_img_recon = lia.render(one_shot_lia_start, torch.Tensor(generated_directions[:,pred_index,:]).to(args.device), feats)
245
+ ori_img_recon = ori_img_recon.clamp(-1, 1)
246
+ wav_pred = (ori_img_recon.detach() + 1) / 2
247
+ saved_image(wav_pred, os.path.join(frames_result_saved_path, "%06d.png"%(pred_index)))
248
+ #==============================================
249
+
250
+ execution_time = time.time() - start_time
251
+ print(f"Renderer Model: {execution_time:.2f} Seconds")
252
+
253
+ frames_to_video(frames_result_saved_path, args.test_audio_path, predicted_video_256_path)
254
+
255
+ shutil.rmtree(frames_result_saved_path)
256
+
257
+ # Enhancer
258
+ if args.face_sr and check_package_installed('gfpgan'):
259
+ from face_sr.face_enhancer import enhancer_list
260
+ import imageio
261
+
262
+ # Super-resolution
263
+ imageio.mimsave(predicted_video_512_path+'.tmp.mp4', enhancer_list(predicted_video_256_path, method='gfpgan', bg_upsampler=None), fps=float(25))
264
+
265
+ # Merge audio and video
266
+ video_clip = VideoFileClip(predicted_video_512_path+'.tmp.mp4')
267
+ audio_clip = AudioFileClip(predicted_video_256_path)
268
+ final_clip = video_clip.set_audio(audio_clip)
269
+ final_clip.write_videofile(predicted_video_512_path, codec='libx264', audio_codec='aac')
270
+
271
+ os.remove(predicted_video_512_path+'.tmp.mp4')
272
+
273
+ if args.face_sr:
274
+ return predicted_video_256_path, predicted_video_512_path
275
+ else:
276
+ return predicted_video_256_path, predicted_video_256_path
277
+
278
+ def generate_video(uploaded_img, uploaded_audio, infer_type,
279
+ pose_yaw, pose_pitch, pose_roll, face_location, face_scale, step_T, device, face_sr, seed):
280
+ if uploaded_img is None or uploaded_audio is None:
281
+ return None, gr.Markdown("Error: Input image or audio file is empty. Please check and upload both files.")
282
+
283
+ model_mapping = {
284
+ "mfcc_pose_only": "./ckpts/stage2_pose_only_mfcc.ckpt",
285
+ "mfcc_full_control": "./ckpts/stage2_more_controllable_mfcc.ckpt",
286
+ "hubert_audio_only": "./ckpts/stage2_audio_only_hubert.ckpt",
287
+ "hubert_pose_only": "./ckpts/stage2_pose_only_hubert.ckpt",
288
+ "hubert_full_control": "./ckpts/stage2_full_control_hubert.ckpt",
289
+ }
290
+
291
+ # if face_crop:
292
+ # uploaded_img_path = Path(uploaded_img)
293
+ # cropped_img_path = uploaded_img_path.with_name(uploaded_img_path.stem + "_crop" + uploaded_img_path.suffix)
294
+ # crop_image(uploaded_img, cropped_img_path)
295
+ # uploaded_img = str(cropped_img_path)
296
+
297
+ # import pdb;pdb.set_trace()
298
+
299
+ stage2_checkpoint_path = model_mapping.get(infer_type, "default_checkpoint.ckpt")
300
+ try:
301
+ args = argparse.Namespace(
302
+ infer_type=infer_type,
303
+ test_image_path=uploaded_img,
304
+ test_audio_path=uploaded_audio,
305
+ test_hubert_path='',
306
+ result_path='./outputs/',
307
+ stage1_checkpoint_path='./ckpts/stage1.ckpt',
308
+ stage2_checkpoint_path=stage2_checkpoint_path,
309
+ seed=seed,
310
+ control_flag=True,
311
+ pose_yaw=pose_yaw,
312
+ pose_pitch=pose_pitch,
313
+ pose_roll=pose_roll,
314
+ face_location=face_location,
315
+ pose_driven_path='not_supported_in_this_mode',
316
+ face_scale=face_scale,
317
+ step_T=step_T,
318
+ image_size=256,
319
+ device=device,
320
+ motion_dim=20,
321
+ decoder_layers=2,
322
+ face_sr=face_sr
323
+ )
324
+
325
+ # Save the uploaded audio to the expected path
326
+ # shutil.copy(uploaded_audio, args.test_audio_path)
327
+
328
+ # Run the main function
329
+ output_256_video_path, output_512_video_path = main(args)
330
+
331
+ # Check if the output video file exists
332
+ if not os.path.exists(output_256_video_path):
333
+ return None, gr.Markdown("Error: Video generation failed. Please check your inputs and try again.")
334
+ if output_256_video_path == output_512_video_path:
335
+ return gr.Video(value=output_256_video_path), None, gr.Markdown("Video (256*256 only) generated successfully!")
336
+ return gr.Video(value=output_256_video_path), gr.Video(value=output_512_video_path), gr.Markdown("Video generated successfully!")
337
+
338
+ except Exception as e:
339
+ return None, None, gr.Markdown(f"Error: An unexpected error occurred - {str(e)}")
340
+
341
+ default_values = {
342
+ "pose_yaw": 0,
343
+ "pose_pitch": 0,
344
+ "pose_roll": 0,
345
+ "face_location": 0.5,
346
+ "face_scale": 0.5,
347
+ "step_T": 50,
348
+ "seed": 0,
349
+ "device": "cuda"
350
+ }
351
+
352
+ with gr.Blocks() as demo:
353
+ gr.Markdown('# AniTalker')
354
+ gr.Markdown('![]()')
355
+ with gr.Row():
356
+ with gr.Column():
357
+ uploaded_img = gr.Image(type="filepath", label="Reference Image")
358
+ uploaded_audio = gr.Audio(type="filepath", label="Input Audio")
359
+ with gr.Column():
360
+ output_video_256 = gr.Video(label="Generated Video (256)")
361
+ output_video_512 = gr.Video(label="Generated Video (512)")
362
+ output_message = gr.Markdown()
363
+
364
+
365
+
366
+ generate_button = gr.Button("Generate Video")
367
+
368
+ with gr.Accordion("Configuration", open=True):
369
+ infer_type = gr.Dropdown(
370
+ label="Inference Type",
371
+ choices=['mfcc_pose_only', 'mfcc_full_control', 'hubert_audio_only', 'hubert_pose_only'],
372
+ value='hubert_audio_only'
373
+ )
374
+ face_sr = gr.Checkbox(label="Enable Face Super-Resolution (512*512)", value=False)
375
+ # face_crop = gr.Checkbox(label="Face Crop (Dlib)", value=False)
376
+ # face_crop = False # TODO
377
+ seed = gr.Number(label="Seed", value=default_values["seed"])
378
+ pose_yaw = gr.Slider(label="pose_yaw", minimum=-1, maximum=1, value=default_values["pose_yaw"])
379
+ pose_pitch = gr.Slider(label="pose_pitch", minimum=-1, maximum=1, value=default_values["pose_pitch"])
380
+ pose_roll = gr.Slider(label="pose_roll", minimum=-1, maximum=1, value=default_values["pose_roll"])
381
+ face_location = gr.Slider(label="face_location", minimum=0, maximum=1, value=default_values["face_location"])
382
+ face_scale = gr.Slider(label="face_scale", minimum=0, maximum=1, value=default_values["face_scale"])
383
+ step_T = gr.Slider(label="step_T", minimum=1, maximum=100, step=1, value=default_values["step_T"])
384
+ device = gr.Radio(label="Device", choices=["cuda", "cpu"], value=default_values["device"])
385
+
386
 
387
+ generate_button.click(
388
+ generate_video,
389
+ inputs=[
390
+ uploaded_img, uploaded_audio, infer_type,
391
+ pose_yaw, pose_pitch, pose_roll, face_location, face_scale, step_T, device, face_sr, seed
392
+ ],
393
+ outputs=[output_video_256, output_video_512, output_message]
394
+ )
395
 
396
+ if __name__ == '__main__':
397
+ parser = argparse.ArgumentParser(description='EchoMimic')
398
+ parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
399
+ parser.add_argument('--server_port', type=int, default=3001, help='Server port')
400
+ args = parser.parse_args()
401
 
 
402
  demo.launch()