Marcos12886 commited on
Commit
017e65e
·
verified ·
1 Parent(s): 0f6f358

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -24
app.py CHANGED
@@ -1,10 +1,11 @@
1
  import os
2
  import torch
 
3
  import gradio as gr
4
  from huggingface_hub import InferenceClient
5
  from model import predict_params, AudioDataset
6
  import torchaudio
7
- # TODO: Que no diga lo de que no hay 1s_normal al predecir
8
  token = os.getenv("HF_TOKEN")
9
  client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -44,7 +45,7 @@ def predict(audio_path_pred):
44
  def predict_stream(audio_path_stream):
45
  with torch.no_grad():
46
  logits = call(audio_path_stream, model=model_mon, dataset_path="data/baby_cry_detection", filter_white_noise=False, undersample_normal=False)
47
- probabilities = torch.nn.functional.softmax(logits, dim=-1)
48
  crying_probabilities = probabilities[:, 1]
49
  avg_crying_probability = crying_probabilities.mean()*100
50
  if avg_crying_probability < 15:
@@ -54,11 +55,11 @@ def predict_stream(audio_path_stream):
54
  return "No está llorando"
55
 
56
  def decibelios(audio_path_stream):
57
- waveform, sample_rate = torchaudio.load(audio_path_stream)
58
  rms = torch.sqrt(torch.mean(torch.square(waveform)))
59
  db_level = 20 * torch.log10(rms + 1e-6).item()
60
- min_db = -80
61
- max_db = 0
62
  scaled_db_level = (db_level - min_db) / (max_db - min_db)
63
  normalized_db_level = scaled_db_level * 100
64
  return normalized_db_level
@@ -66,15 +67,15 @@ def decibelios(audio_path_stream):
66
  def mostrar_decibelios(audio_path_stream, visual_threshold):
67
  db_level = decibelios(audio_path_stream)
68
  if db_level > visual_threshold:
69
- return f"Prediciendo... Decibelios: {db_level:.2f}"
70
- elif db_level < visual_threshold:
71
- return f"Esperando... Decibelios: {db_level:.2f}"
 
72
 
73
  def predict_stream_decib(audio_path_stream, visual_threshold):
74
  db_level = decibelios(audio_path_stream)
75
  if db_level > visual_threshold:
76
- llorando = predict_stream(audio_path_stream)
77
- return f"{llorando}"
78
  else:
79
  return ""
80
 
@@ -99,6 +100,16 @@ def chatbot_config(message, history: list[tuple[str, str]]):
99
  def cambiar_pestaña():
100
  return gr.update(visible=False), gr.update(visible=True)
101
 
 
 
 
 
 
 
 
 
 
 
102
  my_theme = gr.themes.Soft(
103
  primary_hue="emerald",
104
  secondary_hue="green",
@@ -201,39 +212,41 @@ with gr.Blocks(theme=my_theme) as demo:
201
  label="Baby recorder",
202
  type="filepath",
203
  )
 
204
  gr.Button("¿Por qué llora?").click(
205
- predict,
206
  inputs=audio_input,
207
- outputs=gr.Textbox(label="Tu bebé llora por:")
208
  )
209
  gr.Button("Volver").click(cambiar_pestaña, outputs=[pag_predictor, chatbot])
210
  with gr.Column(visible=False) as pag_monitor:
211
  gr.Markdown("<h2>Monitor</h2>")
212
  gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Detecta en tiempo real si tu bebé está llorando y por qué</h4>")
213
  audio_stream = gr.Audio(
214
- format="wav",
215
- label="Baby recorder",
216
- type="filepath",
217
- streaming=True
218
- )
219
  threshold_db = gr.Slider(
220
  minimum=0,
221
- maximum=100,
222
  step=1,
223
- value=30,
224
  label="Umbral de ruido para activar la predicción:"
225
- )
 
226
  audio_stream.stream(
227
  mostrar_decibelios,
228
  inputs=[audio_stream, threshold_db],
229
- outputs=gr.Textbox(value="Esperando...", label="Estado")
230
- )
231
  audio_stream.stream(
232
  predict_stream_decib,
233
  inputs=[audio_stream, threshold_db],
234
- outputs=gr.Textbox(value="", label="Tu bebé:")
235
  )
236
- gr.Button("Volver").click(cambiar_pestaña, outputs=[pag_monitor, chatbot])
237
  boton_inicial.click(cambiar_pestaña, outputs=[inicial, chatbot])
238
  boton_predictor.click(cambiar_pestaña, outputs=[chatbot, pag_predictor])
239
  boton_monitor.click(cambiar_pestaña, outputs=[chatbot, pag_monitor])
 
