insanecoder69 commited on
Commit
c6df9b7
·
verified ·
1 Parent(s): 07b8206

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +294 -51
app.py CHANGED
@@ -1,55 +1,298 @@
1
  import gradio as gr
2
- import subprocess
3
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- def run_talkshow_model(audio_file):
6
- # Path to the TalkSHOW demo script
7
- demo_script = 'scripts/demo.py'
8
-
9
- # Configuration and model parameters
10
- config_file = './config/LS3DCG.json'
11
- body_model_name = 's2g_LS3DCG'
12
- body_model_path = 'experiments/2022-10-19-smplx_S2G-LS3DCG/ckpt-99.pth'
13
-
14
- # Path of the uploaded audio file
15
- audio_file_path = audio_file
16
-
17
- # Path where the output .mp4 video will be saved
18
- output_video_path = './output_video/result.mp4'
19
-
20
- # Run the demo.py script with the necessary arguments
21
- command = [
22
- 'python', demo_script,
23
- '--config_file', config_file,
24
- '--infer',
25
- '--audio_file', audio_file_path,
26
- '--body_model_name', body_model_name,
27
- '--body_model_path', body_model_path,
28
- '--id', '0',
29
- '--output', output_video_path # Assuming demo.py has an argument to specify output
30
- ]
31
-
32
- try:
33
- # Run the subprocess and capture output
34
- subprocess.run(command, check=True, capture_output=True, text=True)
35
-
36
- # Check if the .mp4 file is generated
37
- if os.path.exists(output_video_path):
38
- return output_video_path # Return the path of the generated video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  else:
40
- return "Error: Output video not generated."
41
-
42
- except subprocess.CalledProcessError as e:
43
- return f"Error running the model: {e.stderr}" # Return the error message
44
-
45
- # Set up the Gradio interface
46
- interface = gr.Interface(
47
- fn=run_talkshow_model,
48
- inputs=gr.Audio(source="upload", type="filepath"),
49
- outputs=gr.Video(), # Use gr.Video to output the generated .mp4 video
50
- title="TalkSHOW: Audio to Mesh"
51
- )
52
-
53
- # Launch the interface
54
- if __name__ == "__main__":
55
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import os
3
+ import sys
4
+ # import OpenGL.GL as gl
5
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
6
+ os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1"
7
+ os.system('pip install /home/user/app/pyrender')
8
+ sys.path.append('/home/user/app/pyrender')
9
+ # os.system(r"apt-get install -y python-opengl libosmesa6")
10
+ sys.path.append(os.getcwd())
11
+ # os.system(r"cd mesh-master")
12
+ # os.system(r"tar -jxvf boost_1_79_0.tar.bz2")
13
+ # os.system(r"mv boost_1_79_0 boost")
14
+ # os.system(r"CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/home/user/app/boost")
15
+ # os.system(r"export LIBRARY_PATH=$LIBRARY_PATH:/home/user/app/boost/stage/lib")
16
+ # os.system(r"apt-get update")
17
+ # os.system(r"apt-get install sudo")
18
+ #
19
+ # os.system(r"apt-get install libboost-dev")
20
+ # # os.system(r"sudo apt-get install gcc")
21
+ # # os.system(r"sudo apt-get install g++")
22
+ # os.system(r"make -C ./mesh-master all")
23
+ # os.system(r"cd ..")
24
+ # os.system("pip install --no-deps --verbose --no-cache-dir /home/user/app/mesh-fix-MSVC_compilation")
25
 
