Spaces:
Build error
Build error
import re | |
import time | |
import gradio as gr | |
import torch | |
import commons | |
import utils | |
from models import SynthesizerTrn | |
from text import text_to_sequence | |
config_json = "configs//multi.json" | |
pth_path = "model//G=728.pth" | |
lan = ["中文", "日文", "英文", "德语", "克罗地亚语"] | |
def get_text(text, hps, cleaned=False): | |
if cleaned: | |
text_norm = text_to_sequence(text, hps.symbols, []) | |
else: | |
text_norm = text_to_sequence(text, hps.symbols, hps.data.text_cleaners) | |
if hps.data.add_blank: | |
text_norm = commons.intersperse(text_norm, 0) | |
text_norm = torch.LongTensor(text_norm) | |
return text_norm | |
def get_label(text, label): | |
if f'[{label}]' in text: | |
return True, text.replace(f'[{label}]', '') | |
else: | |
return False, text | |
def sle(language, tts_input0): | |
if language == "中文": | |
tts_input1 = "[ZH]" + tts_input0.replace('\n', '。') + "[ZH]" | |
return tts_input1 | |
if language == "英文": | |
tts_input1 = "[EN]" + tts_input0.replace('\n', '.') + "[EN]" | |
return tts_input1 | |
elif language == "日文": | |
tts_input1 = "[JA]" + tts_input0.replace('\n', '。') + "[JA]" | |
return tts_input1 | |
elif language == "德语": | |
tts_input1 = "[DE]" + tts_input0.replace('\n', '.') + "[DE]" | |
return tts_input1 | |
elif language == "克罗地亚语": | |
tts_input1 = "[CR]" + tts_input0.replace('\n', '.') + "[CR]" | |
return tts_input1 | |
def load_model(config_json, pth_path): | |
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
hps_ms = utils.get_hparams_from_file(f"{config_json}") | |
n_speakers = hps_ms.data.n_speakers if 'n_speakers' in hps_ms.data.keys() else 0 | |
n_symbols = len(hps_ms.symbols) if 'symbols' in hps_ms.keys() else 0 | |
net_g_ms = SynthesizerTrn( | |
n_symbols, | |
hps_ms.data.filter_length // 2 + 1, | |
hps_ms.train.segment_size // hps_ms.data.hop_length, | |
n_speakers=n_speakers, | |
**hps_ms.model).to(dev) | |
_ = net_g_ms.eval() | |
_ = utils.load_checkpoint(pth_path, net_g_ms) | |
return net_g_ms | |
net_g_ms = load_model(config_json, pth_path) | |
def infer(language, text, speaker_id, n_scale=0.667, n_scale_w=0.8, l_scale=1): | |
hps_ms = utils.get_hparams_from_file(f"{config_json}") | |
stn_tst = get_text(sle(language, text), hps_ms) | |
speaker_id = int(i_dict[speaker_id]) | |
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
with torch.no_grad(): | |
x_tst = stn_tst.unsqueeze(0).to(dev) | |
t1 = time.time() | |
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev) | |
sid = torch.LongTensor([speaker_id]).to(dev) | |
audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=n_scale, noise_scale_w=n_scale_w, | |
length_scale=l_scale)[0][ | |
0, 0].data.cpu().float().numpy() | |
t2 = time.time() | |
spending_time = "推理时间:" + str(t2 - t1) + "s" | |
print(spending_time) | |
return (hps_ms.data.sampling_rate, audio) | |
i_dict = { | |
"ことり(JAP)": 1, | |
"うみ(JAP)": 0, | |
"えり(JAP)": 6, | |
"小文(CHN)": 9, | |
"小菊(CHN)": 10, | |
"小标(CHN)": 11, | |
"Helena(HRV)": 14, | |
"Erika(DEU)": 19, | |
"Diana(ENG)": 26, | |
"Michelle(ENG)": 30, | |
} | |
idols = [ | |
"ことり(JAP)", | |
"うみ(JAP)", | |
"えり(JAP)", | |
"小文(CHN)", | |
"小菊(CHN)", | |
"小标(CHN)", | |
"Helena(HRV)", | |
"Erika(DEU)", | |
"Diana(ENG)", | |
"Michelle(ENG)" | |
] | |
app = gr.Blocks() | |
with app: | |
with gr.Tabs(): | |
with gr.TabItem("幻音文字转语音"): | |
tts_input1 = gr.TextArea(label="支持英语、日语、德语、中文、克罗地亚语", value="大家好") | |
language = gr.Dropdown(label="选择语言", choices=lan, value="中文", interactive=True) | |
para_input1 = gr.Slider(minimum=0.01, maximum=1.0, label="更改噪声比例", value=0.667) | |
para_input2 = gr.Slider(minimum=0.01, maximum=1.0, label="更改噪声偏差", value=0.8) | |
para_input3 = gr.Slider(minimum=0.1, maximum=10, label="更改时间比例", value=1) | |
tts_submit = gr.Button("Generate", variant="primary") | |
speaker1 = gr.Dropdown(label="选择说话人", choices=idols, value="小文(CHN)", interactive=True) | |
tts_output2 = gr.Audio(label="Output") | |
tts_submit.click(infer, [language, tts_input1, speaker1, para_input1, para_input2, para_input3], | |
[tts_output2]) | |
app.launch() | |