1
  import os
2
  import torch
3
+ import torch.nn.functional as F
4
  import gradio as gr
5
  from huggingface_hub import InferenceClient
6
  from model import predict_params, AudioDataset
7
  import torchaudio
8
+
9
  token = os.getenv("HF_TOKEN")
10
  client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
45
  def predict_stream(audio_path_stream):
46
  with torch.no_grad():
47
  logits = call(audio_path_stream, model=model_mon, dataset_path="data/baby_cry_detection", filter_white_noise=False, undersample_normal=False)
48
+ probabilities = F.softmax(logits, dim=-1)
49
  crying_probabilities = probabilities[:, 1]
50
  avg_crying_probability = crying_probabilities.mean()*100
51
  if avg_crying_probability < 15:
 
55
  return "No está llorando"
56
 
57
  def decibelios(audio_path_stream):
58
+ waveform, _ = torchaudio.load(audio_path_stream)
59
  rms = torch.sqrt(torch.mean(torch.square(waveform)))
60
  db_level = 20 * torch.log10(rms + 1e-6).item()
61
+ min_db = -80
62
+ max_db = 0
63
  scaled_db_level = (db_level - min_db) / (max_db - min_db)
64
  normalized_db_level = scaled_db_level * 100
65
  return normalized_db_level
 
67
  def mostrar_decibelios(audio_path_stream, visual_threshold):
68
  db_level = decibelios(audio_path_stream)
69
  if db_level > visual_threshold:
70
+ status = "Prediciendo..."
71
+ else:
72
+ status = "Esperando..."
73
+ return f"<h3 style='text-align: center; font-size: 1.5em;'>{status} Decibelios: {db_level:.2f}</h3>"
74
 
75
  def predict_stream_decib(audio_path_stream, visual_threshold):
76
  db_level = decibelios(audio_path_stream)
77
  if db_level > visual_threshold:
78
+ return display_prediction_stream(audio_path_stream)
 
79
  else:
80
  return ""
81
 
 
100
  def cambiar_pestaña():
101
  return gr.update(visible=False), gr.update(visible=True)
102
 
103
+ def display_prediction(audio, prediction_func):
104
+ prediction = prediction_func(audio)
105
+ return f"<h3 style='text-align: center; font-size: 1.5em;'>{prediction}</h3>"
106
+
107
+ def display_prediction_wrapper(audio):
108
+ return display_prediction(audio, predict)
109
+
110
+ def display_prediction_stream(audio):
111
+ return display_prediction(audio, predict_stream)
112
+
113
  my_theme = gr.themes.Soft(
114
  primary_hue="emerald",
115
  secondary_hue="green",
 
212
  label="Baby recorder",
213
  type="filepath",
214
  )
215
+ prediction_output = gr.Markdown()
216
  gr.Button("¿Por qué llora?").click(
217
+ display_prediction_wrapper,
218
  inputs=audio_input,
219
+ outputs=gr.Markdown()
220
  )
221
  gr.Button("Volver").click(cambiar_pestaña, outputs=[pag_predictor, chatbot])
222
  with gr.Column(visible=False) as pag_monitor:
223
  gr.Markdown("<h2>Monitor</h2>")
224
  gr.Markdown("<h4 style='text-align: center; font-size: 1.5em'>Detecta en tiempo real si tu bebé está llorando y por qué</h4>")
225
  audio_stream = gr.Audio(
226
+ format="wav",
227
+ label="Baby recorder",
228
+ type="filepath",
229
+ streaming=True
230
+ )
231
  threshold_db = gr.Slider(
232
  minimum=0,
233
+ maximum=120,
234
  step=1,
235
+ value=20,
236
  label="Umbral de ruido para activar la predicción:"
237
+ )
238
+ volver = gr.Button("Volver")
239
  audio_stream.stream(
240
  mostrar_decibelios,
241
  inputs=[audio_stream, threshold_db],
242
+ outputs=gr.HTML(value="<h3 style='text-align: center; font-size: 1.5em;'>Esperando...</h3>")
243
+ )
244
  audio_stream.stream(
245
  predict_stream_decib,
246
  inputs=[audio_stream, threshold_db],
247
+ outputs=gr.HTML()
248
  )
249
+ volver.click(cambiar_pestaña, outputs=[pag_monitor, chatbot])
250
  boton_inicial.click(cambiar_pestaña, outputs=[inicial, chatbot])
251
  boton_predictor.click(cambiar_pestaña, outputs=[chatbot, pag_predictor])
252
  boton_monitor.click(cambiar_pestaña, outputs=[chatbot, pag_monitor])