RUSH-miaomi commited on
Commit
38cc289
·
1 Parent(s): da03080

Upload webui.py

Browse files
Files changed (1) hide show
  1. webui.py +255 -0
webui.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: E402
2
+ import re
3
+ import sys, os
4
+ import logging
5
+ import re_matching
6
+ logging.getLogger("numba").setLevel(logging.WARNING)
7
+ logging.getLogger("markdown_it").setLevel(logging.WARNING)
8
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
9
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
10
+
11
+ logging.basicConfig(
12
+ level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s"
13
+ )
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ import torch
18
+ import argparse
19
+ import commons
20
+ import utils
21
+ from models import SynthesizerTrn
22
+ from text.symbols import symbols
23
+ from text import cleaned_text_to_sequence, get_bert
24
+ from text.cleaner import clean_text
25
+ import gradio as gr
26
+ import webbrowser
27
+ import numpy as np
28
+
29
+ net_g = None
30
+
31
+ if sys.platform == "darwin" and torch.backends.mps.is_available():
32
+ device = "mps"
33
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
34
+ else:
35
+ device = "cuda"
36
+
37
+
38
+ def get_text(text, language_str, hps):
39
+ norm_text, phone, tone, word2ph = clean_text(text, language_str)
40
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
41
+
42
+ if hps.data.add_blank:
43
+ phone = commons.intersperse(phone, 0)
44
+ tone = commons.intersperse(tone, 0)
45
+ language = commons.intersperse(language, 0)
46
+ for i in range(len(word2ph)):
47
+ word2ph[i] = word2ph[i] * 2
48
+ word2ph[0] += 1
49
+ bert = get_bert(norm_text, word2ph, language_str, device)
50
+ del word2ph
51
+ assert bert.shape[-1] == len(phone), phone
52
+
53
+ if language_str == "ZH":
54
+ bert = bert
55
+ ja_bert = torch.zeros(768, len(phone))
56
+ elif language_str == "JP":
57
+ ja_bert = bert
58
+ bert = torch.zeros(1024, len(phone))
59
+ else:
60
+ bert = torch.zeros(1024, len(phone))
61
+ ja_bert = torch.zeros(768, len(phone))
62
+
63
+ assert bert.shape[-1] == len(
64
+ phone
65
+ ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
66
+
67
+ phone = torch.LongTensor(phone)
68
+ tone = torch.LongTensor(tone)
69
+ language = torch.LongTensor(language)
70
+ return bert, ja_bert, phone, tone, language
71
+
72
+
73
+ def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, language):
74
+ global net_g
75
+ bert, ja_bert, phones, tones, lang_ids = get_text(text, language, hps)
76
+ with torch.no_grad():
77
+ x_tst = phones.to(device).unsqueeze(0)
78
+ tones = tones.to(device).unsqueeze(0)
79
+ lang_ids = lang_ids.to(device).unsqueeze(0)
80
+ bert = bert.to(device).unsqueeze(0)
81
+ ja_bert = ja_bert.to(device).unsqueeze(0)
82
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
83
+ del phones
84
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
85
+ audio = (
86
+ net_g.infer(
87
+ x_tst,
88
+ x_tst_lengths,
89
+ speakers,
90
+ tones,
91
+ lang_ids,
92
+ bert,
93
+ ja_bert,
94
+ sdp_ratio=sdp_ratio,
95
+ noise_scale=noise_scale,
96
+ noise_scale_w=noise_scale_w,
97
+ length_scale=length_scale,
98
+ )[0][0, 0]
99
+ .data.cpu()
100
+ .float()
101
+ .numpy()
102
+ )
103
+ del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers
104
+ torch.cuda.empty_cache()
105
+ return audio
106
+
107
+
108
+ def generate_audio(slices, sdp_ratio, noise_scale, noise_scale_w, length_scale, speaker, language):
109
+ audio_list = []
110
+ silence = np.zeros(hps.data.sampling_rate // 2)
111
+ with torch.no_grad():
112
+ for piece in slices:
113
+ audio = infer(
114
+ piece,
115
+ sdp_ratio=sdp_ratio,
116
+ noise_scale=noise_scale,
117
+ noise_scale_w=noise_scale_w,
118
+ length_scale=length_scale,
119
+ sid=speaker,
120
+ language=language,
121
+ )
122
+ audio_list.append(audio)
123
+ audio_list.append(silence) # 将静音添加到列表中
124
+ return audio_list
125
+
126
+
127
+ def tts_fn(text: str, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale, language):
128
+ audio_list = []
129
+ if language == "mix":
130
+ bool_valid, str_valid = re_matching.validate_text(text)
131
+ if not bool_valid:
132
+ return str_valid, (hps.data.sampling_rate, np.concatenate([np.zeros(hps.data.sampling_rate // 2)]))
133
+ result = re_matching.text_matching(text)
134
+ for one in result:
135
+ _speaker = one.pop()
136
+ for lang, content in one:
137
+ audio_list.extend(
138
+ generate_audio(content.split("|"), sdp_ratio, noise_scale,
139
+ noise_scale_w, length_scale, _speaker+'_'+lang.lower(), lang)
140
+ )
141
+ else:
142
+ audio_list.extend(
143
+ generate_audio(text.split("|"), sdp_ratio, noise_scale, noise_scale_w, length_scale, speaker, language)
144
+ )
145
+
146
+ audio_concat = np.concatenate(audio_list)
147
+ return "Success", (hps.data.sampling_rate, audio_concat)
148
+
149
+
150
+ if __name__ == "__main__":
151
+ parser = argparse.ArgumentParser()
152
+ parser.add_argument(
153
+ "-m", "--model", default="./logs/as/G_8000.pth", help="path of your model"
154
+ )
155
+ parser.add_argument(
156
+ "-c",
157
+ "--config",
158
+ default="./configs/config.json",
159
+ help="path of your config file",
160
+ )
161
+ parser.add_argument(
162
+ "--share", default=False, help="make link public", action="store_true"
163
+ )
164
+ parser.add_argument(
165
+ "-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log"
166
+ )
167
+
168
+ args = parser.parse_args()
169
+ if args.debug:
170
+ logger.info("Enable DEBUG-LEVEL log")
171
+ logging.basicConfig(level=logging.DEBUG)
172
+ hps = utils.get_hparams_from_file(args.config)
173
+
174
+ device = (
175
+ "cuda:0"
176
+ if torch.cuda.is_available()
177
+ else (
178
+ "mps"
179
+ if sys.platform == "darwin" and torch.backends.mps.is_available()
180
+ else "cpu"
181
+ )
182
+ )
183
+ net_g = SynthesizerTrn(
184
+ len(symbols),
185
+ hps.data.filter_length // 2 + 1,
186
+ hps.train.segment_size // hps.data.hop_length,
187
+ n_speakers=hps.data.n_speakers,
188
+ **hps.model,
189
+ ).to(device)
190
+ _ = net_g.eval()
191
+
192
+ _ = utils.load_checkpoint(args.model, net_g, None, skip_optimizer=True)
193
+
194
+ speaker_ids = hps.data.spk2id
195
+ speakers = list(speaker_ids.keys())
196
+ languages = ["ZH", "JP", "mix"]
197
+ with gr.Blocks() as app:
198
+ with gr.Row():
199
+ with gr.Column():
200
+ text = gr.TextArea(
201
+ label="输入文本内容",
202
+ placeholder="""
203
+ 如果你选择语言为\'mix\',必须按照格式输入,否则报错:
204
+ 格式举例(zh是中文,jp是日语,不区分大小写;说话人举例:gongzi):
205
+ [说话人1]<zh>你好,こんにちは! <jp>こんにちは,世界。
206
+ [说话人2]<zh>你好吗?<jp>元気ですか?
207
+ [说话人3]<zh>谢谢。<jp>どういたしまして。
208
+ ...
209
+ 另外,所有的语言选项都可以用'|'分割长段实现分句生成。
210
+ """
211
+ )
212
+ speaker = gr.Dropdown(
213
+ choices=speakers, value=speakers[0], label="选择说话人"
214
+ )
215
+ sdp_ratio = gr.Slider(
216
+ minimum=0, maximum=1, value=0.2, step=0.1, label="SDP/DP混合比"
217
+ )
218
+ noise_scale = gr.Slider(
219
+ minimum=0.1, maximum=2, value=0.2, step=0.1, label="感情"
220
+ )
221
+ noise_scale_w = gr.Slider(
222
+ minimum=0.1, maximum=2, value=0.9, step=0.1, label="音素长度"
223
+ )
224
+ length_scale = gr.Slider(
225
+ minimum=0.1, maximum=2, value=0.8, step=0.1, label="语速"
226
+ )
227
+ language = gr.Dropdown(
228
+ choices=languages, value=languages[0], label="选择语言(新增mix混合选项)"
229
+ )
230
+ btn = gr.Button("生成音频!", variant="primary")
231
+ with gr.Column():
232
+ text_output = gr.Textbox(label="状态信息")
233
+ audio_output = gr.Audio(label="输出音频")
234
+ explain_image = gr.Image(label="参数解释信息",
235
+ show_label=True,
236
+ show_share_button=False,
237
+ show_download_button=False,
238
+ value=os.path.abspath("./img/参数说明.png"))
239
+ btn.click(
240
+ tts_fn,
241
+ inputs=[
242
+ text,
243
+ speaker,
244
+ sdp_ratio,
245
+ noise_scale,
246
+ noise_scale_w,
247
+ length_scale,
248
+ language,
249
+ ],
250
+ outputs=[text_output, audio_output],
251
+ )
252
+
253
+ webbrowser.open("http://127.0.0.1:7860")
254
+ app.launch(share=args.share, server_port=7860)
255
+ print("推理页面已开启!")