Marcos12886 commited on
Commit
40b8b4f
·
verified ·
1 Parent(s): 0f202d9

Comentarios

Browse files
Files changed (1) hide show
  1. app.py +132 -132
app.py CHANGED
@@ -1,128 +1,128 @@
1
  import os
2
  import torch
3
- import torch.nn.functional as F
4
- import gradio as gr
5
- from huggingface_hub import InferenceClient
6
- from model import predict_params, AudioDataset
7
- import torchaudio
8
 
9
- token = os.getenv("HF_TOKEN")
10
- client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  model_class, id2label_class = predict_params(
13
- model_path="A-POR-LOS-8000/distilhubert-finetuned-mixed-data2",
14
- dataset_path="data/mixed_data",
15
- filter_white_noise=True,
16
- undersample_normal=True
17
  )
18
  model_mon, id2label_mon = predict_params(
19
- model_path="A-POR-LOS-8000/distilhubert-finetuned-cry-detector",
20
- dataset_path="data/baby_cry_detection",
21
- filter_white_noise=False,
22
- undersample_normal=False
23
  )
24
 
25
  def call(audiopath, model, dataset_path, filter_white_noise, undersample_normal=False):
26
- model.to(device)
27
- model.eval()
28
- audio_dataset = AudioDataset(dataset_path, {}, filter_white_noise, undersample_normal)
29
- processed_audio = audio_dataset.preprocess_audio(audiopath)
30
- inputs = {"input_values": processed_audio.to(device).unsqueeze(0)}
31
- with torch.no_grad():
32
- outputs = model(**inputs)
33
- logits = outputs.logits
34
- return logits
35
 
36
  def predict(audio_path_pred):
37
- with torch.no_grad():
38
- logits = call(audio_path_pred, model=model_class, dataset_path="data/mixed_data", filter_white_noise=True, undersample_normal=False)
39
- predicted_class_ids_class = torch.argmax(logits, dim=-1).item()
40
- label_class = id2label_class[predicted_class_ids_class]
41
- label_mapping = {0: 'Cansancio/Incomodidad', 1: 'Dolor', 2: 'Hambre', 3: 'Problemas para respirar'}
42
- label_class = label_mapping.get(predicted_class_ids_class, label_class)
43
  return f"""
44
  <div style='text-align: center; font-size: 1.5em'>
45
  <span style='display: inline-block; min-width: 300px;'>{label_class}</span>
46
  </div>
47
- """
48
 
49
  def predict_stream(audio_path_stream):
50
- with torch.no_grad():
51
- logits = call(audio_path_stream, model=model_mon, dataset_path="data/baby_cry_detection", filter_white_noise=False, undersample_normal=False)
52
- probabilities = F.softmax(logits, dim=-1)
53
- crying_probabilities = probabilities[:, 1]
54
- avg_crying_probability = crying_probabilities.mean()*100
55
- if avg_crying_probability < 15:
56
- label_class = predict(audio_path_stream)
57
- return f"Está llorando por: {label_class}"
58
  else:
59
- return "No está llorando"
60
 
61
  def decibelios(audio_path_stream):
62
- waveform, _ = torchaudio.load(audio_path_stream)
63
- rms = torch.sqrt(torch.mean(torch.square(waveform)))
64
- db_level = 20 * torch.log10(rms + 1e-6).item()
65
- min_db = -80
66
- max_db = 0
67
- scaled_db_level = (db_level - min_db) / (max_db - min_db)
68
- normalized_db_level = scaled_db_level * 100
69
- return normalized_db_level
70
 
71
  def mostrar_decibelios(audio_path_stream, visual_threshold):
72
- db_level = decibelios(audio_path_stream)
73
- if db_level > visual_threshold:
74
- status = "Prediciendo..."
75
  else:
76
- status = "Esperando..."
77
  return f"""
78
  <div style='text-align: center; font-size: 1.5em'>
79
  <span>{status}</span>
80
  <span style='display: inline-block; min-width: 120px;'>Decibelios: {db_level:.2f}</span>
81
  </div>
82
- """
83
 
