Spaces:
No application file
No application file
File size: 4,201 Bytes
3883c60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
import traceback
from collections import OrderedDict
import torch
def savee(ckpt, sr, if_f0, name, epoch, version):
try:
opt = OrderedDict()
opt["weight"] = {}
for key in ckpt.keys():
if "enc_q" in key:
continue
opt["weight"][key] = ckpt[key].half()
if sr == "40k":
opt["config"] = [
1025,
32,
192,
192,
768,
2,
6,
3,
0,
"1",
[3, 7, 11],
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
[10, 10, 2, 2],
512,
[16, 16, 4, 4],
109,
256,
40000,
]
elif sr == "48k":
opt["config"] = [
1025,
32,
192,
192,
768,
2,
6,
3,
0,
"1",
[3, 7, 11],
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
[10, 6, 2, 2, 2],
512,
[16, 16, 4, 4, 4],
109,
256,
48000,
]
elif sr == "32k":
opt["config"] = [
513,
32,
192,
192,
768,
2,
6,
3,
0,
"1",
[3, 7, 11],
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
[10, 4, 2, 2, 2],
512,
[16, 16, 4, 4, 4],
109,
256,
32000,
]
opt["info"] = "%sepoch" % epoch
opt["sr"] = sr
opt["f0"] = if_f0
opt["version"] = version
os.makedirs(os.path.dirname(name), exist_ok=True)
torch.save(opt, name)
return "Success."
except:
return traceback.format_exc()
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
torch.save(
{
"model": state_dict,
"iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items(): # 模型需要的shape
try:
new_state_dict[k] = saved_state_dict[k]
if saved_state_dict[k].shape != state_dict[k].shape:
print(
"shape-%s-mismatch|need-%s|get-%s"
% (k, state_dict[k].shape, saved_state_dict[k].shape)
) #
raise KeyError
except:
# logger.info(traceback.format_exc())
new_state_dict[k] = v # 模型自带的随机值
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict, strict=False)
else:
model.load_state_dict(new_state_dict, strict=False)
iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]
if (
optimizer is not None and load_opt == 1
): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
# try:
optimizer.load_state_dict(checkpoint_dict["optimizer"])
# except:
# traceback.print_exc()
return model, optimizer, learning_rate, iteration |