Spanicin commited on
Commit
49b64d3
·
verified ·
1 Parent(s): 7c42cce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -1
app.py CHANGED
@@ -128,11 +128,22 @@ def process_chunk(audio_chunk, args):
128
  print("crop_pic_path",crop_pic_path)
129
  print("crop_info",crop_info)
130
  torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
131
  batch = get_data(first_coeff_path, audio_chunk, args.device, ref_eyeblink_coeff_path=None, still=args.still)
132
  audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path,
133
  audio2exp_checkpoint, audio2exp_yaml_path,
134
  wav2lip_checkpoint, args.device)
135
- coeff_path = audio_to_coeff.generate(batch, args.result_dir, args.pose_style, ref_pose_coeff_path=None)
136
 
137
  # Further processing with animate_from_coeff using the coeff_path
138
  animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint,
 
128
  print("crop_pic_path",crop_pic_path)
129
  print("crop_info",crop_info)
130
  torch.cuda.empty_cache()
131
+
132
+ if args.ref_pose is not None:
133
+ ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
134
+ ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname)
135
+ os.makedirs(ref_pose_frame_dir, exist_ok=True)
136
+ ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir)
137
+ print('ref_eyeblink_coeff_path',ref_pose_coeff_path)
138
+ else:
139
+ ref_pose_coeff_path = None
140
+ print('ref_eyeblink_coeff_path',ref_pose_coeff_path)
141
+
142
  batch = get_data(first_coeff_path, audio_chunk, args.device, ref_eyeblink_coeff_path=None, still=args.still)
143
  audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path,
144
  audio2exp_checkpoint, audio2exp_yaml_path,
145
  wav2lip_checkpoint, args.device)
146
+ coeff_path = audio_to_coeff.generate(batch, args.result_dir, args.pose_style, ref_pose_coeff_path)
147
 
148
  # Further processing with animate_from_coeff using the coeff_path
149
  animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint,