vits-chatbot / audiobook.py
Mahiruoshi's picture
Upload audiobook.py
332dcef
raw
history blame
8.98 kB
import json
import re
import numpy as np
import IPython.display as ipd
import torch
import commons
import utils
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence
import gradio as gr
import time
import json
import datetime
import os
import pickle
from scipy.io.wavfile import write
import librosa
import romajitable
from mel_processing import spectrogram_torch
import soundfile as sf
from scipy import signal
class VitsGradio:
def __init__(self):
self.lan = ["中文","日文","自动"]
self.modelPaths = []
for root,dirs,files in os.walk("checkpoints"):
for dir in dirs:
self.modelPaths.append(dir)
with gr.Blocks() as self.Vits:
with gr.Tab("小说合成"):
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
self.Text = gr.File(label="Text")
self.audio_path = gr.TextArea(label="音频路径",lines=1,value = 'audiobook/chapter.wav')
btnbook = gr.Button("小说合成")
btnbook.click(self.tts_fn, inputs=[self.Text,self.audio_path])
with gr.Tab("TTS设定"):
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
self.input1 = gr.Dropdown(label = "模型", choices = self.modelPaths, value = self.modelPaths[0], type = "value")
self.input2 = gr.Dropdown(label="Language", choices=self.lan, value="自动", interactive=True)
self.input3 = gr.Dropdown(label="Speaker", choices=list(range(1001)), value=0, interactive=True)
self.input4 = gr.Slider(minimum=0, maximum=1.0, label="更改噪声比例(noise scale),以控制情感", value=0.6)
self.input5 = gr.Slider(minimum=0, maximum=1.0, label="更改噪声偏差(noise scale w),以控制音素长短", value=0.667)
self.input6 = gr.Slider(minimum=0.1, maximum=10, label="duration", value=1)
statusa = gr.TextArea()
btnVC = gr.Button("完成vits TTS端设定")
btnVC.click(self.create_tts_fn, inputs=[self.input1, self.input2, self.input3, self.input4, self.input5, self.input6], outputs = [statusa])
def is_japanese(self,string):
for ch in string:
if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
return True
return False
def is_english(self,string):
import re
pattern = re.compile('^[A-Za-z0-9.,:;!?()_*"\' ]+$')
if pattern.fullmatch(string):
return True
else:
return False
def get_text(self,text, hps, cleaned=False):
if cleaned:
text_norm = text_to_sequence(text, self.hps_ms.symbols, [])
else:
text_norm = text_to_sequence(text, self.hps_ms.symbols, self.hps_ms.data.text_cleaners)
if self.hps_ms.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
def get_label(self,text, label):
if f'[{label}]' in text:
return True, text.replace(f'[{label}]', '')
else:
return False, text
def sle(self,language,text):
text = text.replace('\n','。').replace(' ',',')
if language == "中文":
tts_input1 = "[ZH]" + text + "[ZH]"
return tts_input1
elif language == "自动":
tts_input1 = f"[JA]{text}[JA]" if self.is_japanese(text) else f"[ZH]{text}[ZH]"
return tts_input1
elif language == "日文":
tts_input1 = "[JA]" + text + "[JA]"
return tts_input1
def create_tts_fn(self,path, input2, input3, n_scale= 0.667,n_scale_w = 0.8, l_scale = 1 ):
self.language = input2
self.speaker_id = int(input3)
self.n_scale = n_scale
self.n_scale_w = n_scale_w
self.l_scale = l_scale
self.dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.hps_ms = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
self.n_speakers = self.hps_ms.data.n_speakers if 'n_speakers' in self.hps_ms.data.keys() else 0
self.n_symbols = len(self.hps_ms.symbols) if 'symbols' in self.hps_ms.keys() else 0
self.net_g_ms = SynthesizerTrn(
self.n_symbols,
self.hps_ms.data.filter_length // 2 + 1,
self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
n_speakers=self.n_speakers,
**self.hps_ms.model).to(self.dev)
_ = self.net_g_ms.eval()
_ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", self.net_g_ms)
return 'success'
def transfer(self,text):
text = re.sub("<[^>]*>","",text)
result_list = re.split(r'\n', text)
final_list = []
for j in result_list:
result_list2 = re.split(r'。|!|——|:|;|……|——|。|!', j)
for i in result_list2:
if self.is_english(i):
i = romajitable.to_kana(i).katakana
for m in range(20):
i = i.replace('\n','').replace(' ','').replace('……','。').replace('…','。').replace('还','孩').replace('“','').replace('”','').replace('!','。').replace('」','').replace('「','')
#Current length of single sentence: 50
if len(i)>1:
if len(i) > 50:
try:
cur_list = re.split(r'。|!|——|,|:', i)
for i in cur_list:
if len(i)>1:
final_list.append(i+'。')
except:
pass
else:
final_list.append(i)
final_list = [x for x in final_list if x != '']
return final_list
def tts_fn(self,text,audio_path):
with open(text.name, "r", encoding="utf-8") as f:
text = f.read()
a = ['【','[','(','(','〔']
b = ['】',']',')',')','〕']
for i in a:
text = text.replace(i,'<')
for i in b:
text = text.replace(i,'>')
final_list = self.transfer(text)
split_list = []
while len(final_list) > 0:
split_list.append(final_list[:1000])
final_list = final_list[1000:]
c0 = 0
for lists in split_list:
audio_fin = []
t = datetime.timedelta(seconds=0)
c = 0
f1 = open(audio_path.replace('.wav',str(c0)+".srt"),'w',encoding='utf-8')
for sentence in lists:
try:
c +=1
with torch.no_grad():
stn_tst = self.get_text(self.sle(self.language,sentence), self.hps_ms, cleaned=False)
x_tst = stn_tst.unsqueeze(0).to(self.dev)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(self.dev)
sid = torch.LongTensor([self.speaker_id]).to(self.dev)
t1 = time.time()
audio = self.net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=self.n_scale, noise_scale_w=self.n_scale_w, length_scale=self.l_scale)[0][
0, 0].data.cpu().float().numpy()
t2 = time.time()
spending_time = "第"+str(c)+"句的推理时间为:"+str(t2-t1)+"s"
print(spending_time)
time_start = str(t).split(".")[0] + "," + str(t.microseconds)[:3]
last_time = datetime.timedelta(seconds=len(audio)/float(22050))
t+=last_time
time_end = str(t).split(".")[0] + "," + str(t.microseconds)[:3]
print(time_end)
f1.write(str(c-1)+'\n'+time_start+' --> '+time_end+'\n'+sentence.replace('。','')+'\n\n')
resampled_audio_data = signal.resample(audio, len(audio) * 2)
audio_fin.append(resampled_audio_data)
except:
pass
sf.write(audio_path.replace('.wav',str(c0)+'.wav'), np.concatenate(audio_fin), 44100, 'PCM_24')
c0 += 1
file_path = audio_path.replace('.wav',str(c0)+".srt")
if __name__ == '__main__':
print("开始部署")
grVits = VitsGradio()
grVits.Vits.launch()