AkitoP commited on
Commit
75f3959
·
verified ·
1 Parent(s): b7c2a16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +390 -389
app.py CHANGED
@@ -1,390 +1,391 @@
1
- import os
2
- import sys
3
- import spaces
4
- cnhubert_base_path = "TencentGameMate/chinese-hubert-base"
5
- bert_path = "hfl/chinese-roberta-wwm-ext-large"
6
- os.environ["version"] = 'v2'
7
- now_dir = os.path.dirname(os.path.abspath(__file__))
8
- sys.path.insert(0, now_dir)
9
- sys.path.insert(0, os.path.join(now_dir, "GPT_SoVITS"))
10
- sys.path.insert(0, os.path.join(now_dir, "GPT_SoVITS",'text'))
11
- import site
12
- site_packages_roots = []
13
- for site_packages_root in site_packages_roots:
14
- if os.path.exists(site_packages_root):
15
- try:
16
- with open("%s/users.pth" % (site_packages_root), "w") as f:
17
- f.write(
18
- "%s\n%s/tools\n%s/tools/damo_asr\n%s/GPT_SoVITS\n%s/tools/uvr5"
19
- % (now_dir, now_dir, now_dir, now_dir, now_dir)
20
- )
21
- break
22
- except PermissionError:
23
- pass
24
- import re
25
- import gradio as gr
26
- from transformers import AutoModelForMaskedLM, AutoTokenizer
27
- import numpy as np
28
- import os,librosa,torch, audiosegment
29
- from GPT_SoVITS.feature_extractor import cnhubert
30
- cnhubert.cnhubert_base_path=cnhubert_base_path
31
- from GPT_SoVITS.module.models import SynthesizerTrn
32
- from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
33
- from GPT_SoVITS.text import cleaned_text_to_sequence
34
- from GPT_SoVITS.text.cleaner import clean_text
35
- from time import time as ttime
36
- from GPT_SoVITS.module.mel_processing import spectrogram_torch
37
- import tempfile
38
- from tools.my_utils import load_audio
39
- import os
40
- import json
41
- # import pyopenjtalk
42
- # cwd = os.getcwd()
43
- # if os.path.exists(os.path.join(cwd,'user.dic')):
44
- # pyopenjtalk.update_global_jtalk_with_user_dict(os.path.join(cwd, 'user.dic'))
45
-
46
-
47
- import logging
48
- logging.getLogger('httpx').setLevel(logging.WARNING)
49
- logging.getLogger('httpcore').setLevel(logging.WARNING)
50
- logging.getLogger('multipart').setLevel(logging.WARNING)
51
-
52
- device = "cuda" if torch.cuda.is_available() else "cpu"
53
- #device = "cpu"
54
- is_half = False
55
-
56
- tokenizer = AutoTokenizer.from_pretrained(bert_path)
57
- bert_model=AutoModelForMaskedLM.from_pretrained(bert_path)
58
- if(is_half==True):bert_model=bert_model.half().to(device)
59
- else:bert_model=bert_model.to(device)
60
- # bert_model=bert_model.to(device)
61
- def get_bert_feature(text, word2ph): # Bert(不是HuBERT的特征计算)
62
- with torch.no_grad():
63
- inputs = tokenizer(text, return_tensors="pt")
64
- for i in inputs:
65
- inputs[i] = inputs[i].to(device)#####输入是long不用管精度问题,精度随bert_model
66
- res = bert_model(**inputs, output_hidden_states=True)
67
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
68
- assert len(word2ph) == len(text)
69
- phone_level_feature = []
70
- for i in range(len(word2ph)):
71
- repeat_feature = res[i].repeat(word2ph[i], 1)
72
- phone_level_feature.append(repeat_feature)
73
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
74
- # if(is_half==True):phone_level_feature=phone_level_feature.half()
75
- return phone_level_feature.T
76
-
77
- loaded_sovits_model = [] # [(path, dict, model)]
78
- loaded_gpt_model = []
79
- ssl_model = cnhubert.get_model()
80
- if (is_half == True):
81
- ssl_model = ssl_model.half().to(device)
82
- else:
83
- ssl_model = ssl_model.to(device)
84
-
85
-
86
- def load_model(sovits_path, gpt_path):
87
- global ssl_model
88
- global loaded_sovits_model
89
- global loaded_gpt_model
90
- vq_model = None
91
- t2s_model = None
92
- dict_s2 = None
93
- dict_s1 = None
94
- hps = None
95
- for path, dict_s2_, model in loaded_sovits_model:
96
- if path == sovits_path:
97
- vq_model = model
98
- dict_s2 = dict_s2_
99
- break
100
- for path, dict_s1_, model in loaded_gpt_model:
101
- if path == gpt_path:
102
- t2s_model = model
103
- dict_s1 = dict_s1_
104
- break
105
-
106
- if dict_s2 is None:
107
- dict_s2 = torch.load(sovits_path, map_location="cpu")
108
- hps = dict_s2["config"]
109
-
110
- if dict_s1 is None:
111
- dict_s1 = torch.load(gpt_path, map_location="cpu")
112
- config = dict_s1["config"]
113
- class DictToAttrRecursive:
114
- def __init__(self, input_dict):
115
- for key, value in input_dict.items():
116
- if isinstance(value, dict):
117
- # 如果值是字典,递归调用构造函数
118
- setattr(self, key, DictToAttrRecursive(value))
119
- else:
120
- setattr(self, key, value)
121
-
122
- hps = DictToAttrRecursive(hps)
123
- hps.model.semantic_frame_rate = "25hz"
124
-
125
-
126
- if not vq_model:
127
- vq_model = SynthesizerTrn(
128
- hps.data.filter_length // 2 + 1,
129
- hps.train.segment_size // hps.data.hop_length,
130
- n_speakers=hps.data.n_speakers,
131
- **hps.model)
132
- if (is_half == True):
133
- vq_model = vq_model.half().to(device)
134
- else:
135
- vq_model = vq_model.to(device)
136
- vq_model.eval()
137
- vq_model.load_state_dict(dict_s2["weight"], strict=False)
138
- loaded_sovits_model.append((sovits_path, dict_s2, vq_model))
139
- hz = 50
140
- max_sec = config['data']['max_sec']
141
- if not t2s_model:
142
- t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
143
- t2s_model.load_state_dict(dict_s1["weight"])
144
- if (is_half == True): t2s_model = t2s_model.half()
145
- t2s_model = t2s_model.to(device)
146
- t2s_model.eval()
147
- total = sum([param.nelement() for param in t2s_model.parameters()])
148
- print("Number of parameter: %.2fM" % (total / 1e6))
149
- loaded_gpt_model.append((gpt_path, dict_s1, t2s_model))
150
- return vq_model, ssl_model, t2s_model, hps, config, hz, max_sec
151
-
152
-
153
- def get_spepc(hps, filename):
154
- audio=load_audio(filename,int(hps.data.sampling_rate))
155
- audio = audio / np.max(np.abs(audio))
156
- audio=torch.FloatTensor(audio)
157
- print(torch.max(torch.abs(audio)))
158
- audio_norm = audio
159
- # audio_norm = audio / torch.max(torch.abs(audio))
160
- audio_norm = audio_norm.unsqueeze(0)
161
- print(torch.max(torch.abs(audio_norm)))
162
- spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False)
163
- return spec
164
-
165
- def create_tts_fn(vq_model, ssl_model, t2s_model, hps, config, hz, max_sec):
166
- @spaces.GPU()
167
- def tts_fn(ref_wav_path, prompt_text, prompt_language, target_phone, text_language, target_text = None):
168
- t0 = ttime()
169
- prompt_text=prompt_text.strip()
170
- prompt_language=prompt_language
171
- with torch.no_grad():
172
- wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
173
- # maxx=0.95
174
- # tmp_max = np.abs(wav16k).max()
175
- # alpha=0.5
176
- # wav16k = (wav16k / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * wav16k
177
- #在这里归一化
178
- #print(max(np.abs(wav16k)))
179
- #wav16k = wav16k / np.max(np.abs(wav16k))
180
- #print(max(np.abs(wav16k)))
181
- # 添加0.3s的静音
182
- wav16k = np.concatenate([wav16k, np.zeros(int(hps.data.sampling_rate * 0.3)),])
183
- wav16k = torch.from_numpy(wav16k)
184
- wav16k = wav16k.float()
185
- if(is_half==True):wav16k=wav16k.half().to(device)
186
- else:wav16k=wav16k.to(device)
187
- print(wav16k.shape) # 读取16k音频
188
- ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float()
189
- print(ssl_content.shape)
190
- codes = vq_model.extract_latent(ssl_content)
191
- print(codes.shape)
192
- prompt_semantic = codes[0, 0]
193
- t1 = ttime()
194
- phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
195
- phones1=cleaned_text_to_sequence(phones1)
196
- #texts=text.split("\n")
197
- audio_opt = []
198
- zero_wav=np.zeros(int(hps.data.sampling_rate*0.3),dtype=np.float16 if is_half==True else np.float32)
199
- phones = get_phone_from_str_list(target_phone, text_language)
200
- for phones2 in phones:
201
- if(len(phones2) == 0):
202
- continue
203
- if(len(phones2) == 1 and phones2[0] == ""):
204
- continue
205
- #phones2, word2ph2, norm_text2 = clean_text(text, text_language)
206
- print(phones2)
207
- phones2 = cleaned_text_to_sequence(phones2)
208
- #if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
209
- bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device)
210
- #if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
211
- bert2 = torch.zeros((1024, len(phones2))).to(bert1)
212
- bert = torch.cat([bert1, bert2], 1)
213
-
214
- all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
215
- bert = bert.to(device).unsqueeze(0)
216
- all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
217
- prompt = prompt_semantic.unsqueeze(0).to(device)
218
- t2 = ttime()
219
- idx = 0
220
- cnt = 0
221
- while idx == 0 and cnt < 2:
222
- with torch.no_grad():
223
- # pred_semantic = t2s_model.model.infer
224
- pred_semantic,idx = t2s_model.model.infer_panel(
225
- all_phoneme_ids,
226
- all_phoneme_len,
227
- prompt,
228
- bert,
229
- # prompt_phone_len=ph_offset,
230
- top_k=config['inference']['top_k'],
231
- early_stop_num=hz * max_sec)
232
- t3 = ttime()
233
- cnt+=1
234
- if idx == 0:
235
- return "Error: Generation failure: bad zero prediction.", None
236
- pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
237
- refer = get_spepc(hps, ref_wav_path)#.to(device)
238
- if(is_half==True):refer=refer.half().to(device)
239
- else:refer=refer.to(device)
240
- # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
241
- audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分
242
- audio_opt.append(audio)
243
- audio_opt.append(zero_wav)
244
- t4 = ttime()
245
- print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
246
-
247
- audio = (hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16))
248
-
249
- filename = tempfile.mktemp(suffix=".wav",prefix=f"{prompt_text[:8]}_{target_text[:8]}_")
250
- audiosegment.from_numpy_array(audio[1], framerate=audio[0]).export(filename, format="WAV")
251
- return "Success", (hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16)), filename
252
- return tts_fn
253
-
254
-
255
- def get_str_list_from_phone(text, text_language):
256
- # raw文本过g2p得到音素列表,再转成字符串
257
- # 注意,这里的text是一个段落,可能包含多个句子
258
- # 段落间\n分割,音素间空格分割
259
- texts=text.split("\n")
260
- phone_list = []
261
- for text in texts:
262
- phones2, word2ph2, norm_text2 = clean_text(text, text_language)
263
- phone_list.append(" ".join(phones2))
264
- return "\n".join(phone_list)
265
-
266
- def get_phone_from_str_list(str_list:str, language:str = 'ja'):
267
- # 从音素字符串中得到音素列表
268
- # 注意,这里的text是一个段落,可能包含多个句子
269
- # 段落间\n分割,音素间空格分割
270
- sentences = str_list.split("\n")
271
- phones = []
272
- for sentence in sentences:
273
- phones.append(sentence.split(" "))
274
- return phones
275
-
276
- splits={",","。","?","!",",",".","?","!","~",":",":","—","…",}#不考虑省略号
277
- def split(todo_text):
278
- todo_text = todo_text.replace("……", "。").replace("——", ",")
279
- if (todo_text[-1] not in splits): todo_text += "。"
280
- i_split_head = i_split_tail = 0
281
- len_text = len(todo_text)
282
- todo_texts = []
283
- while (1):
284
- if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
285
- if (todo_text[i_split_head] in splits):
286
- i_split_head += 1
287
- todo_texts.append(todo_text[i_split_tail:i_split_head])
288
- i_split_tail = i_split_head
289
- else:
290
- i_split_head += 1
291
- return todo_texts
292
-
293
-
294
- def change_reference_audio(prompt_text, transcripts):
295
- return transcripts[prompt_text]
296
-
297
-
298
- models = []
299
- models_info = json.load(open("./models/models_info.json", "r", encoding="utf-8"))
300
-
301
-
302
-
303
- for i, info in models_info.items():
304
- title = info['title']
305
- cover = info['cover']
306
- gpt_weight = info['gpt_weight']
307
- sovits_weight = info['sovits_weight']
308
- example_reference = info['example_reference']
309
- transcripts = {}
310
- transcript_path = info["transcript_path"]
311
- path = os.path.dirname(transcript_path)
312
- with open(transcript_path, 'r', encoding='utf-8') as file:
313
- for line in file:
314
- line = line.strip().replace("\\", "/")
315
- wav,_,_, t = line.split("|")
316
- wav = os.path.basename(wav)
317
- transcripts[t] = os.path.join(os.path.join(path,"reference_audio"), wav)
318
-
319
- vq_model, ssl_model, t2s_model, hps, config, hz, max_sec = load_model(sovits_weight, gpt_weight)
320
-
321
-
322
- models.append(
323
- (
324
- i,
325
- title,
326
- cover,
327
- transcripts,
328
- example_reference,
329
- create_tts_fn(
330
- vq_model, ssl_model, t2s_model, hps, config, hz, max_sec
331
- )
332
- )
333
- )
334
- with gr.Blocks() as app:
335
- gr.Markdown(
336
- "# <center> GPT-SoVITS-V2-Gakuen Idolmaster\n"
337
- )
338
- with gr.Tabs():
339
- for (name, title, cover, transcripts, example_reference, tts_fn) in models:
340
- with gr.TabItem(name):
341
- with gr.Row():
342
- gr.Markdown(
343
- '<div align="center">'
344
- f'<a><strong>{title}</strong></a>'
345
- '</div>')
346
- with gr.Row():
347
- with gr.Column():
348
- prompt_text = gr.Dropdown(
349
- label="Transcript of the Reference Audio",
350
- value=example_reference if example_reference in transcripts else list(transcripts.keys())[0],
351
- choices=list(transcripts.keys())
352
- )
353
- inp_ref_audio = gr.Audio(
354
- label="Reference Audio",
355
- type="filepath",
356
- interactive=False,
357
- value=transcripts[example_reference] if example_reference in transcripts else list(transcripts.values())[0]
358
- )
359
- transcripts_state = gr.State(value=transcripts)
360
- prompt_text.change(
361
- fn=change_reference_audio,
362
- inputs=[prompt_text, transcripts_state],
363
- outputs=[inp_ref_audio]
364
- )
365
- prompt_language = gr.State(value="ja")
366
- with gr.Column():
367
- text = gr.Textbox(label="Input Text", value="こんにちは、私はあなたのAIアシスタントです。仲良くしましょうね。")
368
- text_language = gr.Dropdown(
369
- label="Language",
370
- choices=["ja"],
371
- value="ja"
372
- )
373
- clean_button = gr.Button("Clean Text", variant="primary")
374
- inference_button = gr.Button("Generate", variant="primary")
375
- cleaned_text = gr.Textbox(label="Cleaned Text")
376
- output = gr.Audio(label="Output Audio")
377
- output_file = gr.File(label="Output Audio File")
378
- om = gr.Textbox(label="Output Message")
379
- clean_button.click(
380
- fn=get_str_list_from_phone,
381
- inputs=[text, text_language],
382
- outputs=[cleaned_text]
383
- )
384
- inference_button.click(
385
- fn=tts_fn,
386
- inputs=[inp_ref_audio, prompt_text, prompt_language, cleaned_text, text_language, text],
387
- outputs=[om, output, output_file]
388
- )
389
-
 