26
+ from transformers import Wav2Vec2Processor
27
+
28
+ import numpy as np
29
+ import json
30
+ import smplx as smpl
31
+
32
+ from nets import *
33
+ from trainer.options import parse_args
34
+ from data_utils import torch_data
35
+ from trainer.config import load_JsonConfig
36
+
37
+ import torch
38
+ import torch.nn as nn
39
+ import torch.nn.functional as F
40
+ from torch.utils import data
41
+ from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
42
+ from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
43
+ from visualise.rendering import RenderTool
44
+
45
+ global device
46
+ is_cuda = torch.cuda.is_available()
47
+ device = torch.device("cuda" if is_cuda else "cpu")
48
+
49
+
50
+ def init_model(model_name, model_path, args, config):
51
+ if model_name == 's2g_face':
52
+ generator = s2g_face(
53
+ args,
54
+ config,
55
+ )
56
+ elif model_name == 's2g_body_vq':
57
+ generator = s2g_body_vq(
58
+ args,
59
+ config,
60
+ )
61
+ elif model_name == 's2g_body_pixel':
62
+ generator = s2g_body_pixel(
63
+ args,
64
+ config,
65
+ )
66
+ elif model_name == 's2g_LS3DCG':
67
+ generator = LS3DCG(
68
+ args,
69
+ config,
70
+ )
71
+ else:
72
+ raise NotImplementedError
73
+
74
+ model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
75
+ if model_name == 'smplx_S2G':
76
+ generator.generator.load_state_dict(model_ckpt['generator']['generator'])
77
+
78
+ elif 'generator' in list(model_ckpt.keys()):
79
+ generator.load_state_dict(model_ckpt['generator'])
80
+ else:
81
+ model_ckpt = {'generator': model_ckpt}
82
+ generator.load_state_dict(model_ckpt)
83
+
84
+ return generator
85
+
86
+
87
+ def get_vertices(smplx_model, betas, result_list, exp, require_pose=False):
88
+ vertices_list = []
89
+ poses_list = []
90
+ expression = torch.zeros([1, 100])
91
+
92
+ for i in result_list:
93
+ vertices = []
94
+ poses = []
95
+ for j in range(i.shape[0]):
96
+ output = smplx_model(betas=betas,
97
+ expression=i[j][165:265].unsqueeze_(dim=0) if exp else expression,
98
+ jaw_pose=i[j][0:3].unsqueeze_(dim=0),
99
+ leye_pose=i[j][3:6].unsqueeze_(dim=0),
100
+ reye_pose=i[j][6:9].unsqueeze_(dim=0),
101
+ global_orient=i[j][9:12].unsqueeze_(dim=0),
102
+ body_pose=i[j][12:75].unsqueeze_(dim=0),
103
+ left_hand_pose=i[j][75:120].unsqueeze_(dim=0),
104
+ right_hand_pose=i[j][120:165].unsqueeze_(dim=0),
105
+ return_verts=True)
106
+ vertices.append(output.vertices.detach().cpu().numpy().squeeze())
107
+ # pose = torch.cat([output.body_pose, output.left_hand_pose, output.right_hand_pose], dim=1)
108
+ pose = output.body_pose
109
+ poses.append(pose.detach().cpu())
110
+ vertices = np.asarray(vertices)
111
+ vertices_list.append(vertices)
112
+ poses = torch.cat(poses, dim=0)
113
+ poses_list.append(poses)
114
+ if require_pose:
115
+ return vertices_list, poses_list
116
+ else:
117
+ return vertices_list, None
118
+
119
+
120
+ global_orient = torch.tensor([3.0747, -0.0158, -0.0152])
121
+
122
+ parser = parse_args()
123
+ args = parser.parse_args()
124
+ args.gpu = device
125
+
126
+ RUN_MODE = "local"
127
+ if RUN_MODE != "local":
128
+ os.system("wget -P experiments/2022-10-15-smplx_S2G-face-3d/ "
129
+ "https://huggingface.co/feifeifeiliu/TalkSHOW/resolve/main/2022-10-15-smplx_S2G-face-3d/ckpt-99.pth")
130
+ os.system("wget -P experiments/2022-10-31-smplx_S2G-body-vq-3d/ "
131
+ "https://huggingface.co/feifeifeiliu/TalkSHOW/resolve/main/2022-10-31-smplx_S2G-body-vq-3d/ckpt-99.pth")
132
+ os.system("wget -P experiments/2022-11-02-smplx_S2G-body-pixel-3d/ "
133
+ "https://huggingface.co/feifeifeiliu/TalkSHOW/resolve/main/2022-11-02-smplx_S2G-body-pixel-3d/ckpt-99.pth")
134
+ os.system("wget -P visualise/smplx/ "
135
+ "https://huggingface.co/feifeifeiliu/TalkSHOW/resolve/main/smplx/SMPLX_NEUTRAL.npz")
136
+
137
+ config = load_JsonConfig("config/body_pixel.json")
138
+
139
+ face_model_name = args.face_model_name
140
+ face_model_path = args.face_model_path
141
+ body_model_name = args.body_model_name
142
+ body_model_path = args.body_model_path
143
+ smplx_path = './visualise/'
144
+
145
+ os.environ['smplx_npz_path'] = config.smplx_npz_path
146
+ os.environ['extra_joint_path'] = config.extra_joint_path
147
+ os.environ['j14_regressor_path'] = config.j14_regressor_path
148
+
149
+ print('init model...')
150
+ g_body = init_model(body_model_name, body_model_path, args, config)
151
+ generator2 = None
152
+ g_face = init_model(face_model_name, face_model_path, args, config)
153
+
154
+ print('init smlpx model...')
155
+ dtype = torch.float64
156
+ model_params = dict(model_path=smplx_path,
157
+ model_type='smplx',
158
+ create_global_orient=True,
159
+ create_body_pose=True,
160
+ create_betas=True,
161
+ num_betas=300,
162
+ create_left_hand_pose=True,
163
+ create_right_hand_pose=True,
164
+ use_pca=False,
165
+ flat_hand_mean=False,
166
+ create_expression=True,
167
+ num_expression_coeffs=100,
168
+ num_pca_comps=12,
169
+ create_jaw_pose=True,
170
+ create_leye_pose=True,
171
+ create_reye_pose=True,
172
+ create_transl=False,
173
+ # gender='ne',
174
+ dtype=dtype, )
175
+ smplx_model = smpl.create(**model_params).to(device)
176
+ print('init rendertool...')
177
+ rendertool = RenderTool('visualise/video/' + config.Log.name)
178
+
179
+
180
+ def infer(wav, identity, pose):
181
+ betas = torch.zeros([1, 300], dtype=torch.float64).to(device)
182
+ am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
183
+ am_sr = 16000
184
+ num_sample = args.num_sample
185
+ cur_wav_file = wav
186
+
187
+ if pose == 'Stand':
188
+ stand = True
189
+ face = False
190
+ elif pose == 'Sit':
191
+ stand = False
192
+ face = False
193
+ else:
194
+ stand = False
195
+ face = True
196
+
197
+ if face:
198
+ body_static = torch.zeros([1, 162], device=device)
199
+ body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
200
+
201
+ if identity == 'Oliver':
202
+ id = 0
203
+ elif identity == 'Chemistry':
204
+ id = 1
205
+ elif identity == 'Seth':
206
+ id = 2
207
+ elif identity == 'Conan':
208
+ id = 3
209
+
210
+ result_list = []
211
+
212
+ pred_face = g_face.infer_on_audio(cur_wav_file,
213
+ initial_pose=None,
214
+ norm_stats=None,
215
+ w_pre=False,
216
+ # id=id,
217
+ frame=None,
218
+ am=am,
219
+ am_sr=am_sr
220
+ )
221
+ pred_face = torch.tensor(pred_face).squeeze().to(device)
222
+ # pred_face = torch.zeros([gt.shape[0], 105])
223
+
224
+ if config.Data.pose.convert_to_6d:
225
+ pred_jaw = pred_face[:, :6].reshape(pred_face.shape[0], -1, 6)
226
+ pred_jaw = matrix_to_axis_angle(rotation_6d_to_matrix(pred_jaw)).reshape(pred_face.shape[0], -1)
227
+ pred_face = pred_face[:, 6:]
228
+ else:
229
+ pred_jaw = pred_face[:, :3]
230
+ pred_face = pred_face[:, 3:]
231
+
232
+ id = torch.tensor([id], device=device)
233
+
234
+ for i in range(num_sample):
235
+ pred_res = g_body.infer_on_audio(cur_wav_file,
236
+ initial_pose=None,
237
+ norm_stats=None,
238
+ txgfile=None,
239
+ id=id,
240
+ var=None,
241
+ fps=30,
242
+ w_pre=False
243
+ )
244
+ pred = torch.tensor(pred_res).squeeze().to(device)
245
+
246
+ if pred.shape[0] < pred_face.shape[0]:
247
+ repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
248
+ pred = torch.cat([pred, repeat_frame], dim=0)
249
+ else:
250
+ pred = pred[:pred_face.shape[0], :]
251
+
252
+ body_or_face = False
253
+ if pred.shape[1] < 275:
254
+ body_or_face = True
255
+ if config.Data.pose.convert_to_6d:
256
+ pred = pred.reshape(pred.shape[0], -1, 6)
257
+ pred = matrix_to_axis_angle(rotation_6d_to_matrix(pred))
258
+ pred = pred.reshape(pred.shape[0], -1)
259
+
260
+ if config.Model.model_name == 's2g_LS3DCG':
261
+ pred = torch.cat([pred[:, :3], pred[:, 103:], pred[:, 3:103]], dim=-1)
262
  else:
