Update model/utils.py
Browse files- model/utils.py +14 -28
model/utils.py
CHANGED
@@ -557,38 +557,24 @@ def repetition_found(text, length = 2, tolerance = 10):
|
|
557 |
|
558 |
# load model checkpoint for inference
|
559 |
|
560 |
-
def load_checkpoint(model, ckpt_path, device,
|
561 |
-
|
562 |
-
dtype = (
|
563 |
-
torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
|
564 |
-
)
|
565 |
-
model = model.to(dtype)
|
566 |
|
567 |
ckpt_type = ckpt_path.split(".")[-1]
|
568 |
if ckpt_type == "safetensors":
|
569 |
from safetensors.torch import load_file
|
570 |
-
|
571 |
-
checkpoint = load_file(ckpt_path)
|
572 |
else:
|
573 |
-
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=
|
574 |
-
|
|
|
|
|
575 |
if ckpt_type == "safetensors":
|
576 |
-
checkpoint
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
if k not in ["initted", "step"]
|
581 |
-
}
|
582 |
-
|
583 |
-
# patch for backward compatibility, 305e3ea
|
584 |
-
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
|
585 |
-
if key in checkpoint["model_state_dict"]:
|
586 |
-
del checkpoint["model_state_dict"][key]
|
587 |
-
|
588 |
-
model.load_state_dict(checkpoint["model_state_dict"])
|
589 |
else:
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
return model.to(device)
|
|
|
557 |
|
558 |
# load model checkpoint for inference
|
559 |
|
560 |
+
def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
561 |
+
from ema_pytorch import EMA
|
|
|
|
|
|
|
|
|
562 |
|
563 |
ckpt_type = ckpt_path.split(".")[-1]
|
564 |
if ckpt_type == "safetensors":
|
565 |
from safetensors.torch import load_file
|
566 |
+
checkpoint = load_file(ckpt_path, device=device)
|
|
|
567 |
else:
|
568 |
+
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
|
569 |
+
|
570 |
+
if use_ema == True:
|
571 |
+
ema_model = EMA(model, include_online_model = False).to(device)
|
572 |
if ckpt_type == "safetensors":
|
573 |
+
ema_model.load_state_dict(checkpoint)
|
574 |
+
else:
|
575 |
+
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
576 |
+
ema_model.copy_params_from_ema_to_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
577 |
else:
|
578 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
579 |
+
|
580 |
+
return model
|
|
|
|