Spaces:
Running
Running
from typing import Callable | |
import gradio as gr | |
from fish_speech.i18n import i18n | |
from tools.inference_engine.utils import normalize_text | |
from tools.webui.variables import HEADER_MD, TEXTBOX_PLACEHOLDER | |
def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks: | |
with gr.Blocks(theme=gr.themes.Base()) as app: | |
gr.Markdown(HEADER_MD) | |
# Use light theme by default | |
app.load( | |
None, | |
None, | |
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}" | |
% theme, | |
) | |
# Inference | |
with gr.Row(): | |
with gr.Column(scale=3): | |
text = gr.Textbox( | |
label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10 | |
) | |
refined_text = gr.Textbox( | |
label=i18n("Realtime Transform Text"), | |
placeholder=i18n( | |
"Normalization Result Preview (Currently Only Chinese)" | |
), | |
lines=5, | |
interactive=False, | |
) | |
with gr.Row(): | |
normalize = gr.Checkbox( | |
label=i18n("Text Normalization"), | |
value=False, | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Tab(label=i18n("Advanced Config")): | |
with gr.Row(): | |
chunk_length = gr.Slider( | |
label=i18n("Iterative Prompt Length, 0 means off"), | |
minimum=0, | |
maximum=300, | |
value=200, | |
step=8, | |
) | |
max_new_tokens = gr.Slider( | |
label=i18n( | |
"Maximum tokens per batch, 0 means no limit" | |
), | |
minimum=0, | |
maximum=2048, | |
value=0, | |
step=8, | |
) | |
with gr.Row(): | |
top_p = gr.Slider( | |
label="Top-P", | |
minimum=0.6, | |
maximum=0.9, | |
value=0.7, | |
step=0.01, | |
) | |
repetition_penalty = gr.Slider( | |
label=i18n("Repetition Penalty"), | |
minimum=1, | |
maximum=1.5, | |
value=1.2, | |
step=0.01, | |
) | |
with gr.Row(): | |
temperature = gr.Slider( | |
label="Temperature", | |
minimum=0.6, | |
maximum=0.9, | |
value=0.7, | |
step=0.01, | |
) | |
seed = gr.Number( | |
label="Seed", | |
info="0 means randomized inference, otherwise deterministic", | |
value=0, | |
) | |
with gr.Tab(label=i18n("Reference Audio")): | |
with gr.Row(): | |
gr.Markdown( | |
i18n( | |
"5 to 10 seconds of reference audio, useful for specifying speaker." | |
) | |
) | |
with gr.Row(): | |
reference_id = gr.Textbox( | |
label=i18n("Reference ID"), | |
placeholder="Leave empty to use uploaded references", | |
) | |
with gr.Row(): | |
use_memory_cache = gr.Radio( | |
label=i18n("Use Memory Cache"), | |
choices=["on", "off"], | |
value="on", | |
) | |
with gr.Row(): | |
reference_audio = gr.Audio( | |
label=i18n("Reference Audio"), | |
type="filepath", | |
) | |
with gr.Row(): | |
reference_text = gr.Textbox( | |
label=i18n("Reference Text"), | |
lines=1, | |
placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。", | |
value="", | |
) | |
with gr.Column(scale=3): | |
with gr.Row(): | |
error = gr.HTML( | |
label=i18n("Error Message"), | |
visible=True, | |
) | |
with gr.Row(): | |
audio = gr.Audio( | |
label=i18n("Generated Audio"), | |
type="numpy", | |
interactive=False, | |
visible=True, | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
generate = gr.Button( | |
value="\U0001F3A7 " + i18n("Generate"), | |
variant="primary", | |
) | |
text.input(fn=normalize_text, inputs=[text, normalize], outputs=[refined_text]) | |
# Submit | |
generate.click( | |
inference_fct, | |
[ | |
refined_text, | |
normalize, | |
reference_id, | |
reference_audio, | |
reference_text, | |
max_new_tokens, | |
chunk_length, | |
top_p, | |
repetition_penalty, | |
temperature, | |
seed, | |
use_memory_cache, | |
], | |
[audio, error], | |
concurrency_limit=1, | |
) | |
return app | |