alex16052G commited on
Commit
23c65b5
verified
1 Parent(s): 019a00a

Update chat_ai.py

Browse files
Files changed (1) hide show
  1. chat_ai.py +56 -64
chat_ai.py CHANGED
@@ -21,6 +21,9 @@ try:
21
  except ImportError:
22
  USING_SPACES = False
23
 
 
 
 
24
  def gpu_decorator(func):
25
  if USING_SPACES:
26
  return spaces.GPU(func)
@@ -38,26 +41,18 @@ from f5_tts.infer.utils_infer import (
38
  )
39
 
40
  # Cargar el vocoder
41
- vocoder = load_vocoder()
42
 
43
  # Configuraci贸n y carga del modelo F5-TTS
44
  F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
45
  F5TTS_ema_model = load_model(
46
  DiT, F5TTS_model_cfg, str(cached_path("hf://jpgallegoar/F5-Spanish/model_1200000.safetensors"))
47
- )
48
-
49
- # Eliminamos la carga global de WhisperProcessor y WhisperForConditionalGeneration
50
- # Estos se cargar谩n dentro de la funci贸n de transcripci贸n
51
 
52
- @gr.Caching.cache # Ajusta seg煤n tu versi贸n de Gradio
53
- def get_whisper_models():
54
- """Carga y retorna los modelos Whisper y el procesador."""
55
- processor = WhisperProcessor.from_pretrained("openai/whisper-base")
56
- model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
57
- model.eval()
58
- if torch.cuda.is_available():
59
- model.to("cuda")
60
- return processor, model
61
 
62
  @gpu_decorator
63
  def infer(
@@ -65,54 +60,54 @@ def infer(
65
  ):
66
  """Genera el audio sintetizado a partir del texto utilizando la voz de referencia."""
67
  try:
68
- # Preprocesar el audio de referencia y el texto de referencia
69
- ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text)
70
-
71
- ema_model = F5TTS_ema_model
72
-
73
- # Asegurar que el texto a generar est茅 correctamente formateado
74
- if not gen_text.startswith(" "):
75
- gen_text = " " + gen_text
76
- if not gen_text.endswith(". "):
77
- gen_text += ". "
78
-
79
- # El texto ingresado por el usuario se utiliza directamente sin modificaciones
80
- input_text = gen_text
81
-
82
- print(f"Texto para generar audio: {input_text}") # Debug: Verificar el texto
83
-
84
- # Procesar la inferencia para generar el audio
85
- final_wave, final_sample_rate, combined_spectrogram = infer_process(
86
- ref_audio,
87
- ref_text,
88
- input_text,
89
- ema_model,
90
- vocoder,
91
- cross_fade_duration=cross_fade_duration,
92
- speed=speed,
93
- progress=gr.Progress(),
94
- )
 
95
 
96
- # Eliminar silencios si est谩 activado
97
- if remove_silence:
98
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
99
- sf.write(f.name, final_wave, final_sample_rate)
100
- remove_silence_for_generated_wav(f.name)
101
- final_wave, _ = torchaudio.load(f.name)
102
- final_wave = final_wave.squeeze().cpu().numpy()
103
 
104
- # Guardar el espectrograma (opcional)
105
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
106
- spectrogram_path = tmp_spectrogram.name
107
- save_spectrogram(combined_spectrogram, spectrogram_path)
108
 
109
- return (final_sample_rate, final_wave), spectrogram_path
110
  except Exception as e:
111
  # Log del error para depuraci贸n
112
  print(f"Error en infer: {e}")
113
  return None, None
114
 
115
- @gpu_decorator
116
  def transcribe_audio(audio_path):
117
  """Transcribe el audio de referencia usando el modelo Whisper en espa帽ol."""
118
  try:
@@ -131,23 +126,20 @@ def transcribe_audio(audio_path):
131
  if audio.ndim > 1:
132
  audio = torch.mean(audio, dim=0)
133
 
134
- # Cargar los modelos Whisper
135
- whisper_processor, whisper_model = get_whisper_models()
136
-
137
  # Procesar el audio con el procesador de Whisper
138
  inputs = whisper_processor(audio.cpu().numpy(), sampling_rate=16000, return_tensors="pt")
139
 
140
- if torch.cuda.is_available():
141
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
142
 
143
  # Forzar el idioma a espa帽ol (usando el nombre en ingl茅s)
144
  forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language="spanish", task="transcribe")
145
 
146
  # Generar la transcripci贸n
147
- predicted_ids = whisper_model.generate(
148
- inputs["input_features"],
149
- forced_decoder_ids=forced_decoder_ids
150
- )
 
151
  transcription = whisper_processor.decode(predicted_ids[0], skip_special_tokens=True)
152
 
