Spaces:
Running
on
Zero
Running
on
Zero
Update core/test_xportrait.py
Browse files- core/test_xportrait.py +2 -2
core/test_xportrait.py
CHANGED
@@ -221,7 +221,7 @@ def x_portrait_data_prep(source_image_path, driving_video_path, device, best_fra
|
|
221 |
# You can now use the modified state_dict without the deleted keys
|
222 |
def load_state_dict(model, ckpt_path, reinit_hint_block=False, strict=True, map_location="cpu"):
|
223 |
print(f"Loading model state dict from {ckpt_path} ...")
|
224 |
-
state_dict = torch.load(ckpt_path, map_location=map_location)
|
225 |
state_dict = state_dict.get('state_dict', state_dict)
|
226 |
if reinit_hint_block:
|
227 |
print("Ignoring hint block parameters from checkpoint!")
|
@@ -341,7 +341,7 @@ def visualize_mm(args, name, batch_data, infer_model, nSample, local_image_dir,
|
|
341 |
|
342 |
noise = pre_noise.to(c_cross.device)
|
343 |
|
344 |
-
with torch.
|
345 |
infer_model.to(args.device)
|
346 |
infer_model.eval()
|
347 |
|
|
|
221 |
# You can now use the modified state_dict without the deleted keys
|
222 |
def load_state_dict(model, ckpt_path, reinit_hint_block=False, strict=True, map_location="cpu"):
|
223 |
print(f"Loading model state dict from {ckpt_path} ...")
|
224 |
+
state_dict = torch.load(ckpt_path, map_location=map_location, weights_only=True)
|
225 |
state_dict = state_dict.get('state_dict', state_dict)
|
226 |
if reinit_hint_block:
|
227 |
print("Ignoring hint block parameters from checkpoint!")
|
|
|
341 |
|
342 |
noise = pre_noise.to(c_cross.device)
|
343 |
|
344 |
+
with torch.amp.autocast("cuda", enabled=args.use_fp16, dtype=FP16_DTYPE):
|
345 |
infer_model.to(args.device)
|
346 |
infer_model.eval()
|
347 |
|