|
import torch
|
|
import json
|
|
import re
|
|
import unicodedata
|
|
import numpy as np
|
|
import regex
|
|
|
|
from types import SimpleNamespace
|
|
from LOAD.models import DurationNet, SynthesizerTrn
|
|
|
|
import scipy.io.wavfile as wav
|
|
|
|
|
|
config_file = "config.json"
|
|
duration_model_path = "duration_model.pth"
|
|
lightspeed_model_path = "gen_630k.pth"
|
|
phone_set_file = "phone_set.json"
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
with open(config_file, "rb") as f:
|
|
hps = json.load(f, object_hook=lambda x: SimpleNamespace(**x))
|
|
|
|
|
|
with open(phone_set_file, "r") as f:
|
|
phone_set = json.load(f)
|
|
|
|
assert phone_set[0][1:-1] == "SEP"
|
|
assert "sil" in phone_set
|
|
sil_idx = phone_set.index("sil")
|
|
|
|
space_re = regex.compile(r"\s+")
|
|
number_re = regex.compile("([0-9]+)")
|
|
digits = ["không", "một", "hai", "ba", "bốn", "năm", "sáu", "bảy", "tám", "chín"]
|
|
num_re = regex.compile(r"([0-9.,]*[0-9])")
|
|
alphabet = "aàáảãạăằắẳẵặâầấẩẫậeèéẻẽẹêềếểễệiìíỉĩịoòóỏõọôồốổỗộơờớởỡợuùúủũụưừứửữựyỳýỷỹỵbcdđghklmnpqrstvx"
|
|
keep_text_and_num_re = regex.compile(rf"[^\s{alphabet}.,0-9]")
|
|
keep_text_re = regex.compile(rf"[^\s{alphabet}]")
|
|
|
|
|
|
def read_number(num: str) -> str:
|
|
if len(num) == 1:
|
|
return digits[int(num)]
|
|
elif len(num) == 2 and num.isdigit():
|
|
n = int(num)
|
|
end = digits[n % 10]
|
|
if n == 10:
|
|
return "mười"
|
|
if n % 10 == 5:
|
|
end = "lăm"
|
|
if n % 10 == 0:
|
|
return digits[n // 10] + " mươi"
|
|
elif n < 20:
|
|
return "mười " + end
|
|
else:
|
|
if n % 10 == 1:
|
|
end = "mốt"
|
|
return digits[n // 10] + " mươi " + end
|
|
elif len(num) == 3 and num.isdigit():
|
|
n = int(num)
|
|
if n % 100 == 0:
|
|
return digits[n // 100] + " trăm"
|
|
elif num[1] == "0":
|
|
return digits[n // 100] + " trăm lẻ " + digits[n % 100]
|
|
else:
|
|
return digits[n // 100] + " trăm " + read_number(num[1:])
|
|
elif len(num) >= 4 and len(num) <= 6 and num.isdigit():
|
|
n = int(num)
|
|
n1 = n // 1000
|
|
return read_number(str(n1)) + " ngàn " + read_number(num[-3:])
|
|
elif "," in num:
|
|
n1, n2 = num.split(",")
|
|
return read_number(n1) + " phẩy " + read_number(n2)
|
|
elif "." in num:
|
|
parts = num.split(".")
|
|
if len(parts) == 2:
|
|
if parts[1] == "000":
|
|
return read_number(parts[0]) + " ngàn"
|
|
elif parts[1].startswith("00"):
|
|
end = digits[int(parts[1][2:])]
|
|
return read_number(parts[0]) + " ngàn lẻ " + end
|
|
else:
|
|
return read_number(parts[0]) + " ngàn " + read_number(parts[1])
|
|
elif len(parts) == 3:
|
|
return (
|
|
read_number(parts[0])
|
|
+ " triệu "
|
|
+ read_number(parts[1])
|
|
+ " ngàn "
|
|
+ read_number(parts[2])
|
|
)
|
|
return num
|
|
|
|
|
|
def text_to_phone_idx(text):
|
|
|
|
text = text.lower()
|
|
|
|
text = unicodedata.normalize("NFKC", text)
|
|
text = text.replace(".", " . ")
|
|
text = text.replace(",", " , ")
|
|
text = text.replace(";", " ; ")
|
|
text = text.replace(":", " : ")
|
|
text = text.replace("!", " ! ")
|
|
text = text.replace("?", " ? ")
|
|
text = text.replace("(", " ( ")
|
|
|
|
text = num_re.sub(r" \1 ", text)
|
|
words = text.split()
|
|
words = [read_number(w) if num_re.fullmatch(w) else w for w in words]
|
|
text = " ".join(words)
|
|
|
|
|
|
text = re.sub(r"\s+", " ", text)
|
|
|
|
text = text.strip()
|
|
|
|
tokens = []
|
|
for c in text:
|
|
|
|
if c in ":,.!?;(":
|
|
tokens.append(sil_idx)
|
|
elif c in phone_set:
|
|
tokens.append(phone_set.index(c))
|
|
elif c == " ":
|
|
|
|
tokens.append(0)
|
|
if tokens[0] != sil_idx:
|
|
|
|
tokens = [sil_idx, 0] + tokens
|
|
if tokens[-1] != sil_idx:
|
|
tokens = tokens + [0, sil_idx]
|
|
return tokens
|
|
|
|
|
|
def text_to_speech(duration_net, generator, text):
|
|
|
|
|
|
|
|
|
|
phone_idx = text_to_phone_idx(text)
|
|
batch = {
|
|
"phone_idx": np.array([phone_idx]),
|
|
"phone_length": np.array([len(phone_idx)]),
|
|
}
|
|
|
|
|
|
phone_length = torch.from_numpy(batch["phone_length"].copy()).long().to(device)
|
|
phone_idx = torch.from_numpy(batch["phone_idx"].copy()).long().to(device)
|
|
with torch.inference_mode():
|
|
phone_duration = duration_net(phone_idx, phone_length)[:, :, 0] * 1000
|
|
phone_duration = torch.where(
|
|
phone_idx == sil_idx, torch.clamp_min(phone_duration, 200), phone_duration
|
|
)
|
|
phone_duration = torch.where(phone_idx == 0, 0, phone_duration)
|
|
|
|
|
|
end_time = torch.cumsum(phone_duration, dim=-1)
|
|
start_time = end_time - phone_duration
|
|
start_frame = start_time / 1000 * hps.data.sampling_rate / hps.data.hop_length
|
|
end_frame = end_time / 1000 * hps.data.sampling_rate / hps.data.hop_length
|
|
spec_length = end_frame.max(dim=-1).values
|
|
pos = torch.arange(0, spec_length.item(), device=device)
|
|
attn = torch.logical_and(
|
|
pos[None, :, None] >= start_frame[:, None, :],
|
|
pos[None, :, None] < end_frame[:, None, :],
|
|
).float()
|
|
with torch.inference_mode():
|
|
y_hat = generator.infer(
|
|
phone_idx, phone_length, spec_length, attn, max_len=None, noise_scale=0.0
|
|
)[0]
|
|
wave = y_hat[0, 0].data.cpu().numpy()
|
|
return (wave * (2**15)).astype(np.int16)
|
|
|
|
|
|
def load_models():
|
|
duration_net = DurationNet(hps.data.vocab_size, 64, 4).to(device)
|
|
duration_net.load_state_dict(torch.load(duration_model_path, map_location=device))
|
|
duration_net = duration_net.eval()
|
|
generator = SynthesizerTrn(
|
|
hps.data.vocab_size,
|
|
hps.data.filter_length // 2 + 1,
|
|
hps.train.segment_size // hps.data.hop_length,
|
|
**vars(hps.model),
|
|
).to(device)
|
|
del generator.enc_q
|
|
ckpt = torch.load(lightspeed_model_path, map_location=device)
|
|
params = {}
|
|
for k, v in ckpt["net_g"].items():
|
|
k = k[7:] if k.startswith("module.") else k
|
|
params[k] = v
|
|
generator.load_state_dict(params, strict=False)
|
|
del ckpt, params
|
|
generator = generator.eval()
|
|
return duration_net, generator
|
|
|
|
|
|
def speak(text):
|
|
|
|
duration_net, generator = load_models()
|
|
paragraphs = text.split("\n")
|
|
clips = []
|
|
max_chunk_length = 400
|
|
|
|
for paragraph in paragraphs:
|
|
paragraph = paragraph.strip()
|
|
if paragraph == "":
|
|
continue
|
|
|
|
chunks = [
|
|
paragraph[i : i + max_chunk_length]
|
|
for i in range(0, len(paragraph), max_chunk_length)
|
|
]
|
|
for chunk in chunks:
|
|
clips.append(text_to_speech(duration_net, generator, chunk))
|
|
|
|
|
|
|
|
|
|
|
|
y = np.concatenate(clips)
|
|
|
|
return hps.data.sampling_rate, y
|
|
|
|
|
|
def textToMp3(text, outWAV):
|
|
sampling_rate, audio = speak(text)
|
|
|
|
|
|
wav.write(outWAV, sampling_rate, audio)
|
|
|
|
|
|
textToMp3('bây giờ là mấy giờ', 'test.wav') |