Spaces:
Paused
Paused
Update chat_ai.py
Browse files- chat_ai.py +45 -30
chat_ai.py
CHANGED
@@ -1,8 +1,5 @@
|
|
1 |
# text_to_speech_ai.py
|
2 |
|
3 |
-
# ruff: noqa: E402
|
4 |
-
# Above allows ruff to ignore E402: module level import not at top of file
|
5 |
-
|
6 |
import re
|
7 |
import tempfile
|
8 |
import os
|
@@ -21,9 +18,6 @@ try:
|
|
21 |
except ImportError:
|
22 |
USING_SPACES = False
|
23 |
|
24 |
-
# Definir el dispositivo global
|
25 |
-
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
26 |
-
|
27 |
def gpu_decorator(func):
|
28 |
if USING_SPACES:
|
29 |
return spaces.GPU(func)
|
@@ -40,19 +34,37 @@ from f5_tts.infer.utils_infer import (
|
|
40 |
save_spectrogram,
|
41 |
)
|
42 |
|
43 |
-
#
|
44 |
-
|
|
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
@gpu_decorator
|
58 |
def infer(
|
@@ -61,34 +73,34 @@ def infer(
|
|
61 |
"""Genera el audio sintetizado a partir del texto utilizando la voz de referencia."""
|
62 |
try:
|
63 |
with torch.no_grad():
|
|
|
|
|
64 |
# Preprocesar el audio de referencia y el texto de referencia
|
65 |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text)
|
66 |
-
|
67 |
-
ema_model = F5TTS_ema_model
|
68 |
-
|
69 |
# Asegurar que el texto a generar est茅 correctamente formateado
|
70 |
if not gen_text.startswith(" "):
|
71 |
gen_text = " " + gen_text
|
72 |
if not gen_text.endswith(". "):
|
73 |
gen_text += ". "
|
74 |
-
|
75 |
# El texto ingresado por el usuario se utiliza directamente sin modificaciones
|
76 |
input_text = gen_text
|
77 |
-
|
78 |
print(f"Texto para generar audio: {input_text}") # Debug: Verificar el texto
|
79 |
-
|
80 |
# Procesar la inferencia para generar el audio
|
81 |
final_wave, final_sample_rate, combined_spectrogram = infer_process(
|
82 |
-
ref_audio.to(
|
83 |
ref_text,
|
84 |
input_text,
|
85 |
-
|
86 |
vocoder,
|
87 |
cross_fade_duration=cross_fade_duration,
|
88 |
speed=speed,
|
89 |
progress=gr.Progress(),
|
90 |
)
|
91 |
-
|
92 |
# Eliminar silencios si est谩 activado
|
93 |
if remove_silence:
|
94 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
@@ -96,21 +108,24 @@ def infer(
|
|
96 |
remove_silence_for_generated_wav(f.name)
|
97 |
final_wave, _ = torchaudio.load(f.name)
|
98 |
final_wave = final_wave.squeeze().cpu().numpy()
|
99 |
-
|
100 |
# Guardar el espectrograma (opcional)
|
101 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
|
102 |
spectrogram_path = tmp_spectrogram.name
|
103 |
save_spectrogram(combined_spectrogram, spectrogram_path)
|
104 |
-
|
105 |
return (final_sample_rate, final_wave), spectrogram_path
|
106 |
except Exception as e:
|
107 |
# Log del error para depuraci贸n
|
108 |
print(f"Error en infer: {e}")
|
109 |
return None, None
|
110 |
|
|
|
111 |
def transcribe_audio(audio_path):
|
112 |
"""Transcribe el audio de referencia usando el modelo Whisper en espa帽ol."""
|
113 |
try:
|
|
|
|
|
114 |
if not os.path.exists(audio_path):
|
115 |
raise FileNotFoundError(f"Archivo de audio no encontrado: {audio_path}")
|
116 |
|
@@ -129,7 +144,7 @@ def transcribe_audio(audio_path):
|
|
129 |
# Procesar el audio con el procesador de Whisper
|
130 |
inputs = whisper_processor(audio.cpu().numpy(), sampling_rate=16000, return_tensors="pt")
|
131 |
|
132 |
-
inputs = {k: v.to(
|
133 |
|
134 |
# Forzar el idioma a espa帽ol (usando el nombre en ingl茅s)
|
135 |
forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language="spanish", task="transcribe")
|
|
|
1 |
# text_to_speech_ai.py
|
2 |
|
|
|
|
|
|
|
3 |
import re
|
4 |
import tempfile
|
5 |
import os
|
|
|
18 |
except ImportError:
|
19 |
USING_SPACES = False
|
20 |
|
|
|
|
|
|
|
21 |
def gpu_decorator(func):
|
22 |
if USING_SPACES:
|
23 |
return spaces.GPU(func)
|
|
|
34 |
save_spectrogram,
|
35 |
)
|
36 |
|
37 |
+
# Definir el dispositivo global (se usa solo dentro de las funciones)
|
38 |
+
def get_device():
|
39 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
40 |
|
41 |
+
@gpu_decorator
|
42 |
+
def load_models():
|
43 |
+
"""Carga y devuelve los modelos necesarios."""
|
44 |
+
device = get_device()
|
45 |
+
|
46 |
+
# Cargar el vocoder y moverlo al dispositivo
|
47 |
+
vocoder = load_vocoder().to(device)
|
48 |
+
|
49 |
+
# Configuraci贸n y carga del modelo F5-TTS
|
50 |
+
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
51 |
+
F5TTS_ema_model = load_model(
|
52 |
+
DiT, F5TTS_model_cfg, str(cached_path("hf://jpgallegoar/F5-Spanish/model_1200000.safetensors"))
|
53 |
+
).to(device)
|
54 |
+
|
55 |
+
# Cargar el modelo Whisper para transcripci贸n
|
56 |
+
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-base")
|
57 |
+
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to(device)
|
58 |
+
whisper_model.eval()
|
59 |
+
|
60 |
+
return vocoder, F5TTS_ema_model, whisper_processor, whisper_model, device
|
61 |
+
|
62 |
+
# Cargar modelos una sola vez y almacenarlos en variables globales dentro de la funci贸n
|
63 |
+
# Esto se logra usando atributos de funci贸n
|
64 |
+
def get_models():
|
65 |
+
if not hasattr(get_models, "vocoder"):
|
66 |
+
get_models.vocoder, get_models.F5TTS_ema_model, get_models.whisper_processor, get_models.whisper_model, get_models.device = load_models()
|
67 |
+
return get_models.vocoder, get_models.F5TTS_ema_model, get_models.whisper_processor, get_models.whisper_model, get_models.device
|
68 |
|
69 |
@gpu_decorator
|
70 |
def infer(
|
|
|
73 |
"""Genera el audio sintetizado a partir del texto utilizando la voz de referencia."""
|
74 |
try:
|
75 |
with torch.no_grad():
|
76 |
+
vocoder, F5TTS_ema_model, _, _, device = get_models()
|
77 |
+
|
78 |
# Preprocesar el audio de referencia y el texto de referencia
|
79 |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text)
|
80 |
+
|
|
|
|
|
81 |
# Asegurar que el texto a generar est茅 correctamente formateado
|
82 |
if not gen_text.startswith(" "):
|
83 |
gen_text = " " + gen_text
|
84 |
if not gen_text.endswith(". "):
|
85 |
gen_text += ". "
|
86 |
+
|
87 |
# El texto ingresado por el usuario se utiliza directamente sin modificaciones
|
88 |
input_text = gen_text
|
89 |
+
|
90 |
print(f"Texto para generar audio: {input_text}") # Debug: Verificar el texto
|
91 |
+
|
92 |
# Procesar la inferencia para generar el audio
|
93 |
final_wave, final_sample_rate, combined_spectrogram = infer_process(
|
94 |
+
ref_audio.to(device),
|
95 |
ref_text,
|
96 |
input_text,
|
97 |
+
F5TTS_ema_model,
|
98 |
vocoder,
|
99 |
cross_fade_duration=cross_fade_duration,
|
100 |
speed=speed,
|
101 |
progress=gr.Progress(),
|
102 |
)
|
103 |
+
|
104 |
# Eliminar silencios si est谩 activado
|
105 |
if remove_silence:
|
106 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
|
|
108 |
remove_silence_for_generated_wav(f.name)
|
109 |
final_wave, _ = torchaudio.load(f.name)
|
110 |
final_wave = final_wave.squeeze().cpu().numpy()
|
111 |
+
|
112 |
# Guardar el espectrograma (opcional)
|
113 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
|
114 |
spectrogram_path = tmp_spectrogram.name
|
115 |
save_spectrogram(combined_spectrogram, spectrogram_path)
|
116 |
+
|
117 |
return (final_sample_rate, final_wave), spectrogram_path
|
118 |
except Exception as e:
|
119 |
# Log del error para depuraci贸n
|
120 |
print(f"Error en infer: {e}")
|
121 |
return None, None
|
122 |
|
123 |
+
@gpu_decorator
|
124 |
def transcribe_audio(audio_path):
|
125 |
"""Transcribe el audio de referencia usando el modelo Whisper en espa帽ol."""
|
126 |
try:
|
127 |
+
vocoder, F5TTS_ema_model, whisper_processor, whisper_model, device = get_models()
|
128 |
+
|
129 |
if not os.path.exists(audio_path):
|
130 |
raise FileNotFoundError(f"Archivo de audio no encontrado: {audio_path}")
|
131 |
|
|
|
144 |
# Procesar el audio con el procesador de Whisper
|
145 |
inputs = whisper_processor(audio.cpu().numpy(), sampling_rate=16000, return_tensors="pt")
|
146 |
|
147 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
148 |
|
149 |
# Forzar el idioma a espa帽ol (usando el nombre en ingl茅s)
|
150 |
forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language="spanish", task="transcribe")
|