84
  def predict_stream_decib(audio_path_stream, visual_threshold):
85
- db_level = decibelios(audio_path_stream)
86
- if db_level > visual_threshold:
87
- prediction = display_prediction_stream(audio_path_stream)
88
  else:
89
- prediction = ""
90
  return f"""
91
  <div style='text-align: center; font-size: 1.5em; min-height: 2em;'>
92
  <span style='display: inline-block; min-width: 300px;'>{prediction}</span>
93
  </div>
94
- """
95
 
96
  def chatbot_config(message, history: list[tuple[str, str]]):
97
- system_message = "You are a Chatbot specialized in baby health and care."
98
- max_tokens = 512
99
- temperature = 0.5
100
- top_p = 0.95
101
- messages = [{"role": "system", "content": system_message}]
102
- for val in history:
103
  if val[0]:
104
- messages.append({"role": "user", "content": val[0]})
105
  if val[1]:
106
- messages.append({"role": "assistant", "content": val[1]})
107
- messages.append({"role": "user", "content": message})
108
- response = ""
109
  for message_response in client.chat_completion(messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p):
110
- token = message_response.choices[0].delta.content
111
- response += token
112
- yield response
113
 
114
  def cambiar_pestaña():
115
- return gr.update(visible=False), gr.update(visible=True)
116
 
117
  def display_prediction(audio, prediction_func):
118
- prediction = prediction_func(audio)
119
- return f"<h3 style='text-align: center; font-size: 1.5em;'>{prediction}</h3>"
120
 
121
  def display_prediction_wrapper(audio):
122
- return display_prediction(audio, predict)
123
 
124
  def display_prediction_stream(audio):
125
- return display_prediction(audio, predict_stream)
126
 
127
  my_theme = gr.themes.Soft(
128
  primary_hue="emerald",
@@ -200,68 +200,68 @@ with gr.Blocks(theme=my_theme) as demo:
200
  "<p style='text-align: left'>Monitor: Nuestro monitor no es como otros que hay en el mercado, ya que es capaz de reconocer si un sonido es un llanto del bebé o no; y si está llorando, predice automáticamente la causa. Dándote la tranquilidad de saber siempre qué pasa con tu pequeño, ahorrándote tiempo y horas de sueño.</p>"
201
  )
202
  boton_inicial = gr.Button("¡Prueba nuestros modelos!")
203
- with gr.Column(visible=False) as chatbot:
204
- gr.Markdown("<h2>Asistente</h2>")
205
- gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Pregunta a nuestro asistente cualquier duda que tengas sobre el cuidado de tu bebé</h4>")
206
- gr.ChatInterface(
207
- chatbot_config,
208
- theme=my_theme,
209
- retry_btn=None,
210
- undo_btn=None,
211
- clear_btn="Limpiar 🗑️",
212
- autofocus=True,
213
- fill_height=True,
214
- )
215
- with gr.Row():
216
- with gr.Column():
217
- boton_predictor = gr.Button("Predictor")
218
- with gr.Column():
219
- boton_monitor = gr.Button("Monitor")
220
- with gr.Column(visible=False) as pag_predictor:
221
- gr.Markdown("<h2>Predictor</h2>")
222
- gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Descubre por qué tu bebé está llorando</h4>")
223
  audio_input = gr.Audio(
224
- min_length=1.0,
225
- format="wav",
226
- label="Baby recorder",
227
- type="filepath",
228
- )
229
- prediction_output = gr.Markdown()
230
  gr.Button("¿Por qué llora?").click(
231
- display_prediction_wrapper,
232
- inputs=audio_input,
233
- outputs=gr.Markdown()
234
- )
235
- gr.Button("Volver").click(cambiar_pestaña, outputs=[pag_predictor, chatbot])
236
- with gr.Column(visible=False) as pag_monitor:
237
- gr.Markdown("<h2>Monitor</h2>")
238
- gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Detecta en tiempo real si tu bebé está llorando y por qué</h4>")
239
  audio_stream = gr.Audio(
240
- format="wav",
241
- label="Baby recorder",
242
- type="filepath",
243
- streaming=True
244
  )