153
  print(f"Transcripci贸n: {transcription}") # Debug: Verificar la transcripci贸n
@@ -155,7 +147,7 @@ def transcribe_audio(audio_path):
155
  return transcription
156
  except Exception as e:
157
  print(f"Error en transcribe_audio: {e}")
158
- return "Error al transcribir el audio de referencia."
159
 
160
  def transcribe_and_update(audio_path):
161
  """Transcribe el audio de referencia y devuelve el texto transcrito."""
 
21
  except ImportError:
22
  USING_SPACES = False
23
 
24
+ # Definir el dispositivo global
25
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
  def gpu_decorator(func):
28
  if USING_SPACES:
29
  return spaces.GPU(func)
 
41
  )
42
 
43
  # Cargar el vocoder
44
+ vocoder = load_vocoder().to(DEVICE)
45
 
46
  # Configuraci贸n y carga del modelo F5-TTS
47
  F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
48
  F5TTS_ema_model = load_model(
49
  DiT, F5TTS_model_cfg, str(cached_path("hf://jpgallegoar/F5-Spanish/model_1200000.safetensors"))
50
+ ).to(DEVICE)
 
 
 
51
 
52
+ # Cargar el modelo Whisper para transcripci贸n
53
+ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-base")
54
+ whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to(DEVICE)
55
+ whisper_model.eval()
 
 
 
 
 
56
 
57
  @gpu_decorator
58
  def infer(
 
60
  ):
61
  """Genera el audio sintetizado a partir del texto utilizando la voz de referencia."""
62
  try:
63
+ with torch.no_grad():
64
+ # Preprocesar el audio de referencia y el texto de referencia
65
+ ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text)
66
+
67
+ ema_model = F5TTS_ema_model
68
+
69
+ # Asegurar que el texto a generar est茅 correctamente formateado
70
+ if not gen_text.startswith(" "):
71
+ gen_text = " " + gen_text
72
+ if not gen_text.endswith(". "):
73
+ gen_text += ". "
74
+
75
+ # El texto ingresado por el usuario se utiliza directamente sin modificaciones
76
+ input_text = gen_text
77
+
78
+ print(f"Texto para generar audio: {input_text}") # Debug: Verificar el texto
79
+
80
+ # Procesar la inferencia para generar el audio
81
+ final_wave, final_sample_rate, combined_spectrogram = infer_process(
82
+ ref_audio.to(DEVICE),
83
+ ref_text,
84
+ input_text,
85
+ ema_model,
86
+ vocoder,
87
+ cross_fade_duration=cross_fade_duration,
88
+ speed=speed,
89
+ progress=gr.Progress(),
90
+ )
91
 
92
+ # Eliminar silencios si est谩 activado
93
+ if remove_silence:
94
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
95
+ sf.write(f.name, final_wave.cpu().numpy(), final_sample_rate)
96
+ remove_silence_for_generated_wav(f.name)
97
+ final_wave, _ = torchaudio.load(f.name)
98
+ final_wave = final_wave.squeeze().cpu().numpy()
99
 
100
+ # Guardar el espectrograma (opcional)
101
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
102
+ spectrogram_path = tmp_spectrogram.name
103
+ save_spectrogram(combined_spectrogram, spectrogram_path)
104
 
105
+ return (final_sample_rate, final_wave), spectrogram_path
106
  except Exception as e:
107
  # Log del error para depuraci贸n
108
  print(f"Error en infer: {e}")
109
  return None, None
110
 
 
111
  def transcribe_audio(audio_path):
112
  """Transcribe el audio de referencia usando el modelo Whisper en espa帽ol."""
113
  try:
 
126
  if audio.ndim > 1:
127
  audio = torch.mean(audio, dim=0)
128
 
 
 
 
129
  # Procesar el audio con el procesador de Whisper
130
  inputs = whisper_processor(audio.cpu().numpy(), sampling_rate=16000, return_tensors="pt")
131
 
132
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
 
133
 
134
  # Forzar el idioma a espa帽ol (usando el nombre en ingl茅s)
135
  forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language="spanish", task="transcribe")
136
 
137
  # Generar la transcripci贸n
138
+ with torch.no_grad():
139
+ predicted_ids = whisper_model.generate(
140
+ inputs["input_features"],
141
+ forced_decoder_ids=forced_decoder_ids
142
+ )
143
  transcription = whisper_processor.decode(predicted_ids[0], skip_special_tokens=True)
144
 
145
  print(f"Transcripci贸n: {transcription}") # Debug: Verificar la transcripci贸n
 
147
  return transcription
148
  except Exception as e:
149
  print(f"Error en transcribe_audio: {e}")
150
+ return None
151
 
152
  def transcribe_and_update(audio_path):
153
  """Transcribe el audio de referencia y devuelve el texto transcrito."""