Hilley's picture
Update app.py
34de65c verified
raw
history blame
5.97 kB
import spaces
import os
import random
import argparse
import torch
import gradio as gr
import numpy as np
import ChatTTS
import se_extractor
from api import ToneColorConverter
import soundfile
print("loading ChatTTS model...")
chat = ChatTTS.Chat()
chat.load_models()
def generate_seed():
new_seed = random.randint(1, 100000000)
return {
"__type__": "update",
"value": new_seed
}
@spaces.GPU
def chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, output_path=None):
torch.manual_seed(audio_seed_input)
rand_spk = torch.randn(768)
params_infer_code = {
'spk_emb': rand_spk,
'temperature': temperature,
'top_P': top_P,
'top_K': top_K,
}
params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'}
torch.manual_seed(text_seed_input)
if refine_text_flag:
if refine_text_input:
params_refine_text['prompt'] = refine_text_input
text = chat.infer(text,
skip_refine_text=False,
refine_text_only=True,
params_refine_text=params_refine_text,
params_infer_code=params_infer_code
)
print("Text has been refined!")
wav = chat.infer(text,
skip_refine_text=True,
params_refine_text=params_refine_text,
params_infer_code=params_infer_code
)
audio_data = np.array(wav[0]).flatten()
sample_rate = 22050
text_data = text[0] if isinstance(text, list) else text
if output_path is None:
return [(sample_rate, audio_data), text_data]
else:
soundfile.write(output_path, audio_data, sample_rate)
return text_data
# OpenVoice Clone
ckpt_converter_en = 'checkpoints/converter'
device = 'cuda:0'
#device = "cpu"
tone_color_converter = ToneColorConverter(f'{ckpt_converter_en}/config.json', device=device)
tone_color_converter.load_ckpt(f'{ckpt_converter_en}/checkpoint.pth')
def generate_audio(text, audio_ref, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input):
save_path = "output.wav"
if audio_ref is not None:
# Run the base speaker tts
src_path = "tmp.wav"
text_data = chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, src_path)
print("Ready for voice cloning!")
source_se, audio_name = se_extractor.get_se(src_path, tone_color_converter, target_dir='processed', vad=True)
reference_speaker = audio_ref
target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
print("Get voices segment!")
# Run the tone color converter
# convert from file
tone_color_converter.convert(
audio_src_path=src_path,
src_se=source_se,
tgt_se=target_se,
output_path=save_path)
else:
chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, save_path)
print("Finished!")
return [save_path, text_data]
with gr.Blocks() as demo:
gr.Markdown("# Enjoy chatting with your ai friends on website, telegram and so on! (https://linkin.love)")
default_text = "Today a man knocked on my door and asked for a small donation toward the local swimming pool. I gave him a glass of water."
text_input = gr.Textbox(label="Input Text", lines=4, placeholder="Please Input Text...", value=default_text)
default_refine_text = "[oral_2][laugh_0][break_6]"
refine_text_checkbox = gr.Checkbox(label="Refine text:'oral' means add filler words, 'laugh' means add laughter, and 'break' means add a pause. (0-10) ", value=True)
refine_text_input = gr.Textbox(label="Refine Prompt", lines=1, placeholder="Please Refine Prompt...", value=default_refine_text)
with gr.Column():
voice_ref = gr.Audio(label="请上传您喜欢的语音文件", type="filepath", value="")
with gr.Row():
temperature_slider = gr.Slider(minimum=0.00001, maximum=1.0, step=0.00001, value=0.3, label="Audio temperature")
top_p_slider = gr.Slider(minimum=0.1, maximum=0.9, step=0.05, value=0.7, label="top_P")
top_k_slider = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_K")
with gr.Row():
audio_seed_input = gr.Number(value=42, label="Speaker Seed")
generate_audio_seed = gr.Button("\U0001F3B2")
text_seed_input = gr.Number(value=42, label="Text Seed")
generate_text_seed = gr.Button("\U0001F3B2")
generate_button = gr.Button("Generate")
text_output = gr.Textbox(label="Refined Text", interactive=False)
audio_output = gr.Audio(label="Output Audio")
generate_audio_seed.click(generate_seed,
inputs=[],
outputs=audio_seed_input)
generate_text_seed.click(generate_seed,
inputs=[],
outputs=text_seed_input)
generate_button.click(generate_audio,
inputs=[text_input, voice_ref, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox, refine_text_input],
outputs=[audio_output,text_output])
parser = argparse.ArgumentParser(description='ChatTTS demo Launch')
parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
parser.add_argument('--server_port', type=int, default=8080, help='Server port')
args = parser.parse_args()
# demo.launch(server_name=args.server_name, server_port=args.server_port, inbrowser=True)
if __name__ == '__main__':
demo.launch()