245
  threshold_db = gr.Slider(
246
- minimum=0,
247
- maximum=120,
248
- step=1,
249
- value=20,
250
- label="Umbral de ruido para activar la predicción:"
251
  )
252
- volver = gr.Button("Volver")
253
  audio_stream.stream(
254
- mostrar_decibelios,
255
- inputs=[audio_stream, threshold_db],
256
- outputs=gr.HTML()
257
  )
258
  audio_stream.stream(
259
- predict_stream_decib,
260
- inputs=[audio_stream, threshold_db],
261
- outputs=gr.HTML()
262
  )
263
- volver.click(cambiar_pestaña, outputs=[pag_monitor, chatbot])
264
- boton_inicial.click(cambiar_pestaña, outputs=[inicial, chatbot])
265
- boton_predictor.click(cambiar_pestaña, outputs=[chatbot, pag_predictor])
266
- boton_monitor.click(cambiar_pestaña, outputs=[chatbot, pag_monitor])
267
- demo.launch(share=True)
 
1
  import os
2
  import torch
3
+ import torch.nn.functional as F # Importa la API funcional de torch, incluyendo softmax
4
+ import gradio as gr # Gradio para crear interfaces web
5
+ from huggingface_hub import InferenceClient # Cliente de inferencia para acceder a modelos desde Hugging Face Hub
6
+ from model import predict_params, AudioDataset # Importaciones personalizadas: carga de modelo y procesamiento de audio
7
+ import torchaudio # Librería para procesamiento de audio
8
 
9
+ token = os.getenv("HF_TOKEN") # Obtiene el token de la API de Hugging Face desde las variables de entorno
10
+ client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token) # Inicializa el cliente de Hugging Face con el modelo y el token
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Verifica si hay GPU disponible, de lo contrario usa CPU
12
  model_class, id2label_class = predict_params(
13
+ model_path="A-POR-LOS-8000/distilhubert-finetuned-mixed-data2", # Ruta al modelo para la predicción de clases de llanto
14
+ dataset_path="data/mixed_data", # Ruta al dataset de audio mixto
15
+ filter_white_noise=True, # Indica que se filtrará el ruido blanco
16
+ undersample_normal=True # Activa el submuestreo para equilibrar clases
17
  )
18
  model_mon, id2label_mon = predict_params(
19
+ model_path="A-POR-LOS-8000/distilhubert-finetuned-cry-detector", # Ruta al modelo detector de llanto
20
+ dataset_path="data/baby_cry_detection", # Ruta al dataset de detección de llanto
21
+ filter_white_noise=False, # No filtrar ruido blanco en este modelo
22
+ undersample_normal=False # No submuestrear datos
23
  )
24
 
25
  def call(audiopath, model, dataset_path, filter_white_noise, undersample_normal=False):
26
+ model.to(device) # Envía el modelo a la GPU (o CPU si no hay GPU disponible)
27
+ model.eval() # Pone el modelo en modo de evaluación (desactiva dropout, batchnorm)
28
+ audio_dataset = AudioDataset(dataset_path, {}, filter_white_noise, undersample_normal) # Carga el dataset de audio con parámetros específicos
29
+ processed_audio = audio_dataset.preprocess_audio(audiopath) # Preprocesa el audio según la configuración del dataset
30
+ inputs = {"input_values": processed_audio.to(device).unsqueeze(0)} # Prepara los datos para el modelo (envía a GPU y ajusta dimensiones)
31
+ with torch.no_grad(): # Desactiva el cálculo del gradiente para ahorrar memoria
32
+ outputs = model(**inputs) # Realiza la inferencia con el modelo
33
+ logits = outputs.logits # Obtiene las predicciones del modelo
34
+ return logits # Retorna los logits (valores sin procesar)
35
 
36
  def predict(audio_path_pred):
