insanecoder69 commited on
Commit
616cd75
·
verified ·
1 Parent(s): 34e1665

Update scripts/demo.py

Browse files
Files changed (1) hide show
  1. scripts/demo.py +12 -13
scripts/demo.py CHANGED
@@ -1,9 +1,9 @@
1
  import os
2
  import sys
3
- # os.environ["PYOPENGL_PLATFORM"] = "egl"
4
  os.environ['CUDA_VISIBLE_DEVICES'] = '0'
5
  sys.path.append(os.getcwd())
6
- os.environ["PYOPENGL_PLATFORM"] = 'osmesa'
7
  from transformers import Wav2Vec2Processor
8
  from glob import glob
9
 
@@ -24,8 +24,8 @@ from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis
24
  from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
25
  from visualise.rendering import RenderTool
26
 
27
- global device
28
- device = 'cpu'
29
 
30
  def init_model(model_name, model_path, args, config):
31
  if model_name == 's2g_face':
@@ -156,7 +156,7 @@ global_orient = torch.tensor([3.0747, -0.0158, -0.0152])
156
 
157
 
158
  def infer(g_body, g_face, smplx_model, rendertool, config, args):
159
- betas = torch.zeros([1, 300], dtype=torch.float64).to(device)
160
  am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
161
  am_sr = 16000
162
  num_sample = args.num_sample
@@ -165,7 +165,7 @@ def infer(g_body, g_face, smplx_model, rendertool, config, args):
165
  face = args.only_face
166
  stand = args.stand
167
  if face:
168
- body_static = torch.zeros([1, 162], device=device)
169
  body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
170
 
171
  result_list = []
@@ -179,7 +179,7 @@ def infer(g_body, g_face, smplx_model, rendertool, config, args):
179
  am=am,
180
  am_sr=am_sr
181
  )
182
- pred_face = torch.tensor(pred_face).squeeze().to(device)
183
  # pred_face = torch.zeros([gt.shape[0], 105])
184
 
185
  if config.Data.pose.convert_to_6d:
@@ -190,7 +190,7 @@ def infer(g_body, g_face, smplx_model, rendertool, config, args):
190
  pred_jaw = pred_face[:, :3]
191
  pred_face = pred_face[:, 3:]
192
 
193
- id = torch.tensor([id], device=device)
194
 
195
  for i in range(num_sample):
196
  pred_res = g_body.infer_on_audio(cur_wav_file,
@@ -202,7 +202,7 @@ def infer(g_body, g_face, smplx_model, rendertool, config, args):
202
  fps=30,
203
  w_pre=False
204
  )
205
- pred = torch.tensor(pred_res).squeeze().to(device)
206
 
207
  if pred.shape[0] < pred_face.shape[0]:
208
  repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
@@ -250,9 +250,8 @@ def infer(g_body, g_face, smplx_model, rendertool, config, args):
250
  def main():
251
  parser = parse_args()
252
  args = parser.parse_args()
253
- # device = torch.device(args.gpu)
254
- # torch.cuda.set_device(device)
255
-
256
 
257
  config = load_JsonConfig(args.config_file)
258
 
@@ -292,7 +291,7 @@ def main():
292
  create_transl=False,
293
  # gender='ne',
294
  dtype=dtype, )
295
- smplx_model = smpl.create(**model_params).to(device)
296
  print('init rendertool...')
297
  rendertool = RenderTool('visualise/video/' + config.Log.name)
298
 
 
1
  import os
2
  import sys
3
+
4
  os.environ['CUDA_VISIBLE_DEVICES'] = '0'
5
  sys.path.append(os.getcwd())
6
+
7
  from transformers import Wav2Vec2Processor
8
  from glob import glob
9
 
 
24
  from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
25
  from visualise.rendering import RenderTool
26
 
27
+ import time
28
+
29
 
30
  def init_model(model_name, model_path, args, config):
31
  if model_name == 's2g_face':
 
156
 
157
 
158
  def infer(g_body, g_face, smplx_model, rendertool, config, args):
159
+ betas = torch.zeros([1, 300], dtype=torch.float64).to('cuda')
160
  am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
161
  am_sr = 16000
162
  num_sample = args.num_sample
 
165
  face = args.only_face
166
  stand = args.stand
167
  if face:
168
+ body_static = torch.zeros([1, 162], device='cuda')
169
  body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
170
 
171
  result_list = []
 
179
  am=am,
180
  am_sr=am_sr
181
  )
182
+ pred_face = torch.tensor(pred_face).squeeze().to('cuda')
183
  # pred_face = torch.zeros([gt.shape[0], 105])
184
 
185
  if config.Data.pose.convert_to_6d:
 
190
  pred_jaw = pred_face[:, :3]
191
  pred_face = pred_face[:, 3:]
192
 
193
+ id = torch.tensor([id], device='cuda')
194
 
195
  for i in range(num_sample):
196
  pred_res = g_body.infer_on_audio(cur_wav_file,
 
202
  fps=30,
203
  w_pre=False
204
  )
205
+ pred = torch.tensor(pred_res).squeeze().to('cuda')
206
 
207
  if pred.shape[0] < pred_face.shape[0]:
208
  repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
 
250
  def main():
251
  parser = parse_args()
252
  args = parser.parse_args()
253
+ device = torch.device(args.gpu)
254
+ torch.cuda.set_device(device)
 
255
 
256
  config = load_JsonConfig(args.config_file)
257
 
 
291
  create_transl=False,
292
  # gender='ne',
293
  dtype=dtype, )
294
+ smplx_model = smpl.create(**model_params).to('cuda')
295
  print('init rendertool...')
296
  rendertool = RenderTool('visualise/video/' + config.Log.name)
297