263
+ pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)
264
+
265
+ # pred[:, 9:12] = global_orient
266
+ pred = part2full(pred, stand)
267
+ if face:
268
+ pred = torch.cat([pred[:, :3], body_static.repeat(pred.shape[0], 1), pred[:, -100:]], dim=-1)
269
+ # result_list[0] = poses2pred(result_list[0], stand)
270
+ # if gt_0 is None:
271
+ # gt_0 = gt
272
+ # pred = pred2poses(pred, gt_0)
273
+ # result_list[0] = poses2poses(result_list[0], gt_0)
274
+
275
+ result_list.append(pred)
276
+
277
+
278
+ vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)
279
+
280
+ result_list = [res.to('cpu') for res in result_list]
281
+ dict = np.concatenate(result_list[:], axis=0)
282
+
283
+ rendertool._render_sequences(cur_wav_file, vertices_list, stand=stand, face=face, whole_body=args.whole_body)
284
+ return "result.mp4"
285
+
286
+ def main():
287
+
288
+ iface = gr.Interface(fn=infer, inputs=["audio",
289
+ gr.Radio(["Oliver", "Chemistry", "Seth", "Conan"]),
290
+ gr.Radio(["Stand", "Sit", "Only Face"]),
291
+ ],
292
+ outputs="video",
293
+ examples=[[os.path.join(os.path.dirname(__file__), "demo_audio/style.wav"), "Oliver", "Sit"]])
294
+ iface.launch(debug=True)
295
+
296
+
297
+ if __name__ == '__main__':
298
+ main()