alex16052G commited on
Commit
43489d1
verified
1 Parent(s): 23c65b5

Update chat_ai.py

Browse files
Files changed (1) hide show
  1. chat_ai.py +45 -30
chat_ai.py CHANGED
@@ -1,8 +1,5 @@
1
  # text_to_speech_ai.py
2
 
3
- # ruff: noqa: E402
4
- # Above allows ruff to ignore E402: module level import not at top of file
5
-
6
  import re
7
  import tempfile
8
  import os
@@ -21,9 +18,6 @@ try:
21
  except ImportError:
22
  USING_SPACES = False
23
 
24
- # Definir el dispositivo global
25
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
-
27
  def gpu_decorator(func):
28
  if USING_SPACES:
29
  return spaces.GPU(func)
@@ -40,19 +34,37 @@ from f5_tts.infer.utils_infer import (
40
  save_spectrogram,
41
  )
42
 
43
- # Cargar el vocoder
44
- vocoder = load_vocoder().to(DEVICE)
 
45
 
46
- # Configuraci贸n y carga del modelo F5-TTS
47
- F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
48
- F5TTS_ema_model = load_model(
49
- DiT, F5TTS_model_cfg, str(cached_path("hf://jpgallegoar/F5-Spanish/model_1200000.safetensors"))
50
- ).to(DEVICE)
51
-
52
- # Cargar el modelo Whisper para transcripci贸n
53
- whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-base")
54
- whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to(DEVICE)
55
- whisper_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  @gpu_decorator
58
  def infer(
@@ -61,34 +73,34 @@ def infer(
61
  """Genera el audio sintetizado a partir del texto utilizando la voz de referencia."""
62
  try:
63
  with torch.no_grad():
 
 
64
  # Preprocesar el audio de referencia y el texto de referencia
65
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text)
66
-
67
- ema_model = F5TTS_ema_model
68
-
69
  # Asegurar que el texto a generar est茅 correctamente formateado
70
  if not gen_text.startswith(" "):
71
  gen_text = " " + gen_text
72
  if not gen_text.endswith(". "):
73
  gen_text += ". "
74
-
75
  # El texto ingresado por el usuario se utiliza directamente sin modificaciones
76
  input_text = gen_text
77
-
78
  print(f"Texto para generar audio: {input_text}") # Debug: Verificar el texto
79
-
80
  # Procesar la inferencia para generar el audio
81
  final_wave, final_sample_rate, combined_spectrogram = infer_process(
82
- ref_audio.to(DEVICE),
83
  ref_text,
84
  input_text,
85
- ema_model,
86
  vocoder,
87
  cross_fade_duration=cross_fade_duration,
88
  speed=speed,
89
  progress=gr.Progress(),
90
  )
91
-
92
  # Eliminar silencios si est谩 activado
93
  if remove_silence:
94
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
@@ -96,21 +108,24 @@ def infer(
96
  remove_silence_for_generated_wav(f.name)
97
  final_wave, _ = torchaudio.load(f.name)
98
  final_wave = final_wave.squeeze().cpu().numpy()
99
-
100
  # Guardar el espectrograma (opcional)
101
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
102
  spectrogram_path = tmp_spectrogram.name
103
  save_spectrogram(combined_spectrogram, spectrogram_path)
104
-
105
  return (final_sample_rate, final_wave), spectrogram_path
106
  except Exception as e:
107
  # Log del error para depuraci贸n
108
  print(f"Error en infer: {e}")
109
  return None, None
110
 
 
111
  def transcribe_audio(audio_path):
112
  """Transcribe el audio de referencia usando el modelo Whisper en espa帽ol."""
113
  try:
 
 
114
  if not os.path.exists(audio_path):
115
  raise FileNotFoundError(f"Archivo de audio no encontrado: {audio_path}")
116
 
@@ -129,7 +144,7 @@ def transcribe_audio(audio_path):
129
  # Procesar el audio con el procesador de Whisper
130
  inputs = whisper_processor(audio.cpu().numpy(), sampling_rate=16000, return_tensors="pt")
131
 
132
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
133
 
134
  # Forzar el idioma a espa帽ol (usando el nombre en ingl茅s)
135
  forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language="spanish", task="transcribe")
 
1
  # text_to_speech_ai.py
2
 
 
 
 
3
  import re
4
  import tempfile
5
  import os
 
18
  except ImportError:
19
  USING_SPACES = False
20
 
 
 
 
21
  def gpu_decorator(func):
22
  if USING_SPACES:
23
  return spaces.GPU(func)
 
34
  save_spectrogram,
35
  )
