vinthony commited on
Commit
cff9535
·
1 Parent(s): 2299694

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -381
app.py CHANGED
@@ -1,387 +1,118 @@
1
- import pickle
2
- import time
3
- import numpy as np
4
- import scipy, cv2, os, sys, argparse
5
- from tqdm import tqdm
6
- import torch
7
- import librosa
8
- from networks import define_G
9
- from pcavs.config.AudioConfig import AudioConfig
10
-
11
- sys.path.append('spectre')
12
- from config import cfg as spectre_cfg
13
- from src.spectre import SPECTRE
14
-
15
- from audio2mesh_helper import *
16
- from pcavs.models import create_model, networks
17
-
18
- torch.manual_seed(0)
19
- from scipy.signal import savgol_filter
20
-
21
-
22
- class SimpleWrapperV2(nn.Module):
23
- def __init__(self, cfg, use_ref=True, exp_dim=53, noload=False) -> None:
24
- super().__init__()
25
-
26
- self.audio_encoder = networks.define_A_sync(cfg)
27
-
28
- self.mapping1 = nn.Linear(512+exp_dim, exp_dim)
29
- nn.init.constant_(self.mapping1.weight, 0.)
30
- nn.init.constant_(self.mapping1.bias, 0.)
31
- self.use_ref = use_ref
32
-
33
- def forward(self, x, ref, use_tanh=False):
34
- x = self.audio_encoder.forward_feature(x).view(x.size(0), -1)
35
- ref_reshape = ref.reshape(x.size(0), -1) #20, -1
36
-
37
- y = self.mapping1(torch.cat([x, ref_reshape], dim=1))
38
-
39
- if self.use_ref:
40
- out = y.reshape(ref.shape[0], ref.shape[1], -1) + ref # resudial
41
- else:
42
- out = y.reshape(ref.shape[0], ref.shape[1], -1)
43
-
44
- if use_tanh:
45
- out[:, :50] = torch.tanh(out[:, :50]) * 3
46
-
47
- return out
48
-
49
- class Audio2Mesh(object):
50
- def __init__(self, args) -> None:
51
- self.args = args
52
-
53
- spectre_cfg.model.use_tex = True
54
- spectre_cfg.model.mask_type = args.mask_type
55
- spectre_cfg.debug = self.args.debug
56
- spectre_cfg.model.netA_sync = 'ressesync'
57
- spectre_cfg.model.gpu_ids = [0]
58
-
59
- self.spectre = SPECTRE(spectre_cfg)
60
- self.spectre.eval()
61
- self.face_tracker = None #FaceTrackerV2() # face landmark detection
62
- self.mel_step_size = 16
63
- self.fps = args.fps
64
- self.Nw = args.tframes
65
- self.device = self.args.device
66
- self.image_size = self.args.image_size
67
-
68
- ### only audio
69
- args.netA_sync = 'ressesync'
70
- args.gpu_ids = [0]
71
- args.exp_dim = 53
72
- args.use_tanh = False
73
- args.K = 20
74
-
75
- self.audio2exp = 'pcavs'
76
-
77
- #
78
- self.avmodel = SimpleWrapperV2(args, exp_dim=args.exp_dim).cuda()
79
- self.avmodel.load_state_dict(torch.load('../packages/pretrained/audio2expression_v2_model.tar')['opt'])
80
-
81
- # 5, 160 = 25fps
82
- self.audio = AudioConfig(frame_rate=args.fps, num_frames_per_clip=5, hop_size=160)
83
-
84
- with open(os.path.join(args.source_dir, 'deca_infos.pkl'), 'rb') as f: # ?
85
- self.fitting_coeffs = pickle.load(f, encoding='bytes')
86
-
87
- self.coeffs_dict = { key: torch.Tensor(self.fitting_coeffs[key]).cuda().squeeze(1) for key in ['cam', 'pose', 'light', 'tex', 'shape', 'exp']}
88
-
89
- #### find the close month
90
- exp_tensors = torch.sum(self.coeffs_dict['exp'], dim=1)
91
- ssss, sorted_indices = torch.sort(exp_tensors)
92
- self.exp_id = sorted_indices[0].item()
93
-
94
- if '.ts' in args.render_path:
95
- self.render = torch.jit.load(args.render_path).cuda()
96
- self.trt = True
97
- else:
98
- self.render = define_G(self.Nw*6, 3, args.ngf, args.netR).eval().cuda()
99
- self.render.load_state_dict(torch.load(args.render_path))
100
- self.trt = False
101
-
102
- print('loaded cached images...')
103
-
104
- @torch.no_grad()
105
- def cg2real(self, rendedimages, start_frame=0):
106
-
107
- ## load original image and the mask
108
- self.source_images = np.concatenate(load_image_from_dir(os.path.join(self.args.source_dir, 'original_frame'),\
109
- resize=self.image_size, limit=len(rendedimages)+start_frame))[start_frame:]
110
- self.source_masks = np.concatenate(load_image_from_dir(os.path.join(self.args.source_dir, 'original_mask'),\
111
- resize=self.image_size, limit=len(rendedimages)+start_frame))[start_frame:]
112
-
113
- self.source_masks = torch.FloatTensor(np.transpose(self.source_masks,(0,3,1,2))/255.)
114
- self.padded_real_tensor = torch.FloatTensor(np.transpose(self.source_images,(0,3,1,2))/255.)
115
-
116
- ## padding the rended_imgs
117
- paded_tensor = torch.cat([rendedimages[0:1]]* (self.Nw // 2) + [rendedimages] + [rendedimages[-1:]]* (self.Nw // 2)).contiguous()
118
- paded_mask_tensor = torch.cat([self.source_masks[0:1]]* (self.Nw // 2) + [self.source_masks] + [self.source_masks[-1:]]* (self.Nw // 2)).contiguous()
119
- paded_real_tensor = torch.cat([self.padded_real_tensor[0:1]]* (self.Nw // 2) + [self.padded_real_tensor] + [self.padded_real_tensor[-1:]]* (self.Nw // 2)).contiguous()
120
-
121
- # paded_mask_tensor = maskErosion(paded_mask_tensor, offY=self.args.mask)
122
- padded_input = ((paded_real_tensor-0.5)*2 ) # *(1-paded_mask_tensor)
123
- padded_input = torch.nn.functional.interpolate(padded_input, (self.image_size, self.image_size), mode='bilinear', align_corners=False)
124
- paded_tensor = torch.nn.functional.interpolate(paded_tensor, (self.image_size, self.image_size), mode='bilinear', align_corners=False)
125
- paded_tensor = (paded_tensor-0.5)*2
126
-
127
- result = []
128
- for index in tqdm(range(0, len(rendedimages), self.args.renderbs), desc='CG2REAL:'):
129
- list_A = []
130
- list_R = []
131
- list_M = []
132
- for i in range(self.args.renderbs):
133
- idx = index + i
134
- if idx+self.Nw > len(padded_input):
135
- list_A.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0))
136
- list_R.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0))
137
- list_M.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0))
138
- else:
139
- list_A.append(padded_input[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0))
140
- list_R.append(paded_tensor[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0))
141
- list_M.append(paded_mask_tensor[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0))
142
-
143
- list_A = torch.cat(list_A)
144
- list_R = torch.cat(list_R)
145
- list_M = torch.cat(list_M)
146
-
147
- idx = (self.Nw//2) * 3
148
- mask = list_M[:, idx:idx+3]
149
-
150
- # list_A = padded_input
151
- mask = maskErosion(mask, offY=self.args.mask)
152
- list_A = list_A * (1 - mask[:,0:1])
153
- A = torch.cat([list_A, list_R], 1)
154
-
155
- if self.trt:
156
- B = self.render(A.half().cuda())
157
- elif self.args.netR == 'unet_256':
158
- # import pdb; pdb.set_trace()
159
- idx = (self.Nw//2) * 3
160
- mask = list_M[:, idx:idx+3].cuda()
161
- mask = maskErosion(mask, offY=self.args.mask)
162
- B0 = list_A[:, idx:idx+3].cuda()
163
- B = self.render(A.cuda()) * mask[:,0:1] + (1 - mask[:,0:1]) * B0
164
- elif self.args.netR == 's2am':
165
- # import pdb; pdb.set_trace()
166
- idx = (self.Nw//2) * 3
167
- mask = list_M[:, idx:idx+3].cuda()
168
- mask = maskErosion(mask, offY=self.args.mask)
169
- B0 = list_A[:, idx:idx+3].cuda()
170
- B = self.render(A.cuda(), mask[:,0:1] ) * mask[:,0:1] + (1 - mask[:,0:1]) * B0
171
- else:
172
- B = self.render(A.cuda())
173
-
174
- result.append((B.cpu() + 1) * 0.5) # -1,1 -> 0,1
175
-
176
- return torch.cat(result)[:len(rendedimages)]
177
-
178
- @torch.no_grad()
179
- def coeffs_to_img(self, vertices, coeffs, zero_pose=False, XK = 20):
180
-
181
- xlen = vertices.shape[0]
182
- all_shape_images = []
183
- landmark2d = []
184
-
185
- #### find the most larger pose 51 in the coeffs.
186
- max_pose_51 = torch.max(self.coeffs_dict['pose'][..., 3:4].squeeze(-1))
187
-
188
- for i in tqdm(range(0, xlen, XK)):
189
-
190
- if i + XK > xlen:
191
- XK = xlen - i
192
-
193
- codedictdecoder = {}
194
- codedictdecoder['shape'] = torch.zeros_like(self.coeffs_dict['shape'][i:i+XK].cuda())
195
- codedictdecoder['tex'] = self.coeffs_dict['tex'][i:i+XK].cuda()
196
- codedictdecoder['exp'] = torch.zeros_like(self.coeffs_dict['exp'][i:i+XK].cuda()) # all_exps[i:i+XK, :50].cuda() # # # vid_exps[i:i+1].cuda() i:i+XK
197
- codedictdecoder['pose'] = self.coeffs_dict['pose'][i:i+XK] # vid_poses[i:i+1].cuda()
198
- codedictdecoder['cam'] = self.coeffs_dict['cam'][i:i+XK].cuda() # vid_poses[i:i+1].cuda()
199
- codedictdecoder['light'] = self.coeffs_dict['light'][i:i+XK].cuda() # vid_poses[i:i+1].cuda()
200
- codedictdecoder['images'] = torch.zeros((XK,3,256,256)).cuda()
201
-
202
- codedictdecoder['pose'][..., 3:4] = torch.clip(coeffs[i:i+XK, 50:51], 0, max_pose_51*0.9) # torch.zeros_like(self.coeffs_dict['pose'][i:i+XK, 3:])
203
- codedictdecoder['pose'][..., 4:6] = 0 # coeffs[i:i+XK, 50:]*( - 0.25) # torch.zeros_like(self.coeffs_dict['pose'][i:i+XK, 3:])
204
-
205
- sub_vertices = vertices[i:i+XK].cuda()
206
-
207
- opdict = self.spectre.decode_verts(codedictdecoder, sub_vertices, rendering=True, vis_lmk=False, return_vis=False)
208
-
209
- landmark2d.append(opdict['landmarks2d'].cpu())
210
-
211
- all_shape_images.append(opdict['rendered_images'].cpu())
212
-
213
- rendedimages = torch.cat(all_shape_images)
214
-
215
- lmk2d = torch.cat(landmark2d)
216
-
217
- return rendedimages, lmk2d
218
-
219
-
220
- @torch.no_grad()
221
- def run_spectre_v3(self, wav=None, ds_features=None, L=20):
222
-
223
- wav = audio_normalize(wav)
224
- all_mel = self.audio.melspectrogram(wav).astype(np.float32).T
225
- frames_from_audio = np.arange(2, len(all_mel) // self.audio.num_bins_per_frame - 2) # 2,[]mmmmmmmmmmmmmmmmmmmmmmmmmmmm
226
- audio_inds = frame2audio_indexs(frames_from_audio, self.audio.num_frames_per_clip, self.audio.num_bins_per_frame)
227
-
228
- vid_exps = self.coeffs_dict['exp'][self.exp_id:self.exp_id+1]
229
- vid_poses = self.coeffs_dict['pose'][self.exp_id:self.exp_id+1]
230
 
231
- ref = torch.cat([vid_exps.view(1, 50), vid_poses[:, 3:].view(1, 3)], dim=-1)
232
- ref = ref[...,:self.args.exp_dim]
233
-
234
- K = 20
235
- xlens = len(audio_inds) # len(self.coeffs_dict['exp'])
236
-
237
- exps = []
238
- for i in tqdm(range(0, xlens, K), desc='S2 DECODER:'+ str(xlens) + ' '):
239
-
240
- mels = []
241
- for j in range(K):
242
- if i + j < xlens:
243
- idx = i+j # //3 * 3
244
- mel = load_spectrogram(all_mel, audio_inds[idx], self.audio.num_frames_per_clip * self.audio.num_bins_per_frame).cuda()
245
- mel = mel.view(-1, 1, 80, self.audio.num_frames_per_clip * self.audio.num_bins_per_frame)
246
- mels.append(mel)
247
- else:
248
- break
249
-
250
- mels = torch.cat(mels, dim=0)
251
- new_exp = self.avmodel(mels, ref.repeat(mels.shape[0], 1, 1).cuda(), self.args.use_tanh) # exp 53
252
- exps+= [new_exp.view(-1, 53)]
253
-
254
- all_exps = torch.cat(exps,axis=0)
255
-
256
- return all_exps
257
-
258
- @torch.no_grad()
259
- def test_model(self, wav_path):
260
-
261
- sys.path.append('../FaceFormer')
262
- from faceformer import Faceformer
263
- from transformers import Wav2Vec2FeatureExtractor,Wav2Vec2Processor
264
- from faceformer import PeriodicPositionalEncoding, init_biased_mask
265
-
266
- #build model
267
- self.args.train_subjects = " ".join(["A"]*8) # suitable for pre-trained faceformer checkpoint
268
- model = Faceformer(self.args)
269
- model.load_state_dict(torch.load('/apdcephfs/private_shadowcun/Avatar2dFF/medias/videos/c8/mask5000_l2/6_model.pth')) # ../packages/pretrained/28_ff_model.pth
270
- model = model.to(torch.device(self.device))
271
- model.eval()
272
-
273
- # hacking for long audio generation
274
- model.PPE = PeriodicPositionalEncoding(self.args.feature_dim, period = self.args.period, max_seq_len=6000).cuda()
275
- model.biased_mask = init_biased_mask(n_head = 4, max_seq_len = 6000, period=self.args.period).cuda()
276
-
277
- train_subjects_list = ["A"] * 8
278
-
279
- one_hot_labels = np.eye(len(train_subjects_list))
280
- one_hot = one_hot_labels[0]
281
- one_hot = np.reshape(one_hot,(-1,one_hot.shape[0]))
282
- one_hot = torch.FloatTensor(one_hot).to(device=self.device)
283
-
284
- vertices_npy = np.load(self.args.source_dir + '/mesh_pose0.npy')
285
- vertices_npy = np.array(vertices_npy).reshape(-1, 5023*3)
286
-
287
- temp = vertices_npy[33] # 829
288
-
289
- template = temp.reshape((-1))
290
- template = np.reshape(template,(-1,template.shape[0]))
291
- template = torch.FloatTensor(template).to(device=self.device)
292
-
293
- speech_array, sampling_rate = librosa.load(os.path.join(wav_path), sr=16000)
294
- processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
295
- audio_feature = np.squeeze(processor(speech_array,sampling_rate=16000).input_values)
296
- audio_feature = np.reshape(audio_feature,(-1,audio_feature.shape[0]))
297
- audio_feature = torch.FloatTensor(audio_feature).to(device=self.device)
 
 
 
 
 
 
 
 
 
 
298
 
299
- prediction = model.predict(audio_feature, template, one_hot, 1.0) # (1, seq_len, V*3)
300
-
301
- return prediction.squeeze()
302
-
303
- @torch.no_grad()
304
- def run(self, face, audio, start_frame=0):
305
-
306
- wav, sr = librosa.load(audio, sr=16000) # 16*80 ? 20*80
307
- wav_tensor = torch.FloatTensor(wav).unsqueeze(0) if len(wav.shape) == 1 else torch.FloatTensor(wav)
308
- _, frames = parse_audio_length(wav_tensor.shape[1], 16000, self.args.fps)
309
-
310
- ##### audio-guided, only use the jaw movement
311
- all_exps = self.run_spectre_v3(wav)
312
-
313
- # #### temp. interpolation
314
- all_exps = torch.nn.functional.interpolate(all_exps.unsqueeze(0).permute([0,2,1]), size=frames, mode='linear')
315
- all_exps = all_exps.permute([0,2,1]).squeeze(0)
316
-
317
- # run faceformer for face mesh generation
318
- predicted_vertices = self.test_model(audio)
319
- predicted_vertices = predicted_vertices.view(-1, 5023*3)
320
-
321
- #### temp. interpolation
322
- predicted_vertices = torch.nn.functional.interpolate(predicted_vertices.unsqueeze(0).permute([0,2,1]), size=frames, mode='linear')
323
- predicted_vertices = predicted_vertices.permute([0,2,1]).squeeze(0).view(-1, 5023, 3)
324
-
325
- all_exps = torch.Tensor(savgol_filter(all_exps.cpu().numpy(), 5, 3, axis=0)).cpu() # smooth GT
326
-
327
- rendedimages, lm2d = self.coeffs_to_img(predicted_vertices, all_exps, zero_pose=True)
328
- debug_video_gen(rendedimages, self.args.result_dir+"/debug_before_ff.mp4", wav_tensor, self.args.fps, sr)
329
-
330
- # cg2real
331
- debug_video_gen(self.cg2real(rendedimages, start_frame=start_frame), self.args.result_dir+"/debug_cg2real_raw.mp4", wav_tensor, self.args.fps, sr)
332
-
333
- exit()
334
-
335
-
336
-
337
- if __name__ == '__main__':
338
- parser = argparse.ArgumentParser(description='Stylization and Seamless Video Dubbing')
339
- parser.add_argument('--face', default='examples', type=str, help='')
340
- parser.add_argument('--audio', default='examples', type=str, help='')
341
- parser.add_argument('--source_dir', default='examples', type=str,help='TODO')
342
- parser.add_argument('--result_dir', default='examples', type=str,help='TODO')
343
- parser.add_argument('--backend', default='wav2lip', type=str,help='wav2lip or pcavs')
344
- parser.add_argument('--result_tag', default='result', type=str,help='TODO')
345
- parser.add_argument('--netR', default='unet_256', type=str,help='TODO')
346
- parser.add_argument('--render_path', default='', type=str,help='TODO')
347
- parser.add_argument('--ngf', default=16, type=int,help='TODO')
348
- parser.add_argument('--fps', default=20, type=int,help='TODO')
349
- parser.add_argument('--mask', default=100, type=int,help='TODO')
350
- parser.add_argument('--mask_type', default='v3', type=str,help='TODO')
351
- parser.add_argument('--image_size', default=256, type=int,help='TODO')
352
- parser.add_argument('--input_nc', default=21, type=int,help='TODO')
353
- parser.add_argument('--output_nc', default=3, type=int,help='TODO')
354
- parser.add_argument('--renderbs', default=16, type=int,help='TODO')
355
- parser.add_argument('--tframes', default=1, type=int,help='TODO')
356
- parser.add_argument('--debug', action='store_true')
357
- parser.add_argument('--enhance', action='store_true')
358
- parser.add_argument('--phone', action='store_true')
359
-
360
- #### faceformer
361
- parser.add_argument("--model_name", type=str, default="VOCA")
362
- parser.add_argument("--dataset", type=str, default="vocaset", help='vocaset or BIWI')
363
- parser.add_argument("--feature_dim", type=int, default=64, help='64 for vocaset; 128 for BIWI')
364
- parser.add_argument("--period", type=int, default=30, help='period in PPE - 30 for vocaset; 25 for BIWI')
365
- parser.add_argument("--vertice_dim", type=int, default=5023*3, help='number of vertices - 5023*3 for vocaset; 23370*3 for BIWI')
366
- parser.add_argument("--device", type=str, default="cuda")
367
- parser.add_argument("--train_subjects", type=str, default="FaceTalk_170728_03272_TA ")
368
- parser.add_argument("--test_subjects", type=str, default="FaceTalk_170809_00138_TA FaceTalk_170731_00024_TA")
369
- parser.add_argument("--condition", type=str, default="FaceTalk_170904_00128_TA", help='select a conditioning subject from train_subjects')
370
- parser.add_argument("--subject", type=str, default="FaceTalk_170731_00024_TA", help='select a subject from test_subjects or train_subjects')
371
- parser.add_argument("--background_black", type=bool, default=True, help='whether to use black background')
372
- parser.add_argument("--template_path", type=str, default="templates.pkl", help='path of the personalized templates')
373
- parser.add_argument("--render_template_path", type=str, default="templates", help='path of the mesh in BIWI/FLAME topology')
374
 
375
- opt = parser.parse_args()
376
 
377
- opt.img_size = 96
378
- opt.static = True
379
- opt.device = torch.device("cuda")
 
 
380
 
381
- a2m = Audio2Mesh(opt)
382
 
383
- print('link start!')
384
- t = time.time()
385
- # 02780
386
- a2m.run(opt.face, opt.audio, 0)
387
- print(time.time() - t)
 
1
+ import os, sys
2
+ import tempfile
3
+ import gradio as gr
4
+ from modules.text2speech import text2speech
5
+ from modules.gfpgan_inference import gfpgan
6
+ from modules.sadtalker_test import SadTalker
7
+
8
+ def get_driven_audio(audio):
9
+ if os.path.isfile(audio):
10
+ return audio
11
+ else:
12
+ save_path = tempfile.NamedTemporaryFile(
13
+ delete=False,
14
+ suffix=("." + "wav"),
15
+ )
16
+ gen_audio = text2speech(audio, save_path.name)
17
+ return gen_audio, gen_audio
18
+
19
+ def get_source_image(image):
20
+ return image
21
+
22
+ def sadtalker_demo(result_dir):
23
+
24
+ sad_talker = SadTalker()
25
+ with gr.Blocks(analytics_enabled=False) as sadtalker_interface:
26
+ gr.Markdown("<div align='center'> <h2> 😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023) </span> </h2> \
27
+ <a style='font-size:18px;color: #efefef' href='https://arxiv.org/abs/2211.12194'>Arxiv</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
28
+ <a style='font-size:18px;color: #efefef' href='https://sadtalker.github.io'>Homepage</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
29
+ <a style='font-size:18px;color: #efefef' href='https://github.com/Winfredy/SadTalker'> Github </div>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ with gr.Row().style(equal_height=False):
32
+ with gr.Column(variant='panel'):
33
+ with gr.Tabs(elem_id="sadtalker_source_image"):
34
+ with gr.TabItem('Upload image'):
35
+ with gr.Row():
36
+ source_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=256,width=256)
37
+
38
+ with gr.Tabs(elem_id="sadtalker_driven_audio"):
39
+ with gr.TabItem('Upload audio'):
40
+ with gr.Column(variant='panel'):
41
+ driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")
42
+ # submit_audio_1 = gr.Button('Submit', variant='primary')
43
+ # submit_audio_1.click(fn=get_driven_audio, inputs=input_audio1, outputs=driven_audio)
44
+
45
+
46
+ with gr.Column(variant='panel'):
47
+ with gr.Tabs(elem_id="sadtalker_checkbox"):
48
+ with gr.TabItem('Settings'):
49
+ with gr.Column(variant='panel'):
50
+ is_still_mode = gr.Checkbox(label="w/ Still Mode (fewer hand motion)")
51
+ enhancer = gr.Checkbox(label="w/ GFPGAN as Face enhancer")
52
+ submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
53
+
54
+ with gr.Tabs(elem_id="sadtalker_genearted"):
55
+ gen_video = gr.Video(label="Generated video", format="mp4").style(height=256,width=256)
56
+ gen_text = gr.Textbox(visible=False)
57
+
58
+
59
+ with gr.Row():
60
+ examples = [
61
+ [
62
+ 'examples/source_image/art_10.png',
63
+ 'examples/driven_audio/deyu.wav',
64
+ True,
65
+ False
66
+ ],
67
+ [
68
+ 'examples/source_image/art_1.png',
69
+ 'examples/driven_audio/chinese_poem1.wav',
70
+ True,
71
+ False
72
+ ],
73
+ [
74
+ 'examples/source_image/art_13.png',
75
+ 'examples/driven_audio/fayu.wav',
76
+ True,
77
+ False
78
+ ],
79
+ [
80
+ 'examples/source_image/art_5.png',
81
+ 'examples/driven_audio/chinese_news.wav',
82
+ True,
83
+ False
84
+ ],
85
+ ]
86
+ gr.Examples(examples=examples,
87
+ inputs=[
88
+ source_image,
89
+ driven_audio,
90
+ is_still_mode,
91
+ enhancer,
92
+ gr.Textbox(value=result_dir, visible=False)],
93
+ outputs=[gen_video, gen_text],
94
+ fn=sad_talker.test,
95
+ cache_examples=os.getenv('SYSTEM') == 'spaces')
96
+
97
+ submit.click(
98
+ fn=sad_talker.test,
99
+ inputs=[source_image,
100
+ driven_audio,
101
+ is_still_mode,
102
+ enhancer,
103
+ gr.Textbox(value=result_dir, visible=False)],
104
+ outputs=[gen_video, gen_text]
105
+ )
106
+
107
+ return sadtalker_interface
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ if __name__ == "__main__":
111
 
112
+ current_code_path = sys.argv[0]
113
+ current_root_dir = os.path.split(current_code_path)[0]
114
+ sadtalker_result_dir = os.path.join(current_root_dir, 'results', 'sadtalker')
115
+ demo = sadtalker_demo(sadtalker_result_dir)
116
+ demo.launch()
117
 
 
118