Update model/utils.py
Browse files- model/utils.py +2 -2
model/utils.py
CHANGED
@@ -437,7 +437,7 @@ def load_asr_model(lang, ckpt_dir = ""):
|
|
437 |
elif lang == "en":
|
438 |
from faster_whisper import WhisperModel
|
439 |
model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
|
440 |
-
model = WhisperModel(model_size, device="
|
441 |
return model
|
442 |
|
443 |
|
@@ -565,7 +565,7 @@ def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
|
565 |
from safetensors.torch import load_file
|
566 |
checkpoint = load_file(ckpt_path, device=device)
|
567 |
else:
|
568 |
-
checkpoint = torch.load(ckpt_path, weights_only=
|
569 |
|
570 |
if use_ema == True:
|
571 |
ema_model = EMA(model, include_online_model = False).to(device)
|
|
|
437 |
elif lang == "en":
|
438 |
from faster_whisper import WhisperModel
|
439 |
model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
|
440 |
+
model = WhisperModel(model_size, device="cpu", compute_type="float16")
|
441 |
return model
|
442 |
|
443 |
|
|
|
565 |
from safetensors.torch import load_file
|
566 |
checkpoint = load_file(ckpt_path, device=device)
|
567 |
else:
|
568 |
+
checkpoint = torch.load(ckpt_path, weights_only=False, map_location=device)
|
569 |
|
570 |
if use_ema == True:
|
571 |
ema_model = EMA(model, include_online_model = False).to(device)
|