Spaces:
Runtime error
Runtime error
Commit
·
3530d5c
1
Parent(s):
d7af1a0
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
|
3 |
+
if sys.platform == "darwin":
|
4 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
5 |
+
|
6 |
+
import logging
|
7 |
+
|
8 |
+
logging.getLogger("numba").setLevel(logging.WARNING)
|
9 |
+
logging.getLogger("markdown_it").setLevel(logging.WARNING)
|
10 |
+
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
11 |
+
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
12 |
+
|
13 |
+
logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s")
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import argparse
|
19 |
+
import commons
|
20 |
+
import utils
|
21 |
+
from models import SynthesizerTrn
|
22 |
+
from text.symbols import symbols
|
23 |
+
from text import cleaned_text_to_sequence, get_bert
|
24 |
+
from text.cleaner import clean_text
|
25 |
+
import gradio as gr
|
26 |
+
import webbrowser
|
27 |
+
|
28 |
+
|
29 |
+
net_g = None
|
30 |
+
|
31 |
+
|
32 |
+
def get_text(text, language_str, hps):
|
33 |
+
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
34 |
+
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
35 |
+
|
36 |
+
if hps.data.add_blank:
|
37 |
+
phone = commons.intersperse(phone, 0)
|
38 |
+
tone = commons.intersperse(tone, 0)
|
39 |
+
language = commons.intersperse(language, 0)
|
40 |
+
for i in range(len(word2ph)):
|
41 |
+
word2ph[i] = word2ph[i] * 2
|
42 |
+
word2ph[0] += 1
|
43 |
+
bert = get_bert(norm_text, word2ph, language_str)
|
44 |
+
del word2ph
|
45 |
+
|
46 |
+
assert bert.shape[-1] == len(phone)
|
47 |
+
|
48 |
+
phone = torch.LongTensor(phone)
|
49 |
+
tone = torch.LongTensor(tone)
|
50 |
+
language = torch.LongTensor(language)
|
51 |
+
|
52 |
+
return bert, phone, tone, language
|
53 |
+
|
54 |
+
def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid):
|
55 |
+
global net_g
|
56 |
+
bert, phones, tones, lang_ids = get_text(text, "ZH", hps)
|
57 |
+
with torch.no_grad():
|
58 |
+
x_tst=phones.to(device).unsqueeze(0)
|
59 |
+
tones=tones.to(device).unsqueeze(0)
|
60 |
+
lang_ids=lang_ids.to(device).unsqueeze(0)
|
61 |
+
bert = bert.to(device).unsqueeze(0)
|
62 |
+
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
63 |
+
del phones
|
64 |
+
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
65 |
+
audio = net_g.infer(x_tst, x_tst_lengths, speakers, tones, lang_ids, bert, sdp_ratio=sdp_ratio
|
66 |
+
, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.cpu().float().numpy()
|
67 |
+
del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers
|
68 |
+
return audio
|
69 |
+
|
70 |
+
def tts_fn(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale):
|
71 |
+
with torch.no_grad():
|
72 |
+
audio = infer(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, sid=speaker)
|
73 |
+
return "Success", (hps.data.sampling_rate, audio)
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == "__main__":
|
77 |
+
parser = argparse.ArgumentParser()
|
78 |
+
parser.add_argument("--model_dir", default="./logs/maolei/G_4800.pth", help="path of your model")
|
79 |
+
parser.add_argument("--config_dir", default="./configs/config.json", help="path of your config file")
|
80 |
+
parser.add_argument("--share", default=False, help="make link public")
|
81 |
+
parser.add_argument("-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log")
|
82 |
+
|
83 |
+
args = parser.parse_args()
|
84 |
+
if args.debug:
|
85 |
+
logger.info("Enable DEBUG-LEVEL log")
|
86 |
+
logging.basicConfig(level=logging.DEBUG)
|
87 |
+
hps = utils.get_hparams_from_file(args.config_dir)
|
88 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
89 |
+
'''
|
90 |
+
device = (
|
91 |
+
"cuda:0"
|
92 |
+
if torch.cuda.is_available()
|
93 |
+
else (
|
94 |
+
"mps"
|
95 |
+
if sys.platform == "darwin" and torch.backends.mps.is_available()
|
96 |
+
else "cpu"
|
97 |
+
)
|
98 |
+
)
|
99 |
+
'''
|
100 |
+
net_g = SynthesizerTrn(
|
101 |
+
len(symbols),
|
102 |
+
hps.data.filter_length // 2 + 1,
|
103 |
+
hps.train.segment_size // hps.data.hop_length,
|
104 |
+
n_speakers=hps.data.n_speakers,
|
105 |
+
**hps.model).to(device)
|
106 |
+
_ = net_g.eval()
|
107 |
+
|
108 |
+
_ = utils.load_checkpoint(args.model_dir, net_g, None, skip_optimizer=True)
|
109 |
+
|
110 |
+
speaker_ids = hps.data.spk2id
|
111 |
+
speakers = list(speaker_ids.keys())
|
112 |
+
with gr.Blocks() as app:
|
113 |
+
with gr.Row():
|
114 |
+
with gr.Column():
|
115 |
+
|
116 |
+
text = gr.TextArea(label="Text", placeholder="Input Text Here",
|
117 |
+
value="猫雷最强!")
|
118 |
+
speaker = gr.Dropdown(choices=speakers, value=speakers[0], label='Speaker')
|
119 |
+
sdp_ratio = gr.Slider(minimum=0.1, maximum=1, value=0.2, step=0.01, label='SDP/DP混合比')
|
120 |
+
noise_scale = gr.Slider(minimum=0.1, maximum=1, value=0.5, step=0.01, label='感情调节')
|
121 |
+
noise_scale_w = gr.Slider(minimum=0.1, maximum=1, value=0.9, step=0.01, label='音素长度')
|
122 |
+
length_scale = gr.Slider(minimum=0.1, maximum=2, value=1, step=0.01, label='生成长度')
|
123 |
+
language = gr.Dropdown(choices=languages, value=languages[0], label="选择语言(该模型mix有问题先别选)" )
|
124 |
+
btn = gr.Button("点击生成", variant="primary")
|
125 |
+
with gr.Column():
|
126 |
+
text_output = gr.Textbox(label="Message")
|
127 |
+
audio_output = gr.Audio(label="Output Audio")
|
128 |
+
|
129 |
+
btn.click(
|
130 |
+
tts_fn,
|
131 |
+
inputs=[
|
132 |
+
text,
|
133 |
+
speaker,
|
134 |
+
sdp_ratio,
|
135 |
+
noise_scale,
|
136 |
+
noise_scale_w,
|
137 |
+
length_scale,
|
138 |
+
language,
|
139 |
+
],
|
140 |
+
outputs=[text_output, audio_output],
|
141 |
+
)
|
142 |
+
|
143 |
+
# webbrowser.open("http://127.0.0.1:6006")
|
144 |
+
# app.launch(server_port=6006, show_error=True)
|
145 |
+
|
146 |
+
app.launch(show_error=True)
|