Gregniuki commited on
Commit
0798588
1 Parent(s): c0ed55a

Update model/utils.py

Browse files
Files changed (1) hide show
  1. model/utils.py +14 -28
model/utils.py CHANGED
@@ -557,38 +557,24 @@ def repetition_found(text, length = 2, tolerance = 10):
557
 
558
  # load model checkpoint for inference
559
 
560
- def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True):
561
- if dtype is None:
562
- dtype = (
563
- torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
564
- )
565
- model = model.to(dtype)
566
 
567
  ckpt_type = ckpt_path.split(".")[-1]
568
  if ckpt_type == "safetensors":
569
  from safetensors.torch import load_file
570
-
571
- checkpoint = load_file(ckpt_path)
572
  else:
573
- checkpoint = torch.load(ckpt_path, weights_only=True, map_location=torch.device('cpu'))
574
- if use_ema:
 
 
575
  if ckpt_type == "safetensors":
576
- checkpoint = {"ema_model_state_dict": checkpoint}
577
- checkpoint["model_state_dict"] = {
578
- k.replace("ema_model.", ""): v
579
- for k, v in checkpoint["ema_model_state_dict"].items()
580
- if k not in ["initted", "step"]
581
- }
582
-
583
- # patch for backward compatibility, 305e3ea
584
- for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
585
- if key in checkpoint["model_state_dict"]:
586
- del checkpoint["model_state_dict"][key]
587
-
588
- model.load_state_dict(checkpoint["model_state_dict"])
589
  else:
590
- if ckpt_type == "safetensors":
591
- checkpoint = {"model_state_dict": checkpoint}
592
- model.load_state_dict(checkpoint["model_state_dict"])
593
-
594
- return model.to(device)
 
557
 
558
  # load model checkpoint for inference
559
 
560
+ def load_checkpoint(model, ckpt_path, device, use_ema = True):
561
+ from ema_pytorch import EMA
 
 
 
 
562
 
563
  ckpt_type = ckpt_path.split(".")[-1]
564
  if ckpt_type == "safetensors":
565
  from safetensors.torch import load_file
566
+ checkpoint = load_file(ckpt_path, device=device)
 
567
  else:
568
+ checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
569
+
570
+ if use_ema == True:
571
+ ema_model = EMA(model, include_online_model = False).to(device)
572
  if ckpt_type == "safetensors":
573
+ ema_model.load_state_dict(checkpoint)
574
+ else:
575
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
576
+ ema_model.copy_params_from_ema_to_model()
 
 
 
 
 
 
 
 
 
577
  else:
578
+ model.load_state_dict(checkpoint['model_state_dict'])
579
+
580
+ return model