locseed / app.py
M4xjunior's picture
Update app.py
1384004 verified
raw
history blame
18.6 kB
import nltk
nltk.download('punkt_tab')
from sentence_analyzer import SentenceAnalyzer
import re
import tempfile
from collections import OrderedDict
from importlib.resources import files
import click
import gradio as gr
import numpy as np
import soundfile as sf
import torchaudio
from cached_path import cached_path
from transformers import AutoModelForCausalLM, AutoTokenizer
try:
import spaces
USING_SPACES = True
except ImportError:
USING_SPACES = False
def gpu_decorator(func):
if USING_SPACES:
return spaces.GPU(func)
else:
return func
from f5_tts.model import DiT, UNetT
from f5_tts.infer.utils_infer import (
load_vocoder,
load_model,
preprocess_ref_audio_text,
infer_process,
remove_silence_for_generated_wav,
save_spectrogram,
)
# Carregar vocoder
vocoder = load_vocoder()
import os
from huggingface_hub import hf_hub_download
def load_f5tts():
# Carrega o caminho do repositório e o nome do arquivo das variáveis de ambiente
repo_id = os.getenv("MODEL_REPO_ID", "SWivid/F5-TTS/F5TTS_Base")
filename = os.getenv("MODEL_FILENAME", "model_1200000.safetensors")
token = os.getenv("HUGGINGFACE_TOKEN")
# Valida se o token está presente
if not token:
raise ValueError("A variável de ambiente 'HUGGINGFACE_TOKEN' não foi definida.")
# Faz o download do modelo do repositório privado
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename, use_auth_token=token)
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
return load_model(DiT, F5TTS_model_cfg, ckpt_path)
# Carregar modelo F5TTS
F5TTS_ema_model = load_f5tts()
@gpu_decorator
def infer(
ref_audio_orig, ref_text, gen_text, remove_silence, cross_fade_duration=0.15, speed=1, show_info=gr.Info
):
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
ema_model = F5TTS_ema_model
final_wave, final_sample_rate, combined_spectrogram = infer_process(
ref_audio,
ref_text.lower().strip(),
gen_text.lower().strip(),
ema_model,
vocoder,
cross_fade_duration=cross_fade_duration,
speed=speed,
show_info=show_info,
progress=gr.Progress(),
)
# Remover silêncios
if remove_silence:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
sf.write(f.name, final_wave, final_sample_rate)
remove_silence_for_generated_wav(f.name)
final_wave, _ = torchaudio.load(f.name)
final_wave = final_wave.squeeze().cpu().numpy()
# Salvar espectrograma
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
spectrogram_path = tmp_spectrogram.name
save_spectrogram(combined_spectrogram, spectrogram_path)
return (final_sample_rate, final_wave), spectrogram_path, ref_text
# Estilos CSS
custom_css = """
#sentences-container {
border: 1px solid #ddd;
border-radius: 4px;
padding: 10px;
margin-bottom: 10px;
}
.sentence-box {
border: 1px solid #eee;
padding: 5px;
margin-bottom: 5px;
border-radius: 4px;
background-color: #f9f9f9;
}
"""
with gr.Blocks(css=custom_css) as app:
with gr.Tabs():
with gr.Tab("TTS Básico"):
gr.Markdown("# TTS Básico com F5-TTS")
ref_audio_input = gr.Audio(label="Áudio de Referência", type="filepath")
gen_text_input = gr.Textbox(label="Texto para Gerar", lines=10)
generate_btn = gr.Button("Sintetizar", variant="primary")
with gr.Accordion("Configurações Avançadas", open=False):
ref_text_input = gr.Textbox(
label="Texto de Referência",
info="Deixe em branco para transcrever automaticamente o áudio de referência. Se você inserir texto, ele substituirá a transcrição automática.",
lines=2,
)
remove_silence = gr.Checkbox(
label="Remover Silêncios",
info="O modelo tende a produzir silêncios, especialmente em áudios mais longos. Podemos remover manualmente os silêncios, se necessário. Isso também aumentará o tempo de geração.",
value=False,
)
speed_slider = gr.Slider(
label="Velocidade",
minimum=0.3,
maximum=2.0,
value=1.0,
step=0.1,
info="Ajuste a velocidade do áudio.",
)
cross_fade_duration_slider = gr.Slider(
label="Duração do Cross-fade (s)",
minimum=0.0,
maximum=1.0,
value=0.15,
step=0.01,
info="Defina a duração do cross-fade entre os clipes de áudio.",
)
sentence_count_slider = gr.Slider(
label="Número de Sentenças por Vez",
minimum=1,
maximum=10,
value=1,
step=1,
info="Selecione quantas sentenças serão geradas por vez.",
)
audio_output = gr.Audio(label="Áudio Sintetizado")
spectrogram_output = gr.Image(label="Espectrograma")
analyzer = SentenceAnalyzer()
@gpu_decorator
def basic_tts(
ref_audio_input,
ref_text_input,
gen_text_input,
remove_silence,
cross_fade_duration_slider,
speed_slider,
sentence_count_slider,
):
# Divida o texto em sentenças
sentences = analyzer.split_into_sentences(gen_text_input)
num_sentences = min(len(sentences), sentence_count_slider)
# Gere áudio para o número selecionado de sentenças
audio_segments = []
for i in range(num_sentences):
audio_out, spectrogram_path, ref_text_out = infer(
ref_audio_input,
ref_text_input,
sentences[i],
remove_silence,
cross_fade_duration_slider,
speed_slider,
)
sr, audio_data = audio_out
audio_segments.append(audio_data)
# Concatene os segmentos de áudio
if audio_segments:
final_audio_data = np.concatenate(audio_segments)
return (sr, final_audio_data), spectrogram_path, gr.update(value=ref_text_out)
else:
gr.Warning("Nenhum áudio gerado.")
return None, None, gr.update(value=ref_text_out)
generate_btn.click(
basic_tts,
inputs=[
ref_audio_input,
ref_text_input,
gen_text_input,
remove_silence,
cross_fade_duration_slider,
speed_slider,
sentence_count_slider,
],
outputs=[audio_output, spectrogram_output],
)
with gr.Tab("Multi-Speech"):
gr.Markdown("# Geração Multi-Speech com F5-TTS")
with gr.Row():
with gr.Column():
regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
regular_insert = gr.Button("Insert Label", variant="secondary")
regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=2)
# Regular speech type (max 100)
max_speech_types = 100
speech_type_rows = [] # 99
speech_type_names = [regular_name] # 100
speech_type_audios = [regular_audio] # 100
speech_type_ref_texts = [regular_ref_text] # 100
speech_type_delete_btns = [] # 99
speech_type_insert_btns = [regular_insert] # 100
# Additional speech types (99 more)
for i in range(max_speech_types - 1):
with gr.Row(visible=False) as row:
with gr.Column():
name_input = gr.Textbox(label="Speech Type Name")
delete_btn = gr.Button("Delete Type", variant="secondary")
insert_btn = gr.Button("Insert Label", variant="secondary")
audio_input = gr.Audio(label="Reference Audio", type="filepath")
ref_text_input = gr.Textbox(label="Reference Text", lines=2)
speech_type_rows.append(row)
speech_type_names.append(name_input)
speech_type_audios.append(audio_input)
speech_type_ref_texts.append(ref_text_input)
speech_type_delete_btns.append(delete_btn)
speech_type_insert_btns.append(insert_btn)
# Button to add speech type
add_speech_type_btn = gr.Button("Add Speech Type")
# Keep track of current number of speech types
speech_type_count = gr.State(value=1)
# Function to add a speech type
def add_speech_type_fn(speech_type_count):
if speech_type_count < max_speech_types:
speech_type_count += 1
# Prepare updates for the rows
row_updates = []
for i in range(1, max_speech_types):
if i < speech_type_count:
row_updates.append(gr.update(visible=True))
else:
row_updates.append(gr.update())
else:
# Optionally, show a warning
row_updates = [gr.update() for _ in range(1, max_speech_types)]
return [speech_type_count] + row_updates
add_speech_type_btn.click(
add_speech_type_fn, inputs=speech_type_count, outputs=[speech_type_count] + speech_type_rows
)
# Function to delete a speech type
def make_delete_speech_type_fn(index):
def delete_speech_type_fn(speech_type_count):
# Prepare updates
row_updates = []
for i in range(1, max_speech_types):
if i == index:
row_updates.append(gr.update(visible=False))
else:
row_updates.append(gr.update())
speech_type_count = max(1, speech_type_count)
return [speech_type_count] + row_updates
return delete_speech_type_fn
# Update delete button clicks
for i, delete_btn in enumerate(speech_type_delete_btns):
delete_fn = make_delete_speech_type_fn(i)
delete_btn.click(delete_fn, inputs=speech_type_count, outputs=[speech_type_count] + speech_type_rows)
# Text input for the prompt
gen_text_input_multistyle = gr.Textbox(
label="Text to Generate",
lines=10,
placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
)
def make_insert_speech_type_fn(index):
def insert_speech_type_fn(current_text, speech_type_name):
current_text = current_text or ""
speech_type_name = speech_type_name or "None"
updated_text = current_text + f"{{{speech_type_name}}} "
return gr.update(value=updated_text)
return insert_speech_type_fn
for i, insert_btn in enumerate(speech_type_insert_btns):
insert_fn = make_insert_speech_type_fn(i)
insert_btn.click(
insert_fn,
inputs=[gen_text_input_multistyle, speech_type_names[i]],
outputs=gen_text_input_multistyle,
)
with gr.Accordion("Advanced Settings", open=False):
remove_silence_multistyle = gr.Checkbox(
label="Remove Silences",
value=True,
)
# Generate button
generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
# Output audio
audio_output_multistyle = gr.Audio(label="Synthesized Audio")
@gpu_decorator
def generate_multistyle_speech(
gen_text,
*args,
):
speech_type_names_list = args[:max_speech_types]
speech_type_audios_list = args[max_speech_types : 2 * max_speech_types]
speech_type_ref_texts_list = args[2 * max_speech_types : 3 * max_speech_types]
remove_silence = args[3 * max_speech_types]
# Collect the speech types and their audios into a dict
speech_types = OrderedDict()
ref_text_idx = 0
for name_input, audio_input, ref_text_input in zip(
speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
):
if name_input and audio_input:
speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input}
else:
speech_types[f"@{ref_text_idx}@"] = {"audio": "", "ref_text": ""}
ref_text_idx += 1
# Parse the gen_text into segments
segments = parse_speechtypes_text(gen_text)
# For each segment, generate speech
generated_audio_segments = []
current_style = "Regular"
for segment in segments:
style = segment["style"]
text = segment["text"]
if style in speech_types:
current_style = style
else:
# If style not available, default to Regular
current_style = "Regular"
ref_audio = speech_types[current_style]["audio"]
ref_text = speech_types[current_style].get("ref_text", "")
# Generate speech for this segment
audio_out, _, ref_text_out = infer(
ref_audio, ref_text, text, remove_silence, 0, show_info=print
) # show_info=print no pull to top when generating
sr, audio_data = audio_out
generated_audio_segments.append(audio_data)
speech_types[current_style]["ref_text"] = ref_text_out
# Concatenate all audio segments
if generated_audio_segments:
final_audio_data = np.concatenate(generated_audio_segments)
return [(sr, final_audio_data)] + [
gr.update(value=speech_types[style]["ref_text"]) for style in speech_types
]
else:
gr.Warning("No audio generated.")
return [None] + [gr.update(value=speech_types[style]["ref_text"]) for style in speech_types]
generate_multistyle_btn.click(
generate_multistyle_speech,
inputs=[
gen_text_input_multistyle,
]
+ speech_type_names
+ speech_type_audios
+ speech_type_ref_texts
+ [
remove_silence_multistyle,
],
outputs=[audio_output_multistyle] + speech_type_ref_texts,
)
# Validation function to disable Generate button if speech types are missing
def validate_speech_types(gen_text, regular_name, *args):
speech_type_names_list = args[:max_speech_types]
# Collect the speech types names
speech_types_available = set()
if regular_name:
speech_types_available.add(regular_name)
for name_input in speech_type_names_list:
if name_input:
speech_types_available.add(name_input)
# Parse the gen_text to get the speech types used
segments = parse_speechtypes_text(gen_text)
speech_types_in_text = set(segment["style"] for segment in segments)
# Check if all speech types in text are available
missing_speech_types = speech_types_in_text - speech_types_available
if missing_speech_types:
# Disable the generate button
return gr.update(interactive=False)
else:
# Enable the generate button
return gr.update(interactive=True)
gen_text_input_multistyle.change(
validate_speech_types,
inputs=[gen_text_input_multistyle, regular_name] + speech_type_names,
outputs=generate_multistyle_btn,
)
@click.command()
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
@click.option("--host", "-H", default=None, help="Host to run the app on")
@click.option(
"--share",
"-s",
default=False,
is_flag=True,
help="Share the app via Gradio share link",
)
@click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
def main(port, host, share, api):
global app
print("Starting app...")
app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api)
if __name__ == "__main__":
if not USING_SPACES:
main()
else:
app.queue().launch()