insanecoder69 commited on
Commit
b506075
·
verified ·
1 Parent(s): c6df9b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -272
app.py CHANGED
@@ -1,298 +1,77 @@
1
  import gradio as gr
 
2
  import os
3
- import sys
4
- # import OpenGL.GL as gl
5
- os.environ["PYOPENGL_PLATFORM"] = "egl"
6
- os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1"
7
- os.system('pip install /home/user/app/pyrender')
8
- sys.path.append('/home/user/app/pyrender')
9
- # os.system(r"apt-get install -y python-opengl libosmesa6")
10
- sys.path.append(os.getcwd())
11
- # os.system(r"cd mesh-master")
12
- # os.system(r"tar -jxvf boost_1_79_0.tar.bz2")
13
- # os.system(r"mv boost_1_79_0 boost")
14
- # os.system(r"CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/home/user/app/boost")
15
- # os.system(r"export LIBRARY_PATH=$LIBRARY_PATH:/home/user/app/boost/stage/lib")
16
- # os.system(r"apt-get update")
17
- # os.system(r"apt-get install sudo")
18
- #
19
- # os.system(r"apt-get install libboost-dev")
20
- # # os.system(r"sudo apt-get install gcc")
21
- # # os.system(r"sudo apt-get install g++")
22
- # os.system(r"make -C ./mesh-master all")
23
- # os.system(r"cd ..")
24
- # os.system("pip install --no-deps --verbose --no-cache-dir /home/user/app/mesh-fix-MSVC_compilation")
25
 
26
- from transformers import Wav2Vec2Processor
27
 
28
- import numpy as np
29
- import json
30
- import smplx as smpl
31
 
32
- from nets import *
33
- from trainer.options import parse_args
34
- from data_utils import torch_data
35
- from trainer.config import load_JsonConfig
36
 
37
- import torch
38
- import torch.nn as nn
39
- import torch.nn.functional as F
40
- from torch.utils import data
41
- from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
42
- from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
43
- from visualise.rendering import RenderTool
44
 
45
- global device
46
- is_cuda = torch.cuda.is_available()
47
- device = torch.device("cuda" if is_cuda else "cpu")
48
 
49
 
50
- def init_model(model_name, model_path, args, config):
51
- if model_name == 's2g_face':
52
- generator = s2g_face(
53
- args,
54
- config,
55
- )
56
- elif model_name == 's2g_body_vq':
57
- generator = s2g_body_vq(
58
- args,
59
- config,
60
- )
61
- elif model_name == 's2g_body_pixel':
62
- generator = s2g_body_pixel(
63
- args,
64
- config,
65
- )
66
- elif model_name == 's2g_LS3DCG':
67
- generator = LS3DCG(
68
- args,
69
- config,
70
- )
71
- else:
72
- raise NotImplementedError
73
 
74
- model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
75
- if model_name == 'smplx_S2G':
76
- generator.generator.load_state_dict(model_ckpt['generator']['generator'])
77
 
78
- elif 'generator' in list(model_ckpt.keys()):
79
- generator.load_state_dict(model_ckpt['generator'])
80
- else:
81
- model_ckpt = {'generator': model_ckpt}
82
- generator.load_state_dict(model_ckpt)
83
 
84
- return generator
85
 
86
 
87
- def get_vertices(smplx_model, betas, result_list, exp, require_pose=False):
88
- vertices_list = []
89
- poses_list = []
90
- expression = torch.zeros([1, 100])
91
 
92
- for i in result_list:
93
- vertices = []
94
- poses = []
95
- for j in range(i.shape[0]):
96
- output = smplx_model(betas=betas,
97
- expression=i[j][165:265].unsqueeze_(dim=0) if exp else expression,
98
- jaw_pose=i[j][0:3].unsqueeze_(dim=0),
99
- leye_pose=i[j][3:6].unsqueeze_(dim=0),
100
- reye_pose=i[j][6:9].unsqueeze_(dim=0),
101
- global_orient=i[j][9:12].unsqueeze_(dim=0),
102
- body_pose=i[j][12:75].unsqueeze_(dim=0),
103
- left_hand_pose=i[j][75:120].unsqueeze_(dim=0),
104
- right_hand_pose=i[j][120:165].unsqueeze_(dim=0),
105
- return_verts=True)
106
- vertices.append(output.vertices.detach().cpu().numpy().squeeze())
107
- # pose = torch.cat([output.body_pose, output.left_hand_pose, output.right_hand_pose], dim=1)
108
- pose = output.body_pose
109
- poses.append(pose.detach().cpu())
110
- vertices = np.asarray(vertices)
111
- vertices_list.append(vertices)
112
- poses = torch.cat(poses, dim=0)
113
- poses_list.append(poses)
114
- if require_pose:
115
- return vertices_list, poses_list
116
- else:
117
- return vertices_list, None
118
 
119
 
120
- global_orient = torch.tensor([3.0747, -0.0158, -0.0152])
121
 
122
- parser = parse_args()
123
- args = parser.parse_args()
124
- args.gpu = device
125
 
126
- RUN_MODE = "local"
127
- if RUN_MODE != "local":
128
- os.system("wget -P experiments/2022-10-15-smplx_S2G-face-3d/ "
129
- "https://huggingface.co/feifeifeiliu/TalkSHOW/resolve/main/2022-10-15-smplx_S2G-face-3d/ckpt-99.pth")
130
- os.system("wget -P experiments/2022-10-31-smplx_S2G-body-vq-3d/ "
131
- "https://huggingface.co/feifeifeiliu/TalkSHOW/resolve/main/2022-10-31-smplx_S2G-body-vq-3d/ckpt-99.pth")
132
- os.system("wget -P experiments/2022-11-02-smplx_S2G-body-pixel-3d/ "
133
- "https://huggingface.co/feifeifeiliu/TalkSHOW/resolve/main/2022-11-02-smplx_S2G-body-pixel-3d/ckpt-99.pth")
134
- os.system("wget -P visualise/smplx/ "
135
- "https://huggingface.co/feifeifeiliu/TalkSHOW/resolve/main/smplx/SMPLX_NEUTRAL.npz")
136
 
137
- config = load_JsonConfig("config/body_pixel.json")
138
 
139
- face_model_name = args.face_model_name
140
- face_model_path = args.face_model_path
141
- body_model_name = args.body_model_name
142
- body_model_path = args.body_model_path
143
- smplx_path = './visualise/'
144
 
145
- os.environ['smplx_npz_path'] = config.smplx_npz_path
146
- os.environ['extra_joint_path'] = config.extra_joint_path
147
- os.environ['j14_regressor_path'] = config.j14_regressor_path
148
 
149
- print('init model...')
150
- g_body = init_model(body_model_name, body_model_path, args, config)
151
- generator2 = None
152
- g_face = init_model(face_model_name, face_model_path, args, config)
153
 
154
- print('init smlpx model...')
155
- dtype = torch.float64
156
- model_params = dict(model_path=smplx_path,
157
- model_type='smplx',
158
- create_global_orient=True,
159
- create_body_pose=True,
160
- create_betas=True,
161
- num_betas=300,
162
- create_left_hand_pose=True,
163
- create_right_hand_pose=True,
164
- use_pca=False,
165
- flat_hand_mean=False,
166
- create_expression=True,
167
- num_expression_coeffs=100,
168
- num_pca_comps=12,
169
- create_jaw_pose=True,
170
- create_leye_pose=True,
171
- create_reye_pose=True,
172
- create_transl=False,
173
- # gender='ne',
174
- dtype=dtype, )
175
- smplx_model = smpl.create(**model_params).to(device)
176
- print('init rendertool...')
177
- rendertool = RenderTool('visualise/video/' + config.Log.name)
178
 
179
-
180
- def infer(wav, identity, pose):
181
- betas = torch.zeros([1, 300], dtype=torch.float64).to(device)
182
- am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
183
- am_sr = 16000
184
- num_sample = args.num_sample
185
- cur_wav_file = wav
186
-
187
- if pose == 'Stand':
188
- stand = True
189
- face = False
190
- elif pose == 'Sit':
191
- stand = False
192
- face = False
193
- else:
194
- stand = False
195
- face = True
196
-
197
- if face:
198
- body_static = torch.zeros([1, 162], device=device)
199
- body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
200
-
201
- if identity == 'Oliver':
202
- id = 0
203
- elif identity == 'Chemistry':
204
- id = 1
205
- elif identity == 'Seth':
206
- id = 2
207
- elif identity == 'Conan':
208
- id = 3
209
-
210
- result_list = []
211
-
212
- pred_face = g_face.infer_on_audio(cur_wav_file,
213
- initial_pose=None,
214
- norm_stats=None,
215
- w_pre=False,
216
- # id=id,
217
- frame=None,
218
- am=am,
219
- am_sr=am_sr
220
- )
221
- pred_face = torch.tensor(pred_face).squeeze().to(device)
222
- # pred_face = torch.zeros([gt.shape[0], 105])
223
-
224
- if config.Data.pose.convert_to_6d:
225
- pred_jaw = pred_face[:, :6].reshape(pred_face.shape[0], -1, 6)
226
- pred_jaw = matrix_to_axis_angle(rotation_6d_to_matrix(pred_jaw)).reshape(pred_face.shape[0], -1)
227
- pred_face = pred_face[:, 6:]
228
- else:
229
- pred_jaw = pred_face[:, :3]
230
- pred_face = pred_face[:, 3:]
231
-
232
- id = torch.tensor([id], device=device)
233
-
234
- for i in range(num_sample):
235
- pred_res = g_body.infer_on_audio(cur_wav_file,
236
- initial_pose=None,
237
- norm_stats=None,
238
- txgfile=None,
239
- id=id,
240
- var=None,
241
- fps=30,
242
- w_pre=False
243
- )
244
- pred = torch.tensor(pred_res).squeeze().to(device)
245
-
246
- if pred.shape[0] < pred_face.shape[0]:
247
- repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
248
- pred = torch.cat([pred, repeat_frame], dim=0)
249
  else:
250
- pred = pred[:pred_face.shape[0], :]
251
-
252
- body_or_face = False
253
- if pred.shape[1] < 275:
254
- body_or_face = True
255
- if config.Data.pose.convert_to_6d:
256
- pred = pred.reshape(pred.shape[0], -1, 6)
257
- pred = matrix_to_axis_angle(rotation_6d_to_matrix(pred))
258
- pred = pred.reshape(pred.shape[0], -1)
259
-
260
- if config.Model.model_name == 's2g_LS3DCG':
261
- pred = torch.cat([pred[:, :3], pred[:, 103:], pred[:, 3:103]], dim=-1)
262
- else:
263
- pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)
264
-
265
- # pred[:, 9:12] = global_orient
266
- pred = part2full(pred, stand)
267
- if face:
268
- pred = torch.cat([pred[:, :3], body_static.repeat(pred.shape[0], 1), pred[:, -100:]], dim=-1)
269
- # result_list[0] = poses2pred(result_list[0], stand)
270
- # if gt_0 is None:
271
- # gt_0 = gt
272
- # pred = pred2poses(pred, gt_0)
273
- # result_list[0] = poses2poses(result_list[0], gt_0)
274
-
275
- result_list.append(pred)
276
-
277
-
278
- vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)
279
-
280
- result_list = [res.to('cpu') for res in result_list]
281
- dict = np.concatenate(result_list[:], axis=0)
282
-
283
- rendertool._render_sequences(cur_wav_file, vertices_list, stand=stand, face=face, whole_body=args.whole_body)
284
- return "result.mp4"
285
-
286
- def main():
287
-
288
- iface = gr.Interface(fn=infer, inputs=["audio",
289
- gr.Radio(["Oliver", "Chemistry", "Seth", "Conan"]),
290
- gr.Radio(["Stand", "Sit", "Only Face"]),
291
- ],
292
- outputs="video",
293
- examples=[[os.path.join(os.path.dirname(__file__), "demo_audio/style.wav"), "Oliver", "Sit"]])
294
- iface.launch(debug=True)
295
-
296
-
297
- if __name__ == '__main__':
298
- main()
 
1
  import gradio as gr
2
+ import subprocess
3
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
 
5
 
 
 
 
6
 
 
 
 
 
7
 
 
 
 
 
 
 
 
8
 
 
 
 
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
 
 
 
12
 
 
 
 
 
 
13
 
 
14
 
15
 
 
 
 
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
 
19
 
 
 
 
20
 
 
 
 
 
 
 
 
 
 
 
21
 
 
22
 
 
 
 
 
 
23
 
 
 
 
24
 
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ def run_talkshow_model(audio_file):
28
+ # Path to the TalkSHOW demo script
29
+ demo_script = 'scripts/demo.py'
30
+
31
+ # Configuration and model parameters
32
+ config_file = './config/LS3DCG.json'
33
+ body_model_name = 's2g_LS3DCG'
34
+ body_model_path = 'experiments/2022-10-19-smplx_S2G-LS3DCG/ckpt-99.pth'
35
+
36
+ # Path of the uploaded audio file
37
+ audio_file_path = audio_file
38
+
39
+ # Path where the output .mp4 video will be saved
40
+ output_video_path = './output_video/result.mp4'
41
+
42
+ # Run the demo.py script with the necessary arguments
43
+ command = [
44
+ 'python', demo_script,
45
+ '--config_file', config_file,
46
+ '--infer',
47
+ '--audio_file', audio_file_path,
48
+ '--body_model_name', body_model_name,
49
+ '--body_model_path', body_model_path,
50
+ '--id', '0',
51
+ '--output', output_video_path # Assuming demo.py has an argument to specify output
52
+ ]
53
+
54
+ try:
55
+ # Run the subprocess and capture output
56
+ subprocess.run(command, check=True, capture_output=True, text=True)
57
+
58
+ # Check if the .mp4 file is generated
59
+ if os.path.exists(output_video_path):
60
+ return output_video_path # Return the path of the generated video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  else:
62
+ return "Error: Output video not generated."
63
+
64
+ except subprocess.CalledProcessError as e:
65
+ return f"Error running the model: {e.stderr}" # Return the error message
66
+
67
+ # Set up the Gradio interface
68
+ interface = gr.Interface(
69
+ fn=run_talkshow_model,
70
+ inputs=gr.Audio(source="upload", type="filepath"),
71
+ outputs=gr.Video(), # Use gr.Video to output the generated .mp4 video
72
+ title="TalkSHOW: Audio to Mesh"
73
+ )
74
+
75
+ # Launch the interface
76
+ if __name__ == "__main__":
77
+ interface.launch()