37
+ with torch.no_grad(): # Desactiva gradientes para la inferencia
38
+ logits = call(audio_path_pred, model=model_class, dataset_path="data/mixed_data", filter_white_noise=True, undersample_normal=False) # Llama a la función de inferencia
39
+ predicted_class_ids_class = torch.argmax(logits, dim=-1).item() # Obtiene la clase predicha a partir de los logits
40
+ label_class = id2label_class[predicted_class_ids_class] # Convierte el ID de clase en una etiqueta de texto
41
+ label_mapping = {0: 'Cansancio/Incomodidad', 1: 'Dolor', 2: 'Hambre', 3: 'Problemas para respirar'} # Mapea las etiquetas
42
+ label_class = label_mapping.get(predicted_class_ids_class, label_class) # Si hay una etiqueta personalizada, la usa
43
  return f"""
44
  <div style='text-align: center; font-size: 1.5em'>
45
  <span style='display: inline-block; min-width: 300px;'>{label_class}</span>
46
  </div>
47
+ """ # Retorna el resultado formateado para mostrar en la interfaz
48
 
49
  def predict_stream(audio_path_stream):
50
+ with torch.no_grad(): # Desactiva gradientes durante la inferencia
51
+ logits = call(audio_path_stream, model=model_mon, dataset_path="data/baby_cry_detection", filter_white_noise=False, undersample_normal=False) # Llama al modelo de detección de llanto
52
+ probabilities = F.softmax(logits, dim=-1) # Aplica softmax para convertir los logits en probabilidades
53
+ crying_probabilities = probabilities[:, 1] # Obtiene las probabilidades asociadas al llanto
54
+ avg_crying_probability = crying_probabilities.mean()*100 # Calcula la probabilidad promedio de llanto
55
+ if avg_crying_probability < 15: # Si la probabilidad de llanto es menor a un 15%, se predice la razón
56
+ label_class = predict(audio_path_stream) # Llama a la predicción para determinar la razón del llanto
57
+ return f"Está llorando por: {label_class}" # Retorna el resultado indicando por qué llora
58
  else:
59
+ return "No está llorando" # Si la probabilidad es mayor, indica que no está llorando
60
 
61
  def decibelios(audio_path_stream):
62
+ waveform, _ = torchaudio.load(audio_path_stream) # Carga el audio y su forma de onda
63
+ rms = torch.sqrt(torch.mean(torch.square(waveform))) # Calcula el valor RMS del audio
64
+ db_level = 20 * torch.log10(rms + 1e-6).item() # Convierte el RMS en decibelios (añade un pequeño valor para evitar log(0))
65
+ min_db = -80 # Nivel mínimo de decibelios esperado
66
+ max_db = 0 # Nivel máximo de decibelios esperado
67
+ scaled_db_level = (db_level - min_db) / (max_db - min_db) # Escala el nivel de decibelios a un rango entre 0 y 1
68
+ normalized_db_level = scaled_db_level * 100 # Escala el nivel de decibelios a un porcentaje
69
+ return normalized_db_level # Retorna el nivel de decibelios normalizado
70
 
71
  def mostrar_decibelios(audio_path_stream, visual_threshold):
72
+ db_level = decibelios(audio_path_stream)# Obtiene el nivel de decibelios del audio
73
+ if db_level > visual_threshold: # Si el nivel de decibelios supera el umbral visual
74
+ status = "Prediciendo..." # Cambia el estado a "Prediciendo"
75
  else:
76
+ status = "Esperando..." # Si no supera el umbral, indica que está "Esperando"
77
  return f"""
78
  <div style='text-align: center; font-size: 1.5em'>
79
  <span>{status}</span>
80
  <span style='display: inline-block; min-width: 120px;'>Decibelios: {db_level:.2f}</span>
81
  </div>
82
+ """ # Retorna una cadena HTML con el estado y el nivel de decibelios
83
 
84
  def predict_stream_decib(audio_path_stream, visual_threshold):
85
+ db_level = decibelios(audio_path_stream) # Calcula el nivel de decibelios
86
+ if db_level > visual_threshold: # Si supera el umbral, hace una predicción
87
+ prediction = display_prediction_stream(audio_path_stream) # Llama a la función de predicción
88
  else:
89
+ prediction = "" # Si no supera el umbral, no muestra predicción
90
  return f"""
91
  <div style='text-align: center; font-size: 1.5em; min-height: 2em;'>
92
  <span style='display: inline-block; min-width: 300px;'>{prediction}</span>
93
  </div>
94
+ """ # Retorna el resultado o nada si no supera el umbral
95
 
96
  def chatbot_config(message, history: list[tuple[str, str]]):
97
+ system_message = "You are a Chatbot specialized in baby health and care." # Mensaje inicial del chatbot
98
+ max_tokens = 512 # Máximo de tokens para la respuesta
99
+ temperature = 0.5 # Controla la aleatoriedad de las respuestas
100
+ top_p = 0.95 # Top-p sampling para filtrar palabras
101
+ messages = [{"role": "system", "content": system_message}] # Configura el mensaje del sistema para el chatbot
102
+ for val in history: # Añade el historial de la conversación al mensaje
103
  if val[0]:
104
+ messages.append({"role": "user", "content": val[0]}) # Añade los mensajes del usuario
105
  if val[1]:
106
+ messages.append({"role": "assistant", "content": val[1]}) # Añade las respuestas del asistente
107
+ messages.append({"role": "user", "content": message}) # Añade el mensaje actual del usuario
108
+ response = "" # Inicializa la variable de respuesta
109
  for message_response in client.chat_completion(messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p):
110
+ token = message_response.choices[0].delta.content # Obtiene el contenido del mensaje generado por el modelo
111
+ response += token # Acumula el contenido generado en la respuesta final
112
+ return response # Retorna la respuesta generada por el modelo
113
 
114
  def cambiar_pestaña():
115
+ return gr.update(visible=False), gr.update(visible=True) # Esta función cambia la visibilidad de las pestañas en la interfaz
116
 
117
  def display_prediction(audio, prediction_func):
118
+ prediction = prediction_func(audio) # Llama a la función de predicción para obtener el resultado
119
+ return f"<h3 style='text-align: center; font-size: 1.5em;'>{prediction}</h3>" # Retorna el resultado formateado en HTML
120
 
121
  def display_prediction_wrapper(audio):
122
+ return display_prediction(audio, predict) # Envuelve la función de predicción "predict" en la función "display_prediction"
123
 
124
  def display_prediction_stream(audio):
125
+ return display_prediction(audio, predict_stream) # Envuelve la función de predicción "predict_stream" en la función "display_prediction"
126
 
127
  my_theme = gr.themes.Soft(
128
  primary_hue="emerald",
 
200
  "<p style='text-align: left'>Monitor: Nuestro monitor no es como otros que hay en el mercado, ya que es capaz de reconocer si un sonido es un llanto del bebé o no; y si está llorando, predice automáticamente la causa. Dándote la tranquilidad de saber siempre qué pasa con tu pequeño, ahorrándote tiempo y horas de sueño.</p>"
201
  )
202
  boton_inicial = gr.Button("¡Prueba nuestros modelos!")