36
 
37
+ # Definir el dispositivo global (se usa solo dentro de las funciones)
38
+ def get_device():
39
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
 
41
+ @gpu_decorator
42
+ def load_models():
43
+ """Carga y devuelve los modelos necesarios."""
44
+ device = get_device()
45
+
46
+ # Cargar el vocoder y moverlo al dispositivo
47
+ vocoder = load_vocoder().to(device)
48
+
49
+ # Configuraci贸n y carga del modelo F5-TTS
50
+ F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
51
+ F5TTS_ema_model = load_model(
52
+ DiT, F5TTS_model_cfg, str(cached_path("hf://jpgallegoar/F5-Spanish/model_1200000.safetensors"))
53
+ ).to(device)
54
+
55
+ # Cargar el modelo Whisper para transcripci贸n
56
+ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-base")
57
+ whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to(device)
58
+ whisper_model.eval()
59
+
60
+ return vocoder, F5TTS_ema_model, whisper_processor, whisper_model, device
61
+
62
+ # Cargar modelos una sola vez y almacenarlos en variables globales dentro de la funci贸n
63
+ # Esto se logra usando atributos de funci贸n
64
+ def get_models():
65
+ if not hasattr(get_models, "vocoder"):
66
+ get_models.vocoder, get_models.F5TTS_ema_model, get_models.whisper_processor, get_models.whisper_model, get_models.device = load_models()
67
+ return get_models.vocoder, get_models.F5TTS_ema_model, get_models.whisper_processor, get_models.whisper_model, get_models.device
68
 
69
  @gpu_decorator
70
  def infer(
 
73
  """Genera el audio sintetizado a partir del texto utilizando la voz de referencia."""
74
  try:
75
  with torch.no_grad():
76
+ vocoder, F5TTS_ema_model, _, _, device = get_models()
77
+
78
  # Preprocesar el audio de referencia y el texto de referencia
79
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text)
80
+
 
 
81
  # Asegurar que el texto a generar est茅 correctamente formateado
82
  if not gen_text.startswith(" "):
83
  gen_text = " " + gen_text
84
  if not gen_text.endswith(". "):
85
  gen_text += ". "
86
+
87
  # El texto ingresado por el usuario se utiliza directamente sin modificaciones
88
  input_text = gen_text
89
+
90
  print(f"Texto para generar audio: {input_text}") # Debug: Verificar el texto
91
+
92
  # Procesar la inferencia para generar el audio
93
  final_wave, final_sample_rate, combined_spectrogram = infer_process(
94
+ ref_audio.to(device),
95
  ref_text,
96
  input_text,
97
+ F5TTS_ema_model,
98
  vocoder,
99
  cross_fade_duration=cross_fade_duration,
100
  speed=speed,
101
  progress=gr.Progress(),
102
  )
103
+
104
  # Eliminar silencios si est谩 activado
105
  if remove_silence:
106
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
 
108
  remove_silence_for_generated_wav(f.name)
109
  final_wave, _ = torchaudio.load(f.name)
110
  final_wave = final_wave.squeeze().cpu().numpy()
111
+
112
  # Guardar el espectrograma (opcional)
113
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
114
  spectrogram_path = tmp_spectrogram.name
115
  save_spectrogram(combined_spectrogram, spectrogram_path)
116
+
117
  return (final_sample_rate, final_wave), spectrogram_path
118
  except Exception as e:
119
  # Log del error para depuraci贸n
120
  print(f"Error en infer: {e}")
121
  return None, None
122
 
123
+ @gpu_decorator
124
  def transcribe_audio(audio_path):
125
  """Transcribe el audio de referencia usando el modelo Whisper en espa帽ol."""
126
  try:
127
+ vocoder, F5TTS_ema_model, whisper_processor, whisper_model, device = get_models()
128
+
129
  if not os.path.exists(audio_path):
130
  raise FileNotFoundError(f"Archivo de audio no encontrado: {audio_path}")
131
 
 
144
  # Procesar el audio con el procesador de Whisper
145
  inputs = whisper_processor(audio.cpu().numpy(), sampling_rate=16000, return_tensors="pt")
146
 
147
+ inputs = {k: v.to(device) for k, v in inputs.items()}
148
 
149
  # Forzar el idioma a espa帽ol (usando el nombre en ingl茅s)
150
  forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language="spanish", task="transcribe")