ChatTTS2 / examples /web /webui.py
zhengr's picture
init
c02bdcd
raw
history blame
9.2 kB
import os, sys
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
now_dir = os.getcwd()
sys.path.append(now_dir)
import argparse
import gradio as gr
from funcs import *
from ex import ex
def main():
with gr.Blocks() as demo:
gr.Markdown("# ChatTTS WebUI")
gr.Markdown("- **GitHub Repo**: https://github.com/2noise/ChatTTS")
gr.Markdown("- **HuggingFace Repo**: https://huggingface.co/2Noise/ChatTTS")
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="Input Text",
lines=4,
max_lines=4,
placeholder="Please Input Text...",
value=ex[0][0],
interactive=True,
)
sample_text_input = gr.Textbox(
label="Sample Text",
lines=4,
max_lines=4,
placeholder="If Sample Audio and Sample Text are available, the Speaker Embedding will be disabled.",
interactive=True,
)
with gr.Column():
with gr.Tab(label="Sample Audio"):
sample_audio_input = gr.Audio(
value=None,
type="filepath",
interactive=True,
show_label=False,
waveform_options=gr.WaveformOptions(
sample_rate=24000,
),
scale=1,
)
with gr.Tab(label="Sample Audio Code"):
sample_audio_code_input = gr.Textbox(
lines=12,
max_lines=12,
show_label=False,
placeholder="Paste the Code copied before after uploading Sample Audio.",
interactive=True,
)
with gr.Row():
refine_text_checkbox = gr.Checkbox(
label="Refine text", value=ex[0][6], interactive=True
)
temperature_slider = gr.Slider(
minimum=0.00001,
maximum=1.0,
step=0.00001,
value=ex[0][1],
label="Audio Temperature",
interactive=True,
)
top_p_slider = gr.Slider(
minimum=0.1,
maximum=0.9,
step=0.05,
value=ex[0][2],
label="top_P",
interactive=True,
)
top_k_slider = gr.Slider(
minimum=1,
maximum=20,
step=1,
value=ex[0][3],
label="top_K",
interactive=True,
)
with gr.Row():
voice_selection = gr.Dropdown(
label="Timbre",
choices=voices.keys(),
value="Default",
interactive=True,
)
audio_seed_input = gr.Number(
value=ex[0][4],
label="Audio Seed",
interactive=True,
minimum=seed_min,
maximum=seed_max,
)
generate_audio_seed = gr.Button("\U0001F3B2", interactive=True)
text_seed_input = gr.Number(
value=ex[0][5],
label="Text Seed",
interactive=True,
minimum=seed_min,
maximum=seed_max,
)
generate_text_seed = gr.Button("\U0001F3B2", interactive=True)
with gr.Row():
spk_emb_text = gr.Textbox(
label="Speaker Embedding",
max_lines=3,
show_copy_button=True,
interactive=True,
scale=2,
)
dvae_coef_text = gr.Textbox(
label="DVAE Coefficient",
max_lines=3,
show_copy_button=True,
interactive=True,
scale=2,
)
reload_chat_button = gr.Button("Reload", scale=1, interactive=True)
with gr.Row():
auto_play_checkbox = gr.Checkbox(
label="Auto Play", value=False, scale=1, interactive=True
)
stream_mode_checkbox = gr.Checkbox(
label="Stream Mode",
value=False,
scale=1,
interactive=True,
)
generate_button = gr.Button(
"Generate", scale=2, variant="primary", interactive=True
)
interrupt_button = gr.Button(
"Interrupt",
scale=2,
variant="stop",
visible=False,
interactive=False,
)
text_output = gr.Textbox(
label="Output Text",
interactive=False,
show_copy_button=True,
)
sample_audio_input.change(
fn=on_upload_sample_audio,
inputs=sample_audio_input,
outputs=sample_audio_code_input,
).then(fn=lambda: gr.Info("Sampled Audio Code generated at another Tab."))
# 使用Gradio的回调功能来更新数值输入框
voice_selection.change(
fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input
)
generate_audio_seed.click(generate_seed, outputs=audio_seed_input)
generate_text_seed.click(generate_seed, outputs=text_seed_input)
audio_seed_input.change(
on_audio_seed_change, inputs=audio_seed_input, outputs=spk_emb_text
)
reload_chat_button.click(
reload_chat, inputs=dvae_coef_text, outputs=dvae_coef_text
)
interrupt_button.click(interrupt_generate)
@gr.render(inputs=[auto_play_checkbox, stream_mode_checkbox])
def make_audio(autoplay, stream):
audio_output = gr.Audio(
label="Output Audio",
value=None,
format="mp3" if use_mp3 and not stream else "wav",
autoplay=autoplay,
streaming=stream,
interactive=False,
show_label=True,
waveform_options=gr.WaveformOptions(
sample_rate=24000,
),
)
generate_button.click(
fn=set_buttons_before_generate,
inputs=[generate_button, interrupt_button],
outputs=[generate_button, interrupt_button],
).then(
refine_text,
inputs=[
text_input,
text_seed_input,
refine_text_checkbox,
temperature_slider,
top_p_slider,
top_k_slider,
],
outputs=text_output,
).then(
generate_audio,
inputs=[
text_output,
temperature_slider,
top_p_slider,
top_k_slider,
spk_emb_text,
stream_mode_checkbox,
audio_seed_input,
sample_text_input,
sample_audio_code_input,
],
outputs=audio_output,
).then(
fn=set_buttons_after_generate,
inputs=[generate_button, interrupt_button, audio_output],
outputs=[generate_button, interrupt_button],
)
gr.Examples(
examples=ex,
inputs=[
text_input,
temperature_slider,
top_p_slider,
top_k_slider,
audio_seed_input,
text_seed_input,
refine_text_checkbox,
],
)
parser = argparse.ArgumentParser(description="ChatTTS demo Launch")
parser.add_argument(
"--server_name", type=str, default="0.0.0.0", help="server name"
)
parser.add_argument("--server_port", type=int, default=8080, help="server port")
parser.add_argument("--root_path", type=str, default=None, help="root path")
parser.add_argument(
"--custom_path", type=str, default=None, help="custom model path"
)
parser.add_argument(
"--coef", type=str, default=None, help="custom dvae coefficient"
)
args = parser.parse_args()
logger.info("loading ChatTTS model...")
if load_chat(args.custom_path, args.coef):
logger.info("Models loaded successfully.")
else:
logger.error("Models load failed.")
sys.exit(1)
spk_emb_text.value = on_audio_seed_change(audio_seed_input.value)
dvae_coef_text.value = chat.coef
demo.launch(
server_name=args.server_name,
server_port=args.server_port,
root_path=args.root_path,
inbrowser=True,
show_api=False,
)
if __name__ == "__main__":
main()