203
+ with gr.Column(visible=False) as chatbot: # Columna para la pestaña del chatbot
204
+ gr.Markdown("<h2>Asistente</h2>") # Título de la pestaña del chatbot
205
+ gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Pregunta a nuestro asistente cualquier duda que tengas sobre el cuidado de tu bebé</h4>") # Descripción de la pestaña del chatbot
206
+ gr.ChatInterface(
207
+ chatbot_config, # Función de configuración del chatbot
208
+ theme=my_theme, # Tema personalizado para la interfaz
209
+ retry_btn=None, # Botón de reintentar desactivado
210
+ undo_btn=None, # Botón de deshacer desactivado
211
+ clear_btn="Limpiar 🗑️", # Botón de limpiar mensajes
212
+ autofocus=True, # Enfocar automáticamente el campo de entrada de texto
213
+ fill_height=True, # Rellenar el espacio verticalmente
214
+ )
215
+ with gr.Row(): # Fila para los botones de cambio de pestaña
216
+ with gr.Column(): # Columna para el botón del predictor
217
+ boton_predictor = gr.Button("Predictor") # Botón para cambiar a la pestaña del predictor
218
+ with gr.Column(): # Columna para el botón del monitor
219
+ boton_monitor = gr.Button("Monitor") # Botón para cambiar a la pestaña del monitor
220
+ with gr.Column(visible=False) as pag_predictor: # Columna para la pestaña del predictor
221
+ gr.Markdown("<h2>Predictor</h2>") # Título de la pestaña del predictor
222
+ gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Descubre por qué tu bebé está llorando</h4>") # Descripción de la pestaña del predictor
223
  audio_input = gr.Audio(
224
+ min_length=1.0, # Duración mínima del audio requerida
225
+ format="wav", # Formato de audio admitido
226
+ label="Baby recorder", # Etiqueta del campo de entrada de audio
227
+ type="filepath", # Tipo de entrada de audio (archivo)
228
+ )
229
+ prediction_output = gr.Markdown() # Salida para mostrar la predicción
230
  gr.Button("¿Por qué llora?").click(
231
+ display_prediction_wrapper, # Función de predicción para el botón
232
+ inputs=audio_input, # Entrada de audio para la función de predicción
233
+ outputs=gr.Markdown() # Salida para mostrar la predicción
234
+ )
235
+ gr.Button("Volver").click(cambiar_pestaña, outputs=[pag_predictor, chatbot]) # Botón para volver a la pestaña del chatbot
236
+ with gr.Column(visible=False) as pag_monitor: # Columna para la pestaña del monitor
237
+ gr.Markdown("<h2>Monitor</h2>") # Título de la pestaña del monitor
238
+ gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Detecta en tiempo real si tu bebé está llorando y por qué</h4>") # Descripción de la pestaña del monitor
239
  audio_stream = gr.Audio(
240
+ format="wav", # Formato de audio admitido
241
+ label="Baby recorder", # Etiqueta del campo de entrada de audio
242
+ type="filepath", # Tipo de entrada de audio (archivo)
243
+ streaming=True # Habilitar la transmisión de audio en tiempo real
244
  )
245
  threshold_db = gr.Slider(
246
+ minimum=0, # Valor mínimo del umbral de ruido
247
+ maximum=120, # Valor máximo del umbral de ruido
248
+ step=1, # Paso del umbral de ruido
249
+ value=20, # Valor inicial del umbral de ruido
250
+ label="Umbral de ruido para activar la predicción:" # Etiqueta del control deslizante del umbral de ruido
251
  )
252
+ volver = gr.Button("Volver") # Botón para volver a la pestaña del chatbot
253
  audio_stream.stream(
254
+ mostrar_decibelios, # Función para mostrar el nivel de decibelios
255
+ inputs=[audio_stream, threshold_db], # Entradas para la función de mostrar decibelios
256
+ outputs=gr.HTML() # Salida para mostrar el nivel de decibelios
257
  )
258
  audio_stream.stream(
259
+ predict_stream_decib, # Función para realizar la predicción en tiempo real
260
+ inputs=[audio_stream, threshold_db], # Entradas para la función de predicción en tiempo real
261
+ outputs=gr.HTML() # Salida para mostrar la predicción en tiempo real
262
  )
263
+ volver.click(cambiar_pestaña, outputs=[pag_monitor, chatbot]) # Botón para volver a la pestaña del chatbot
264
+ boton_inicial.click(cambiar_pestaña, outputs=[inicial, chatbot]) # Botón para cambiar a la pestaña inicial
265
+ boton_predictor.click(cambiar_pestaña, outputs=[chatbot, pag_predictor]) # Botón para cambiar a la pestaña del predictor
266
+ boton_monitor.click(cambiar_pestaña, outputs=[chatbot, pag_monitor]) # Botón para cambiar a la pestaña del monitor
267
+ demo.launch(share=True) # Lanzar la interfaz gráfica