|
import nltk |
|
nltk.download('punkt_tab') |
|
|
|
from sentence_analyzer import SentenceAnalyzer |
|
import re |
|
import tempfile |
|
from collections import OrderedDict |
|
from importlib.resources import files |
|
import click |
|
import gradio as gr |
|
import numpy as np |
|
import soundfile as sf |
|
import torchaudio |
|
import torch |
|
from cached_path import cached_path |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
try: |
|
import spaces |
|
USING_SPACES = True |
|
except ImportError: |
|
USING_SPACES = False |
|
|
|
def gpu_decorator(func): |
|
if USING_SPACES: |
|
return spaces.GPU(func) |
|
else: |
|
return func |
|
|
|
|
|
from f5_tts.api import F5TTS |
|
from f5_tts.infer.utils_infer import preprocess_ref_audio_text |
|
|
|
import os |
|
from huggingface_hub import hf_hub_download |
|
|
|
def load_f5tts(): |
|
|
|
repo_id = os.getenv("MODEL_REPO_ID", "SWivid/F5-TTS/F5TTS_Base") |
|
filename = os.getenv("MODEL_FILENAME", "model_1200000.safetensors") |
|
token = os.getenv("HUGGINGFACE_TOKEN") |
|
|
|
if not token: |
|
raise ValueError("A variável de ambiente 'HUGGINGFACE_TOKEN' não foi definida.") |
|
|
|
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename, use_auth_token=token) |
|
|
|
|
|
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) |
|
|
|
|
|
return F5TTS( |
|
model_type="F5-TTS", |
|
ckpt_file=ckpt_path, |
|
vocab_file="/home/user/app/data/Emilia_ZH_EN_pinyin/vocab.txt", |
|
device="cuda" if torch.cuda.is_available() else "cpu", |
|
use_ema=True |
|
) |
|
|
|
|
|
F5TTS_ema_model = load_f5tts() |
|
|
|
|
|
last_checkpoint = None |
|
last_device = None |
|
last_ema = None |
|
tts_api = None |
|
training_process = None |
|
|
|
@gpu_decorator |
|
def infer( |
|
project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema, speed, seed, remove_silence |
|
): |
|
global last_checkpoint, last_device, tts_api, last_ema |
|
if not os.path.isfile(file_checkpoint): |
|
return None, "checkpoint not found!" |
|
if training_process is not None: |
|
device_test = "cpu" |
|
else: |
|
device_test = None |
|
if last_checkpoint != file_checkpoint or last_device != device_test or last_ema != use_ema or tts_api is None: |
|
if last_checkpoint != file_checkpoint: |
|
last_checkpoint = file_checkpoint |
|
if last_device != device_test: |
|
last_device = device_test |
|
if last_ema != use_ema: |
|
last_ema = use_ema |
|
vocab_file = "/home/user/app/data/Emilia_ZH_EN_pinyin/vocab.txt" |
|
tts_api = F5TTS( |
|
model_type=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema |
|
) |
|
print("update >> ", device_test, file_checkpoint, use_ema) |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: |
|
tts_api.infer( |
|
gen_text=gen_text.lower().strip(), |
|
ref_text=ref_text.lower().strip(), |
|
ref_file=ref_audio, |
|
nfe_step=nfe_step, |
|
file_wave=f.name, |
|
speed=speed, |
|
seed=seed, |
|
remove_silence=remove_silence, |
|
) |
|
return f.name, tts_api.device, str(tts_api.seed) |
|
|
|
|
|
custom_css = """ |
|
#sentences-container { |
|
border: 1px solid #ddd; |
|
border-radius: 4px; |
|
padding: 10px; |
|
margin-bottom: 10px; |
|
} |
|
.sentence-box { |
|
border: 1px solid #eee; |
|
padding: 5px; |
|
margin-bottom: 5px; |
|
border-radius: 4px; |
|
background-color: #f9f9f9; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=custom_css) as app: |
|
with gr.Tabs(): |
|
with gr.Tab("TTS Básico"): |
|
gr.Markdown("# TTS Básico com F5-TTS") |
|
|
|
|
|
ref_audio_input = gr.Audio(label="Áudio de Referência", type="filepath") |
|
gen_text_input = gr.Textbox(label="Texto para Gerar", lines=10) |
|
generate_btn = gr.Button("Sintetizar", variant="primary") |
|
|
|
|
|
gr.Markdown("### Configurações Avançadas") |
|
with gr.Accordion("Expandir Configurações Avançadas", open=False): |
|
ref_text_input = gr.Textbox( |
|
label="Texto de Referência", |
|
info="Deixe em branco para transcrever automaticamente o áudio de referência. Se você inserir texto, ele substituirá a transcrição automática.", |
|
lines=2, |
|
) |
|
remove_silence = gr.Checkbox( |
|
label="Remover Silêncios", |
|
info="O modelo tende a produzir silêncios, especialmente em áudios mais longos. Podemos remover manualmente os silêncios, se necessário. Isso também aumentará o tempo de geração.", |
|
value=False, |
|
) |
|
speed_slider = gr.Slider( |
|
label="Velocidade", |
|
minimum=0.3, |
|
maximum=2.0, |
|
value=1.0, |
|
step=0.1, |
|
info="Ajuste a velocidade do áudio.", |
|
) |
|
cross_fade_duration_slider = gr.Slider( |
|
label="Duração do Cross-fade (s)", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.15, |
|
step=0.01, |
|
info="Defina a duração do cross-fade entre os clipes de áudio.", |
|
) |
|
chunk_size_slider = gr.Slider( |
|
label="Número de Sentenças por Chunk", |
|
minimum=1, |
|
maximum=10, |
|
value=1, |
|
step=1, |
|
info="Defina quantas sentenças serão processadas em cada chunk.", |
|
) |
|
nfe_slider = gr.Slider( |
|
label="NFE", |
|
minimum=16, |
|
maximum=64, |
|
value=32, |
|
step=1, |
|
info="Ajuste NFE Step.", |
|
) |
|
seed_input = gr.Number(label="Seed", value=-1, minimum=-1) |
|
|
|
analyzer = SentenceAnalyzer() |
|
|
|
@gpu_decorator |
|
def process_chunks( |
|
ref_audio_input, |
|
ref_text_input, |
|
gen_text_input, |
|
remove_silence, |
|
cross_fade_duration_slider, |
|
speed_slider, |
|
nfe_slider, |
|
chunk_size_slider, |
|
seed_input, |
|
): |
|
|
|
sentences = analyzer.split_into_sentences(gen_text_input) |
|
|
|
|
|
chunks = [ |
|
" ".join(sentences[i : i + chunk_size_slider]) |
|
for i in range(0, len(sentences), chunk_size_slider) |
|
] |
|
|
|
|
|
audio_segments = [] |
|
for chunk in chunks: |
|
|
|
audio_file, device_used, seed_used = infer( |
|
"Emilia_ZH_EN_pinyin", |
|
"/home/user/app/model_1200000.safetensors", |
|
"F5-TTS", |
|
ref_text_input, |
|
ref_audio_input, |
|
chunk, |
|
nfe_slider, |
|
True, |
|
speed_slider, |
|
seed_input, |
|
remove_silence, |
|
) |
|
audio_data, _ = torchaudio.load(audio_file) |
|
audio_segments.append(audio_data.squeeze().cpu().numpy()) |
|
|
|
|
|
if audio_segments: |
|
final_audio_data = np.concatenate(audio_segments) |
|
return ( |
|
(24000, final_audio_data), |
|
None, |
|
gr.update(value=ref_text_input), |
|
seed_used |
|
) |
|
else: |
|
gr.Warning("Nenhum áudio gerado.") |
|
return None, None, gr.update(), None |
|
|
|
|
|
gr.Markdown("### Resultados") |
|
audio_output = gr.Audio(label="Áudio Sintetizado") |
|
seed_output = gr.Text(label="Seed usada:") |
|
|
|
|
|
generate_btn.click( |
|
process_chunks, |
|
inputs=[ |
|
ref_audio_input, |
|
ref_text_input, |
|
gen_text_input, |
|
remove_silence, |
|
cross_fade_duration_slider, |
|
speed_slider, |
|
nfe_slider, |
|
chunk_size_slider, |
|
seed_input, |
|
], |
|
outputs=[ |
|
audio_output, |
|
None, |
|
ref_text_input, |
|
seed_output, |
|
], |
|
) |
|
|
|
|
|
@click.command() |
|
@click.option("--port", "-p", default=None, type=int, help="Port to run the app on") |
|
@click.option("--host", "-H", default=None, help="Host to run the app on") |
|
@click.option( |
|
"--share", |
|
"-s", |
|
default=False, |
|
is_flag=True, |
|
help="Share the app via Gradio share link", |
|
) |
|
@click.option("--api", "-a", default=True, is_flag=True, help="Allow API access") |
|
def main(port, host, share, api): |
|
global app |
|
print("Starting app...") |
|
app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api) |
|
|
|
if __name__ == "__main__": |
|
if not USING_SPACES: |
|
main() |
|
else: |
|
app.queue().launch() |
|
|