kevinwang676 commited on
Commit
f297234
·
verified ·
1 Parent(s): 4fe39bc

Upload GPT_SoVITS_inference_webui.py

Browse files
Files changed (1) hide show
  1. GPT_SoVITS_inference_webui.py +690 -0
GPT_SoVITS_inference_webui.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ 按中英混合识别
3
+ 按日英混合识别
4
+ 多语种启动切分识别语种
5
+ 全部按中文识别
6
+ 全部按英文识别
7
+ 全部按日文识别
8
+ '''
9
+
10
+ # OpenVoice
11
+
12
+ import os
13
+ import torch
14
+ from openvoice import se_extractor
15
+ from openvoice.api import BaseSpeakerTTS, ToneColorConverter
16
+
17
+ if torch.cuda.is_available():
18
+ device = "cuda"
19
+ else:
20
+ device = "cpu"
21
+
22
+ ckpt_base = 'checkpoints/base_speakers/EN'
23
+ ckpt_converter = 'checkpoints/converter'
24
+ base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base}/config.json', device=device)
25
+ base_speaker_tts.load_ckpt(f'{ckpt_base}/checkpoint.pth')
26
+
27
+ tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
28
+ tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
29
+
30
+ #source_se = torch.load(f'{ckpt_base}/en_default_se.pth').to(device)
31
+ #source_se_style = torch.load(f'{ckpt_base}/en_style_se.pth').to(device)
32
+
33
+ def vc_en(audio_ref, style_mode):
34
+ text = "We have always tried to be at the intersection of technology and liberal arts, to be able to get the best of both, to make extremely advanced products from a technology point of view."
35
+ if style_mode=="default":
36
+ source_se = torch.load(f'{ckpt_base}/en_default_se.pth').to(device)
37
+ reference_speaker = audio_ref
38
+ target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
39
+ save_path = "output.wav"
40
+
41
+ # Run the base speaker tts
42
+ src_path = "tmp.wav"
43
+ base_speaker_tts.tts(text, src_path, speaker='default', language='English', speed=1.0)
44
+
45
+ # Run the tone color converter
46
+ encode_message = "@MyShell"
47
+ tone_color_converter.convert(
48
+ audio_src_path=src_path,
49
+ src_se=source_se,
50
+ tgt_se=target_se,
51
+ output_path=save_path,
52
+ message=encode_message)
53
+
54
+ else:
55
+ source_se = torch.load(f'{ckpt_base}/en_style_se.pth').to(device)
56
+ reference_speaker = audio_ref
57
+ target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
58
+
59
+ save_path = "output.wav"
60
+
61
+ # Run the base speaker tts
62
+ src_path = "tmp.wav"
63
+ base_speaker_tts.tts(text, src_path, speaker=style_mode, language='English', speed=1.0)
64
+
65
+ # Run the tone color converter
66
+ encode_message = "@MyShell"
67
+ tone_color_converter.convert(
68
+ audio_src_path=src_path,
69
+ src_se=source_se,
70
+ tgt_se=target_se,
71
+ output_path=save_path,
72
+ message=encode_message)
73
+
74
+ return "output.wav"
75
+
76
+ # End
77
+
78
+ import re, logging
79
+ import LangSegment
80
+ logging.getLogger("markdown_it").setLevel(logging.ERROR)
81
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
82
+ logging.getLogger("httpcore").setLevel(logging.ERROR)
83
+ logging.getLogger("httpx").setLevel(logging.ERROR)
84
+ logging.getLogger("asyncio").setLevel(logging.ERROR)
85
+ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
86
+ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
87
+ import pdb
88
+
89
+ if os.path.exists("./gweight.txt"):
90
+ with open("./gweight.txt", 'r', encoding="utf-8") as file:
91
+ gweight_data = file.read()
92
+ gpt_path = os.environ.get(
93
+ "gpt_path", gweight_data)
94
+ else:
95
+ gpt_path = os.environ.get(
96
+ "gpt_path", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
97
+
98
+ if os.path.exists("./sweight.txt"):
99
+ with open("./sweight.txt", 'r', encoding="utf-8") as file:
100
+ sweight_data = file.read()
101
+ sovits_path = os.environ.get("sovits_path", sweight_data)
102
+ else:
103
+ sovits_path = os.environ.get("sovits_path", "GPT_SoVITS/pretrained_models/s2G488k.pth")
104
+ # gpt_path = os.environ.get(
105
+ # "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
106
+ # )
107
+ # sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
108
+ cnhubert_base_path = os.environ.get(
109
+ "cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base"
110
+ )
111
+ bert_path = os.environ.get(
112
+ "bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
113
+ )
114
+ infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
115
+ infer_ttswebui = int(infer_ttswebui)
116
+ is_share = os.environ.get("is_share", "False")
117
+ is_share = eval(is_share)
118
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
119
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
120
+ is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
121
+ import gradio as gr
122
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
123
+ import numpy as np
124
+ import librosa
125
+ from feature_extractor import cnhubert
126
+
127
+ cnhubert.cnhubert_base_path = cnhubert_base_path
128
+
129
+ from module.models import SynthesizerTrn
130
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
131
+ from text import cleaned_text_to_sequence
132
+ from text.cleaner import clean_text
133
+ from time import time as ttime
134
+ from module.mel_processing import spectrogram_torch
135
+ from my_utils import load_audio
136
+ from tools.i18n.i18n import I18nAuto
137
+
138
+ i18n = I18nAuto()
139
+
140
+ # os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
141
+
142
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
143
+ bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
144
+ if is_half == True:
145
+ bert_model = bert_model.half().to(device)
146
+ else:
147
+ bert_model = bert_model.to(device)
148
+
149
+
150
+ def get_bert_feature(text, word2ph):
151
+ with torch.no_grad():
152
+ inputs = tokenizer(text, return_tensors="pt")
153
+ for i in inputs:
154
+ inputs[i] = inputs[i].to(device)
155
+ res = bert_model(**inputs, output_hidden_states=True)
156
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
157
+ assert len(word2ph) == len(text)
158
+ phone_level_feature = []
159
+ for i in range(len(word2ph)):
160
+ repeat_feature = res[i].repeat(word2ph[i], 1)
161
+ phone_level_feature.append(repeat_feature)
162
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
163
+ return phone_level_feature.T
164
+
165
+
166
+ class DictToAttrRecursive(dict):
167
+ def __init__(self, input_dict):
168
+ super().__init__(input_dict)
169
+ for key, value in input_dict.items():
170
+ if isinstance(value, dict):
171
+ value = DictToAttrRecursive(value)
172
+ self[key] = value
173
+ setattr(self, key, value)
174
+
175
+ def __getattr__(self, item):
176
+ try:
177
+ return self[item]
178
+ except KeyError:
179
+ raise AttributeError(f"Attribute {item} not found")
180
+
181
+ def __setattr__(self, key, value):
182
+ if isinstance(value, dict):
183
+ value = DictToAttrRecursive(value)
184
+ super(DictToAttrRecursive, self).__setitem__(key, value)
185
+ super().__setattr__(key, value)
186
+
187
+ def __delattr__(self, item):
188
+ try:
189
+ del self[item]
190
+ except KeyError:
191
+ raise AttributeError(f"Attribute {item} not found")
192
+
193
+
194
+ ssl_model = cnhubert.get_model()
195
+ if is_half == True:
196
+ ssl_model = ssl_model.half().to(device)
197
+ else:
198
+ ssl_model = ssl_model.to(device)
199
+
200
+
201
+ def change_sovits_weights(sovits_path):
202
+ global vq_model, hps
203
+ dict_s2 = torch.load(sovits_path, map_location="cpu")
204
+ hps = dict_s2["config"]
205
+ hps = DictToAttrRecursive(hps)
206
+ hps.model.semantic_frame_rate = "25hz"
207
+ vq_model = SynthesizerTrn(
208
+ hps.data.filter_length // 2 + 1,
209
+ hps.train.segment_size // hps.data.hop_length,
210
+ n_speakers=hps.data.n_speakers,
211
+ **hps.model
212
+ )
213
+ if ("pretrained" not in sovits_path):
214
+ del vq_model.enc_q
215
+ if is_half == True:
216
+ vq_model = vq_model.half().to(device)
217
+ else:
218
+ vq_model = vq_model.to(device)
219
+ vq_model.eval()
220
+ print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
221
+ with open("./sweight.txt", "w", encoding="utf-8") as f:
222
+ f.write(sovits_path)
223
+
224
+
225
+ change_sovits_weights(sovits_path)
226
+
227
+
228
+ def change_gpt_weights(gpt_path):
229
+ global hz, max_sec, t2s_model, config
230
+ hz = 50
231
+ dict_s1 = torch.load(gpt_path, map_location="cpu")
232
+ config = dict_s1["config"]
233
+ max_sec = config["data"]["max_sec"]
234
+ t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
235
+ t2s_model.load_state_dict(dict_s1["weight"])
236
+ if is_half == True:
237
+ t2s_model = t2s_model.half()
238
+ t2s_model = t2s_model.to(device)
239
+ t2s_model.eval()
240
+ total = sum([param.nelement() for param in t2s_model.parameters()])
241
+ print("Number of parameter: %.2fM" % (total / 1e6))
242
+ with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path)
243
+
244
+
245
+ change_gpt_weights(gpt_path)
246
+
247
+
248
+ def get_spepc(hps, filename):
249
+ audio = load_audio(filename, int(hps.data.sampling_rate))
250
+ audio = torch.FloatTensor(audio)
251
+ audio_norm = audio
252
+ audio_norm = audio_norm.unsqueeze(0)
253
+ spec = spectrogram_torch(
254
+ audio_norm,
255
+ hps.data.filter_length,
256
+ hps.data.sampling_rate,
257
+ hps.data.hop_length,
258
+ hps.data.win_length,
259
+ center=False,
260
+ )
261
+ return spec
262
+
263
+
264
+ dict_language = {
265
+ i18n("中文"): "all_zh",#全部按中文识别
266
+ i18n("英文"): "en",#全部按英文识别#######不变
267
+ i18n("日文"): "all_ja",#全部按日文识别
268
+ i18n("中英混合"): "zh",#按中英混合识别####不变
269
+ i18n("日英混合"): "ja",#按日英混合识别####不变
270
+ i18n("多语种混合"): "auto",#多语种启动切分识别语种
271
+ }
272
+
273
+
274
+ def clean_text_inf(text, language):
275
+ phones, word2ph, norm_text = clean_text(text, language)
276
+ phones = cleaned_text_to_sequence(phones)
277
+ return phones, word2ph, norm_text
278
+
279
+ dtype=torch.float16 if is_half == True else torch.float32
280
+ def get_bert_inf(phones, word2ph, norm_text, language):
281
+ language=language.replace("all_","")
282
+ if language == "zh":
283
+ bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
284
+ else:
285
+ bert = torch.zeros(
286
+ (1024, len(phones)),
287
+ dtype=torch.float16 if is_half == True else torch.float32,
288
+ ).to(device)
289
+
290
+ return bert
291
+
292
+
293
+ splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "��", }
294
+
295
+
296
+ def get_first(text):
297
+ pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
298
+ text = re.split(pattern, text)[0].strip()
299
+ return text
300
+
301
+
302
+ def get_phones_and_bert(text,language):
303
+ if language in {"en","all_zh","all_ja"}:
304
+ language = language.replace("all_","")
305
+ if language == "en":
306
+ LangSegment.setfilters(["en"])
307
+ formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
308
+ else:
309
+ # 因无法区别中日文汉字,以用户输入为准
310
+ formattext = text
311
+ while " " in formattext:
312
+ formattext = formattext.replace(" ", " ")
313
+ phones, word2ph, norm_text = clean_text_inf(formattext, language)
314
+ if language == "zh":
315
+ bert = get_bert_feature(norm_text, word2ph).to(device)
316
+ else:
317
+ bert = torch.zeros(
318
+ (1024, len(phones)),
319
+ dtype=torch.float16 if is_half == True else torch.float32,
320
+ ).to(device)
321
+ elif language in {"zh", "ja","auto"}:
322
+ textlist=[]
323
+ langlist=[]
324
+ LangSegment.setfilters(["zh","ja","en","ko"])
325
+ if language == "auto":
326
+ for tmp in LangSegment.getTexts(text):
327
+ if tmp["lang"] == "ko":
328
+ langlist.append("zh")
329
+ textlist.append(tmp["text"])
330
+ else:
331
+ langlist.append(tmp["lang"])
332
+ textlist.append(tmp["text"])
333
+ else:
334
+ for tmp in LangSegment.getTexts(text):
335
+ if tmp["lang"] == "en":
336
+ langlist.append(tmp["lang"])
337
+ else:
338
+ # 因无法区别中日文汉字,以用户输入为准
339
+ langlist.append(language)
340
+ textlist.append(tmp["text"])
341
+ print(textlist)
342
+ print(langlist)
343
+ phones_list = []
344
+ bert_list = []
345
+ norm_text_list = []
346
+ for i in range(len(textlist)):
347
+ lang = langlist[i]
348
+ phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
349
+ bert = get_bert_inf(phones, word2ph, norm_text, lang)
350
+ phones_list.append(phones)
351
+ norm_text_list.append(norm_text)
352
+ bert_list.append(bert)
353
+ bert = torch.cat(bert_list, dim=1)
354
+ phones = sum(phones_list, [])
355
+ norm_text = ''.join(norm_text_list)
356
+
357
+ return phones,bert.to(dtype),norm_text
358
+
359
+
360
+ def merge_short_text_in_array(texts, threshold):
361
+ if (len(texts)) < 2:
362
+ return texts
363
+ result = []
364
+ text = ""
365
+ for ele in texts:
366
+ text += ele
367
+ if len(text) >= threshold:
368
+ result.append(text)
369
+ text = ""
370
+ if (len(text) > 0):
371
+ if len(result) == 0:
372
+ result.append(text)
373
+ else:
374
+ result[len(result) - 1] += text
375
+ return result
376
+
377
+ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False):
378
+ if prompt_text is None or len(prompt_text) == 0:
379
+ ref_free = True
380
+ t0 = ttime()
381
+ prompt_language = dict_language[prompt_language]
382
+ text_language = dict_language[text_language]
383
+ if not ref_free:
384
+ prompt_text = prompt_text.strip("\n")
385
+ if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
386
+ print(i18n("实际输入的参考文本:"), prompt_text)
387
+ text = text.strip("\n")
388
+ if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
389
+
390
+ print(i18n("实际输入的目标文本:"), text)
391
+ zero_wav = np.zeros(
392
+ int(hps.data.sampling_rate * 0.3),
393
+ dtype=np.float16 if is_half == True else np.float32,
394
+ )
395
+ with torch.no_grad():
396
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
397
+ if (wav16k.shape[0] > 240000 or wav16k.shape[0] < 48000):
398
+ raise OSError(i18n("参考音频在3~15秒范围外,请更换!"))
399
+ wav16k = torch.from_numpy(wav16k)
400
+ zero_wav_torch = torch.from_numpy(zero_wav)
401
+ if is_half == True:
402
+ wav16k = wav16k.half().to(device)
403
+ zero_wav_torch = zero_wav_torch.half().to(device)
404
+ else:
405
+ wav16k = wav16k.to(device)
406
+ zero_wav_torch = zero_wav_torch.to(device)
407
+ wav16k = torch.cat([wav16k, zero_wav_torch])
408
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
409
+ "last_hidden_state"
410
+ ].transpose(
411
+ 1, 2
412
+ ) # .float()
413
+ codes = vq_model.extract_latent(ssl_content)
414
+
415
+ prompt_semantic = codes[0, 0]
416
+ t1 = ttime()
417
+
418
+ if (how_to_cut == i18n("凑四句一切")):
419
+ text = cut1(text)
420
+ elif (how_to_cut == i18n("凑50字一切")):
421
+ text = cut2(text)
422
+ elif (how_to_cut == i18n("按中文句号。切")):
423
+ text = cut3(text)
424
+ elif (how_to_cut == i18n("按英文句号.切")):
425
+ text = cut4(text)
426
+ elif (how_to_cut == i18n("按标点符号切")):
427
+ text = cut5(text)
428
+ while "\n\n" in text:
429
+ text = text.replace("\n\n", "\n")
430
+ print(i18n("实际输入的目标文本(切句后):"), text)
431
+ texts = text.split("\n")
432
+ texts = merge_short_text_in_array(texts, 5)
433
+ audio_opt = []
434
+ if not ref_free:
435
+ phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
436
+
437
+ for text in texts:
438
+ # 解决输入目标文本的空行导致报错的问题
439
+ if (len(text.strip()) == 0):
440
+ continue
441
+ if (text[-1] not in splits): text += "。" if text_language != "en" else "."
442
+ print(i18n("实际输入的目标文本(每句):"), text)
443
+ phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
444
+ print(i18n("前端处理后的文本(每句):"), norm_text2)
445
+ if not ref_free:
446
+ bert = torch.cat([bert1, bert2], 1)
447
+ all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
448
+ else:
449
+ bert = bert2
450
+ all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
451
+
452
+ bert = bert.to(device).unsqueeze(0)
453
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
454
+ prompt = prompt_semantic.unsqueeze(0).to(device)
455
+ t2 = ttime()
456
+ with torch.no_grad():
457
+ # pred_semantic = t2s_model.model.infer(
458
+ pred_semantic, idx = t2s_model.model.infer_panel(
459
+ all_phoneme_ids,
460
+ all_phoneme_len,
461
+ None if ref_free else prompt,
462
+ bert,
463
+ # prompt_phone_len=ph_offset,
464
+ top_k=top_k,
465
+ top_p=top_p,
466
+ temperature=temperature,
467
+ early_stop_num=hz * max_sec,
468
+ )
469
+ t3 = ttime()
470
+ # print(pred_semantic.shape,idx)
471
+ pred_semantic = pred_semantic[:, -idx:].unsqueeze(
472
+ 0
473
+ ) # .unsqueeze(0)#mq要多unsqueeze一次
474
+ refer = get_spepc(hps, ref_wav_path) # .to(device)
475
+ if is_half == True:
476
+ refer = refer.half().to(device)
477
+ else:
478
+ refer = refer.to(device)
479
+ # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
480
+ audio = (
481
+ vq_model.decode(
482
+ pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
483
+ )
484
+ .detach()
485
+ .cpu()
486
+ .numpy()[0, 0]
487
+ ) ###试试重建不带上prompt部分
488
+ max_audio=np.abs(audio).max()#简单防止16bit爆音
489
+ if max_audio>1:audio/=max_audio
490
+ audio_opt.append(audio)
491
+ audio_opt.append(zero_wav)
492
+ t4 = ttime()
493
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
494
+ yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
495
+ np.int16
496
+ )
497
+
498
+
499
+ def split(todo_text):
500
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
501
+ if todo_text[-1] not in splits:
502
+ todo_text += "。"
503
+ i_split_head = i_split_tail = 0
504
+ len_text = len(todo_text)
505
+ todo_texts = []
506
+ while 1:
507
+ if i_split_head >= len_text:
508
+ break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
509
+ if todo_text[i_split_head] in splits:
510
+ i_split_head += 1
511
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
512
+ i_split_tail = i_split_head
513
+ else:
514
+ i_split_head += 1
515
+ return todo_texts
516
+
517
+
518
+ def cut1(inp):
519
+ inp = inp.strip("\n")
520
+ inps = split(inp)
521
+ split_idx = list(range(0, len(inps), 4))
522
+ split_idx[-1] = None
523
+ if len(split_idx) > 1:
524
+ opts = []
525
+ for idx in range(len(split_idx) - 1):
526
+ opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
527
+ else:
528
+ opts = [inp]
529
+ return "\n".join(opts)
530
+
531
+
532
+ def cut2(inp):
533
+ inp = inp.strip("\n")
534
+ inps = split(inp)
535
+ if len(inps) < 2:
536
+ return inp
537
+ opts = []
538
+ summ = 0
539
+ tmp_str = ""
540
+ for i in range(len(inps)):
541
+ summ += len(inps[i])
542
+ tmp_str += inps[i]
543
+ if summ > 50:
544
+ summ = 0
545
+ opts.append(tmp_str)
546
+ tmp_str = ""
547
+ if tmp_str != "":
548
+ opts.append(tmp_str)
549
+ # print(opts)
550
+ if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
551
+ opts[-2] = opts[-2] + opts[-1]
552
+ opts = opts[:-1]
553
+ return "\n".join(opts)
554
+
555
+
556
+ def cut3(inp):
557
+ inp = inp.strip("\n")
558
+ return "\n".join(["%s" % item for item in inp.strip("。").split("。")])
559
+
560
+
561
+ def cut4(inp):
562
+ inp = inp.strip("\n")
563
+ return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
564
+
565
+
566
+ # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
567
+ def cut5(inp):
568
+ # if not re.search(r'[^\w\s]', inp[-1]):
569
+ # inp += '。'
570
+ inp = inp.strip("\n")
571
+ punds = r'[,.;?!、,。?!;:…]'
572
+ items = re.split(f'({punds})', inp)
573
+ mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
574
+ # 在句子不存在符号或句尾无符号的时候保证文本完整
575
+ if len(items)%2 == 1:
576
+ mergeitems.append(items[-1])
577
+ opt = "\n".join(mergeitems)
578
+ return opt
579
+
580
+
581
+ def custom_sort_key(s):
582
+ # 使用正则表达式提取字符串中的数字部分和非数字部分
583
+ parts = re.split('(\d+)', s)
584
+ # 将数字部分转换为整数,非数字部分保持不变
585
+ parts = [int(part) if part.isdigit() else part for part in parts]
586
+ return parts
587
+
588
+
589
+ def change_choices():
590
+ SoVITS_names, GPT_names = get_weights_names()
591
+ return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
592
+
593
+
594
+ pretrained_sovits_name = "GPT_SoVITS/pretrained_models/s2G488k.pth"
595
+ pretrained_gpt_name = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
596
+ SoVITS_weight_root = "SoVITS_weights"
597
+ GPT_weight_root = "GPT_weights"
598
+ os.makedirs(SoVITS_weight_root, exist_ok=True)
599
+ os.makedirs(GPT_weight_root, exist_ok=True)
600
+
601
+
602
+ def get_weights_names():
603
+ SoVITS_names = [pretrained_sovits_name]
604
+ for name in os.listdir(SoVITS_weight_root):
605
+ if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (SoVITS_weight_root, name))
606
+ GPT_names = [pretrained_gpt_name]
607
+ for name in os.listdir(GPT_weight_root):
608
+ if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (GPT_weight_root, name))
609
+ return SoVITS_names, GPT_names
610
+
611
+
612
+ SoVITS_names, GPT_names = get_weights_names()
613
+
614
+ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
615
+ gr.Markdown(
616
+ value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>.")
617
+ )
618
+ with gr.Group():
619
+ gr.Markdown(value=i18n("模型切换"))
620
+ with gr.Row():
621
+ GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
622
+ SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
623
+ refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
624
+ refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
625
+ SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown], [])
626
+ GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
627
+ gr.Markdown(value=i18n("*请上传并填写参考信息"))
628
+ with gr.Row():
629
+ inp_training_audio = gr.Audio(label="请上传您完整的1分钟训练音频", type="filepath")
630
+ style_control = gr.Dropdown(label="请选择一种语音情感", info="🙂default😊friendly🤫whispering😄cheerful😱terrified😡angry😢sad", choices=["default", "friendly", "whispering", "cheerful", "terrified", "angry", "sad"], value="default")
631
+ btn_style = gr.Button("一键生成情感参考音频吧💕", variant="primary")
632
+ out_ref_audio = gr.Audio(label="为您生成的情感参考音频", type="filepath", interactive=False)
633
+ inp_ref = out_ref_audio
634
+ with gr.Column():
635
+ ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=False, show_label=True)
636
+ gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开,开启后无视填写的参考文本。"))
637
+ prompt_text = gr.Textbox(label=i18n("参考音频的文本"), interactive=False, value="We have always tried to be at the intersection of technology and liberal arts, to be able to get the best of both, to make extremely advanced products from a technology point of view.")
638
+ prompt_language = gr.Dropdown(
639
+ label=i18n("参考音频的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("英文"), interactive=False
640
+ )
641
+ gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
642
+ with gr.Row():
643
+ text = gr.Textbox(label=i18n("需要合成的文本"), value="")
644
+ text_language = gr.Dropdown(
645
+ label=i18n("需要合成的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
646
+ )
647
+ how_to_cut = gr.Radio(
648
+ label=i18n("怎么切"),
649
+ choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
650
+ value=i18n("凑四句一切"),
651
+ interactive=True,
652
+ )
653
+ with gr.Row():
654
+ gr.Markdown(value=i18n("gpt采样参数(无参考文本时不要太低):"))
655
+ top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
656
+ top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
657
+ temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
658
+ inference_button = gr.Button(i18n("合成语音"), variant="primary")
659
+ output = gr.Audio(label=i18n("输出的语音"))
660
+
661
+ inference_button.click(
662
+ get_tts_wav,
663
+ [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free],
664
+ [output],
665
+ )
666
+
667
+ gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
668
+ with gr.Row():
669
+ text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="")
670
+ button1 = gr.Button(i18n("凑四句一切"), variant="primary")
671
+ button2 = gr.Button(i18n("凑50字一切"), variant="primary")
672
+ button3 = gr.Button(i18n("按中文句号。切"), variant="primary")
673
+ button4 = gr.Button(i18n("按英文句号.切"), variant="primary")
674
+ button5 = gr.Button(i18n("按标点符号切"), variant="primary")
675
+ text_opt = gr.Textbox(label=i18n("切分后文本"), value="")
676
+ button1.click(cut1, [text_inp], [text_opt])
677
+ button2.click(cut2, [text_inp], [text_opt])
678
+ button3.click(cut3, [text_inp], [text_opt])
679
+ button4.click(cut4, [text_inp], [text_opt])
680
+ button5.click(cut5, [text_inp], [text_opt])
681
+ btn_style.click(vc_en, [inp_training_audio, style_control], [out_ref_audio])
682
+ gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
683
+
684
+ app.queue(concurrency_count=511, max_size=1022).launch(
685
+ server_name="0.0.0.0",
686
+ inbrowser=True,
687
+ share=True,
688
+ server_port=infer_ttswebui,
689
+ quiet=True,
690
+ )