390
  app.launch(share=True)
 
1
+ import os
2
+ import sys
3
+ import spaces
4
+ cnhubert_base_path = "TencentGameMate/chinese-hubert-base"
5
+ bert_path = "hfl/chinese-roberta-wwm-ext-large"
6
+ os.environ["version"] = 'v2'
7
+ now_dir = os.path.dirname(os.path.abspath(__file__))
8
+ sys.path.insert(0, now_dir)
9
+ sys.path.insert(0, os.path.join(now_dir, "GPT_SoVITS"))
10
+ sys.path.insert(0, os.path.join(now_dir, "GPT_SoVITS",'text'))
11
+ import site
12
+ site_packages_roots = []
13
+ for site_packages_root in site_packages_roots:
14
+ if os.path.exists(site_packages_root):
15
+ try:
16
+ with open("%s/users.pth" % (site_packages_root), "w") as f:
17
+ f.write(
18
+ "%s\n%s/tools\n%s/tools/damo_asr\n%s/GPT_SoVITS\n%s/tools/uvr5"
19
+ % (now_dir, now_dir, now_dir, now_dir, now_dir)
20
+ )
21
+ break
22
+ except PermissionError:
23
+ pass
24
+ import re
25
+ import gradio as gr
26
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
27
+ import numpy as np
28
+ import os,librosa,torch, audiosegment
29
+ from GPT_SoVITS.feature_extractor import cnhubert
30
+ cnhubert.cnhubert_base_path=cnhubert_base_path
31
+ from GPT_SoVITS.module.models import SynthesizerTrn
32
+ from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
33
+ from GPT_SoVITS.text import cleaned_text_to_sequence
34
+ from GPT_SoVITS.text.cleaner import clean_text
35
+ from time import time as ttime
36
+ from GPT_SoVITS.module.mel_processing import spectrogram_torch
37
+ import tempfile
38
+ from tools.my_utils import load_audio
39
+ import os
40
+ import json
41
+ # import pyopenjtalk
42
+ # cwd = os.getcwd()
43
+ # if os.path.exists(os.path.join(cwd,'user.dic')):
44
+ # pyopenjtalk.update_global_jtalk_with_user_dict(os.path.join(cwd, 'user.dic'))
45
+
46
+
47
+ import logging
48
+ logging.getLogger('httpx').setLevel(logging.WARNING)
49
+ logging.getLogger('httpcore').setLevel(logging.WARNING)
50
+ logging.getLogger('multipart').setLevel(logging.WARNING)
51
+
52
+ device = "cuda" if torch.cuda.is_available() else "cpu"
53
+ #device = "cpu"
54
+ is_half = False
55
+
56
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
57
+ bert_model=AutoModelForMaskedLM.from_pretrained(bert_path)
58
+ if(is_half==True):bert_model=bert_model.half().to(device)
59
+ else:bert_model=bert_model.to(device)
60
+ # bert_model=bert_model.to(device)
61
+ def get_bert_feature(text, word2ph): # Bert(不是HuBERT的特征计算)
62
+ with torch.no_grad():
63
+ inputs = tokenizer(text, return_tensors="pt")
64
+ for i in inputs:
65
+ inputs[i] = inputs[i].to(device)#####输入是long不用管精度问题,精度随bert_model
66
+ res = bert_model(**inputs, output_hidden_states=True)
67
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
68
+ assert len(word2ph) == len(text)
69
+ phone_level_feature = []
70
+ for i in range(len(word2ph)):
71
+ repeat_feature = res[i].repeat(word2ph[i], 1)
72
+ phone_level_feature.append(repeat_feature)
73
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
74
+ # if(is_half==True):phone_level_feature=phone_level_feature.half()
75
+ return phone_level_feature.T
76
+
77
+ loaded_sovits_model = [] # [(path, dict, model)]
78
+ loaded_gpt_model = []
79
+ ssl_model = cnhubert.get_model()
80
+ if (is_half == True):
81
+ ssl_model = ssl_model.half().to(device)
82
+ else:
83
+ ssl_model = ssl_model.to(device)
84
+
85
+
86
+ def load_model(sovits_path, gpt_path):
87
+ global ssl_model
88
+ global loaded_sovits_model
89
+ global loaded_gpt_model
90
+ vq_model = None
91
+ t2s_model = None
92
+ dict_s2 = None
93
+ dict_s1 = None
94
+ hps = None
95
+ for path, dict_s2_, model in loaded_sovits_model:
96
+ if path == sovits_path:
97
+ vq_model = model
98
+ dict_s2 = dict_s2_
99
+ break
100
+ for path, dict_s1_, model in loaded_gpt_model:
101
+ if path == gpt_path:
102
+ t2s_model = model
103
+ dict_s1 = dict_s1_
104
+ break
105
+
106
+ if dict_s2 is None:
107
+ dict_s2 = torch.load(sovits_path, map_location="cpu")
108
+ hps = dict_s2["config"]
109
+
110
+ if dict_s1 is None:
111
+ dict_s1 = torch.load(gpt_path, map_location="cpu")
112
+ config = dict_s1["config"]
113
+ class DictToAttrRecursive:
114
+ def __init__(self, input_dict):
115
+ for key, value in input_dict.items():
116
+ if isinstance(value, dict):
117
+ # 如果值是字典,递归调用构造函数
118
+ setattr(self, key, DictToAttrRecursive(value))
119
+ else:
120
+ setattr(self, key, value)
121
+
122
+ hps = DictToAttrRecursive(hps)
123
+ hps.model.semantic_frame_rate = "25hz"
124
+
125
+
126
+ if not vq_model:
127
+ vq_model = SynthesizerTrn(
128
+ hps.data.filter_length // 2 + 1,
129
+ hps.train.segment_size // hps.data.hop_length,
130
+ n_speakers=hps.data.n_speakers,
131
+ **hps.model)
132
+ if (is_half == True):
133
+ vq_model = vq_model.half().to(device)
134
+ else:
135
+ vq_model = vq_model.to(device)
136
+ vq_model.eval()
137
+ vq_model.load_state_dict(dict_s2["weight"], strict=False)
138
+ loaded_sovits_model.append((sovits_path, dict_s2, vq_model))
139
+ hz = 50
140
+ max_sec = config['data']['max_sec']
141
+ if not t2s_model:
142
+ t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
143
+ t2s_model.load_state_dict(dict_s1["weight"])
144
+ if (is_half == True): t2s_model = t2s_model.half()
145
+ t2s_model = t2s_model.to(device)
146
+ t2s_model.eval()
147
+ total = sum([param.nelement() for param in t2s_model.parameters()])
148
+ print("Number of parameter: %.2fM" % (total / 1e6))
149
+ loaded_gpt_model.append((gpt_path, dict_s1, t2s_model))
150
+ return vq_model, ssl_model, t2s_model, hps, config, hz, max_sec
151
+
152
+
153
+ def get_spepc(hps, filename):
154
+ audio=load_audio(filename,int(hps.data.sampling_rate))
155
+ audio = audio / np.max(np.abs(audio))
156
+ audio=torch.FloatTensor(audio)
157
+ print(torch.max(torch.abs(audio)))
158
+ audio_norm = audio
159
+ # audio_norm = audio / torch.max(torch.abs(audio))
160
+ audio_norm = audio_norm.unsqueeze(0)
161
+ print(torch.max(torch.abs(audio_norm)))
162
+ spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False)
163
+ return spec
164
+
165
+ def create_tts_fn(vq_model, ssl_model, t2s_model, hps, config, hz, max_sec):
166
+ @spaces.GPU()
167
+ def tts_fn(ref_wav_path, prompt_text, prompt_language, target_phone, text_language, target_text = None):
168
+ t0 = ttime()
169
+ prompt_text=prompt_text.strip()
170
+ print(prompt_text)
171
+ prompt_language=prompt_language
172
+ with torch.no_grad():
173
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
174
+ # maxx=0.95
175
+ # tmp_max = np.abs(wav16k).max()
176
+ # alpha=0.5
177
+ # wav16k = (wav16k / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * wav16k
178
+ #在这里归一化
179
+ #print(max(np.abs(wav16k)))
180
+ #wav16k = wav16k / np.max(np.abs(wav16k))
181
+ #print(max(np.abs(wav16k)))
182
+ # 添加0.3s的静音
183
+ wav16k = np.concatenate([wav16k, np.zeros(int(hps.data.sampling_rate * 0.3)),])
184
+ wav16k = torch.from_numpy(wav16k)
185
+ wav16k = wav16k.float()
186
+ if(is_half==True):wav16k=wav16k.half().to(device)
187
+ else:wav16k=wav16k.to(device)
188
+ print(wav16k.shape) # 读取16k音频
189
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float()
190
+ print(ssl_content.shape)
191
+ codes = vq_model.extract_latent(ssl_content)
192
+ print(codes.shape)
193
+ prompt_semantic = codes[0, 0]
194
+ t1 = ttime()
195
+ phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
196
+ phones1=cleaned_text_to_sequence(phones1)
197
+ #texts=text.split("\n")
198
+ audio_opt = []
199
+ zero_wav=np.zeros(int(hps.data.sampling_rate*0.3),dtype=np.float16 if is_half==True else np.float32)
200
+ phones = get_phone_from_str_list(target_phone, text_language)
201
+ for phones2 in phones:
202
+ if(len(phones2) == 0):
203
+ continue
204
+ if(len(phones2) == 1 and phones2[0] == ""):
205
+ continue
206
+ #phones2, word2ph2, norm_text2 = clean_text(text, text_language)
207
+ print(phones2)
208
+ phones2 = cleaned_text_to_sequence(phones2)
209
+ #if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
210
+ bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device)
211
+ #if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
212
+ bert2 = torch.zeros((1024, len(phones2))).to(bert1)
213
+ bert = torch.cat([bert1, bert2], 1)
214
+
215
+ all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
216
+ bert = bert.to(device).unsqueeze(0)
217
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
218
+ prompt = prompt_semantic.unsqueeze(0).to(device)
219
+ t2 = ttime()
220
+ idx = 0
221
+ cnt = 0
222
+ while idx == 0 and cnt < 2:
223
+ with torch.no_grad():
224
+ # pred_semantic = t2s_model.model.infer
225
+ pred_semantic,idx = t2s_model.model.infer_panel(
226
+ all_phoneme_ids,
227
+ all_phoneme_len,
228
+ prompt,
229
+ bert,
230
+ # prompt_phone_len=ph_offset,
231
+ top_k=config['inference']['top_k'],
232
+ early_stop_num=hz * max_sec)
233
+ t3 = ttime()
234
+ cnt+=1
235
+ if idx == 0:
236
+ return "Error: Generation failure: bad zero prediction.", None
237
+ pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
238
+ refer = get_spepc(hps, ref_wav_path)#.to(device)
239
+ if(is_half==True):refer=refer.half().to(device)
240
+ else:refer=refer.to(device)
241
+ # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
242
+ audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分
243
+ audio_opt.append(audio)
244
+ audio_opt.append(zero_wav)
245
+ t4 = ttime()
246
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
247
+
248
+ audio = (hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16))
249
+
250
+ filename = tempfile.mktemp(suffix=".wav",prefix=f"{prompt_text[:8]}_{target_text[:8]}_")
251
+ audiosegment.from_numpy_array(audio[1], framerate=audio[0]).export(filename, format="WAV")
252
+ return "Success", (hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16)), filename
253
+ return tts_fn
254
+
255
+
256
+ def get_str_list_from_phone(text, text_language):
257
+ # raw文本过g2p得到音素列表,再转成字符串
258
+ # 注意,这里的text是一个段落,可能包含多个句子
259
+ # 段落间\n分割,音素间空格分割
260
+ texts=text.split("\n")
261
+ phone_list = []
262
+ for text in texts:
263
+ phones2, word2ph2, norm_text2 = clean_text(text, text_language)
264
+ phone_list.append(" ".join(phones2))
265
+ return "\n".join(phone_list)
266
+
267
+ def get_phone_from_str_list(str_list:str, language:str = 'ja'):
268
+ # 从音素字符串中得到音素列表
269
+ # 注意,这里的text是一个段落,可能包含多个句子
270
+ # 段落间\n分割,音素间空格分割
271
+ sentences = str_list.split("\n")
272
+ phones = []
273
+ for sentence in sentences:
274
+ phones.append(sentence.split(" "))
275
+ return phones
276
+
277
+ splits={",","。","?","!",",",".","?","!","~",":",":","—","…",}#不考虑省略号
278
+ def split(todo_text):
279
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
280
+ if (todo_text[-1] not in splits): todo_text += "。"
281
+ i_split_head = i_split_tail = 0
282
+ len_text = len(todo_text)
283
+ todo_texts = []
284
+ while (1):
285
+ if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
286
+ if (todo_text[i_split_head] in splits):
287
+ i_split_head += 1
288
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
289
+ i_split_tail = i_split_head
290
+ else:
291
+ i_split_head += 1
292
+ return todo_texts
293
+
294
+
295
+ def change_reference_audio(prompt_text, transcripts):
296
+ return transcripts[prompt_text]
297
+
298
+
299
+ models = []
300
+ models_info = json.load(open("./models/models_info.json", "r", encoding="utf-8"))
301
+
302
+
303
+
304
+ for i, info in models_info.items():
305
+ title = info['title']
306
+ cover = info['cover']
307
+ gpt_weight = info['gpt_weight']
308
+ sovits_weight = info['sovits_weight']
309
+ example_reference = info['example_reference']
310
+ transcripts = {}
311
+ transcript_path = info["transcript_path"]
312
+ path = os.path.dirname(transcript_path)
313
+ with open(transcript_path, 'r', encoding='utf-8') as file:
314
+ for line in file:
315
+ line = line.strip().replace("\\", "/")
316
+ wav,_,_, t = line.split("|")
317
+ wav = os.path.basename(wav)
318
+ transcripts[t] = os.path.join(os.path.join(path,"reference_audio"), wav)
319
+
320
+ vq_model, ssl_model, t2s_model, hps, config, hz, max_sec = load_model(sovits_weight, gpt_weight)
321
+
322
+
323
+ models.append(
324
+ (
325
+ i,
326
+ title,
327
+ cover,
328
+ transcripts,
329
+ example_reference,
330
+ create_tts_fn(
331
+ vq_model, ssl_model, t2s_model, hps, config, hz, max_sec
332
+ )
333
+ )
334
+ )
335
+ with gr.Blocks() as app:
336
+ gr.Markdown(
337
+ "# <center> GPT-SoVITS-V2-Gakuen Idolmaster\n"
338
+ )
339
+ with gr.Tabs():
340
+ for (name, title, cover, transcripts, example_reference, tts_fn) in models:
341
+ with gr.TabItem(name):
342
+ with gr.Row():
343
+ gr.Markdown(
344
+ '<div align="center">'
345
+ f'<a><strong>{title}</strong></a>'
346
+ '</div>')
347
+ with gr.Row():
348
+ with gr.Column():
349
+ prompt_text = gr.Dropdown(
350
+ label="Transcript of the Reference Audio",
351
+ value=example_reference if example_reference in transcripts else list(transcripts.keys())[0],
352
+ choices=list(transcripts.keys())
353
+ )
354
+ inp_ref_audio = gr.Audio(
355
+ label="Reference Audio",
356
+ type="filepath",
357
+ interactive=False,
358
+ value=transcripts[example_reference] if example_reference in transcripts else list(transcripts.values())[0]
359
+ )
360
+ transcripts_state = gr.State(value=transcripts)
361
+ prompt_text.change(
362
+ fn=change_reference_audio,
363
+ inputs=[prompt_text, transcripts_state],
364
+ outputs=[inp_ref_audio]
365
+ )
366
+ prompt_language = gr.State(value="ja")
367
+ with gr.Column():
368
+ text = gr.Textbox(label="Input Text", value="こんにちは、私はあなたのAIアシスタントです。仲良くしましょうね。")
369
+ text_language = gr.Dropdown(
370
+ label="Language",
371
+ choices=["ja"],
372
+ value="ja"
373
+ )
374
+ clean_button = gr.Button("Clean Text", variant="primary")
375
+ inference_button = gr.Button("Generate", variant="primary")
376
+ cleaned_text = gr.Textbox(label="Cleaned Text")
377
+ output = gr.Audio(label="Output Audio")
378
+ output_file = gr.File(label="Output Audio File")
379
+ om = gr.Textbox(label="Output Message")
380
+ clean_button.click(
381
+ fn=get_str_list_from_phone,
382
+ inputs=[text, text_language],
383
+ outputs=[cleaned_text]
384
+ )
385
+ inference_button.click(
386
+ fn=tts_fn,
387
+ inputs=[inp_ref_audio, prompt_text, prompt_language, cleaned_text, text_language, text],
388
+ outputs=[om, output, output_file]
389
+ )
390
+
391
  app.launch(share=True)