Update model/utils.py
Browse files- model/utils.py +2 -2
model/utils.py
CHANGED
@@ -562,8 +562,8 @@ def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
|
562 |
|
563 |
ckpt_type = ckpt_path.split(".")[-1]
|
564 |
if ckpt_type == "safetensors":
|
565 |
-
|
566 |
-
|
567 |
else:
|
568 |
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
|
569 |
|
|
|
562 |
|
563 |
ckpt_type = ckpt_path.split(".")[-1]
|
564 |
if ckpt_type == "safetensors":
|
565 |
+
from safetensors.torch import load_file
|
566 |
+
checkpoint = load_file(ckpt_path, device=device)
|
567 |
else:
|
568 |
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
|
569 |
|