fffiloni commited on
Commit
854f2ef
·
verified ·
1 Parent(s): 4d3f941

Update core/test_xportrait.py

Browse files
Files changed (1) hide show
  1. core/test_xportrait.py +2 -2
core/test_xportrait.py CHANGED
@@ -221,7 +221,7 @@ def x_portrait_data_prep(source_image_path, driving_video_path, device, best_fra
221
  # You can now use the modified state_dict without the deleted keys
222
  def load_state_dict(model, ckpt_path, reinit_hint_block=False, strict=True, map_location="cpu"):
223
  print(f"Loading model state dict from {ckpt_path} ...")
224
- state_dict = torch.load(ckpt_path, map_location=map_location)
225
  state_dict = state_dict.get('state_dict', state_dict)
226
  if reinit_hint_block:
227
  print("Ignoring hint block parameters from checkpoint!")
@@ -341,7 +341,7 @@ def visualize_mm(args, name, batch_data, infer_model, nSample, local_image_dir,
341
 
342
  noise = pre_noise.to(c_cross.device)
343
 
344
- with torch.cuda.amp.autocast(enabled=args.use_fp16, dtype=FP16_DTYPE):
345
  infer_model.to(args.device)
346
  infer_model.eval()
347
 
 
221
  # You can now use the modified state_dict without the deleted keys
222
  def load_state_dict(model, ckpt_path, reinit_hint_block=False, strict=True, map_location="cpu"):
223
  print(f"Loading model state dict from {ckpt_path} ...")
224
+ state_dict = torch.load(ckpt_path, map_location=map_location, weights_only=True)
225
  state_dict = state_dict.get('state_dict', state_dict)
226
  if reinit_hint_block:
227
  print("Ignoring hint block parameters from checkpoint!")
 
341
 
342
  noise = pre_noise.to(c_cross.device)
343
 
344
+ with torch.amp.autocast("cuda", enabled=args.use_fp16, dtype=FP16_DTYPE):
345
  infer_model.to(args.device)
346
  infer_model.eval()
347