Spaces:
Sleeping
Sleeping
Update scripts/demo.py
Browse files- scripts/demo.py +12 -13
scripts/demo.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import os
|
2 |
import sys
|
3 |
-
|
4 |
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
5 |
sys.path.append(os.getcwd())
|
6 |
-
|
7 |
from transformers import Wav2Vec2Processor
|
8 |
from glob import glob
|
9 |
|
@@ -24,8 +24,8 @@ from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis
|
|
24 |
from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
|
25 |
from visualise.rendering import RenderTool
|
26 |
|
27 |
-
|
28 |
-
|
29 |
|
30 |
def init_model(model_name, model_path, args, config):
|
31 |
if model_name == 's2g_face':
|
@@ -156,7 +156,7 @@ global_orient = torch.tensor([3.0747, -0.0158, -0.0152])
|
|
156 |
|
157 |
|
158 |
def infer(g_body, g_face, smplx_model, rendertool, config, args):
|
159 |
-
betas = torch.zeros([1, 300], dtype=torch.float64).to(
|
160 |
am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
|
161 |
am_sr = 16000
|
162 |
num_sample = args.num_sample
|
@@ -165,7 +165,7 @@ def infer(g_body, g_face, smplx_model, rendertool, config, args):
|
|
165 |
face = args.only_face
|
166 |
stand = args.stand
|
167 |
if face:
|
168 |
-
body_static = torch.zeros([1, 162], device=
|
169 |
body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
|
170 |
|
171 |
result_list = []
|
@@ -179,7 +179,7 @@ def infer(g_body, g_face, smplx_model, rendertool, config, args):
|
|
179 |
am=am,
|
180 |
am_sr=am_sr
|
181 |
)
|
182 |
-
pred_face = torch.tensor(pred_face).squeeze().to(
|
183 |
# pred_face = torch.zeros([gt.shape[0], 105])
|
184 |
|
185 |
if config.Data.pose.convert_to_6d:
|
@@ -190,7 +190,7 @@ def infer(g_body, g_face, smplx_model, rendertool, config, args):
|
|
190 |
pred_jaw = pred_face[:, :3]
|
191 |
pred_face = pred_face[:, 3:]
|
192 |
|
193 |
-
id = torch.tensor([id], device=
|
194 |
|
195 |
for i in range(num_sample):
|
196 |
pred_res = g_body.infer_on_audio(cur_wav_file,
|
@@ -202,7 +202,7 @@ def infer(g_body, g_face, smplx_model, rendertool, config, args):
|
|
202 |
fps=30,
|
203 |
w_pre=False
|
204 |
)
|
205 |
-
pred = torch.tensor(pred_res).squeeze().to(
|
206 |
|
207 |
if pred.shape[0] < pred_face.shape[0]:
|
208 |
repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
|
@@ -250,9 +250,8 @@ def infer(g_body, g_face, smplx_model, rendertool, config, args):
|
|
250 |
def main():
|
251 |
parser = parse_args()
|
252 |
args = parser.parse_args()
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
|
257 |
config = load_JsonConfig(args.config_file)
|
258 |
|
@@ -292,7 +291,7 @@ def main():
|
|
292 |
create_transl=False,
|
293 |
# gender='ne',
|
294 |
dtype=dtype, )
|
295 |
-
smplx_model = smpl.create(**model_params).to(
|
296 |
print('init rendertool...')
|
297 |
rendertool = RenderTool('visualise/video/' + config.Log.name)
|
298 |
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
+
|
4 |
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
5 |
sys.path.append(os.getcwd())
|
6 |
+
|
7 |
from transformers import Wav2Vec2Processor
|
8 |
from glob import glob
|
9 |
|
|
|
24 |
from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
|
25 |
from visualise.rendering import RenderTool
|
26 |
|
27 |
+
import time
|
28 |
+
|
29 |
|
30 |
def init_model(model_name, model_path, args, config):
|
31 |
if model_name == 's2g_face':
|
|
|
156 |
|
157 |
|
158 |
def infer(g_body, g_face, smplx_model, rendertool, config, args):
|
159 |
+
betas = torch.zeros([1, 300], dtype=torch.float64).to('cuda')
|
160 |
am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
|
161 |
am_sr = 16000
|
162 |
num_sample = args.num_sample
|
|
|
165 |
face = args.only_face
|
166 |
stand = args.stand
|
167 |
if face:
|
168 |
+
body_static = torch.zeros([1, 162], device='cuda')
|
169 |
body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
|
170 |
|
171 |
result_list = []
|
|
|
179 |
am=am,
|
180 |
am_sr=am_sr
|
181 |
)
|
182 |
+
pred_face = torch.tensor(pred_face).squeeze().to('cuda')
|
183 |
# pred_face = torch.zeros([gt.shape[0], 105])
|
184 |
|
185 |
if config.Data.pose.convert_to_6d:
|
|
|
190 |
pred_jaw = pred_face[:, :3]
|
191 |
pred_face = pred_face[:, 3:]
|
192 |
|
193 |
+
id = torch.tensor([id], device='cuda')
|
194 |
|
195 |
for i in range(num_sample):
|
196 |
pred_res = g_body.infer_on_audio(cur_wav_file,
|
|
|
202 |
fps=30,
|
203 |
w_pre=False
|
204 |
)
|
205 |
+
pred = torch.tensor(pred_res).squeeze().to('cuda')
|
206 |
|
207 |
if pred.shape[0] < pred_face.shape[0]:
|
208 |
repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
|
|
|
250 |
def main():
|
251 |
parser = parse_args()
|
252 |
args = parser.parse_args()
|
253 |
+
device = torch.device(args.gpu)
|
254 |
+
torch.cuda.set_device(device)
|
|
|
255 |
|
256 |
config = load_JsonConfig(args.config_file)
|
257 |
|
|
|
291 |
create_transl=False,
|
292 |
# gender='ne',
|
293 |
dtype=dtype, )
|
294 |
+
smplx_model = smpl.create(**model_params).to('cuda')
|
295 |
print('init rendertool...')
|
296 |
rendertool = RenderTool('visualise/video/' + config.Log.name)
|
297 |
|