Spaces:
Sleeping
Sleeping
File size: 7,208 Bytes
1e6dc54 5cf41d0 ace06e3 abdf62b cc3562b 6d1143c 2ca1b49 abdf62b ace06e3 abdf62b 1e6dc54 abdf62b 1e6dc54 abdf62b ace06e3 ebf42ac ace06e3 abdf62b ace06e3 1e6dc54 ace06e3 abdf62b ebf42ac abdf62b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import os
import torch
import gradio as gr
from huggingface_hub import InferenceClient
from model import predict_params, AudioDataset
from interfaz import estilo, my_theme
token = os.getenv("HF_TOKEN")
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
model_cache = {}
def load_model_and_dataset(model_path, dataset_path, filter_white_noise):
if (model_path, dataset_path, filter_white_noise) not in model_cache:
model, _, _, id2label = predict_params(dataset_path, model_path, filter_white_noise)
model_cache[(model_path, dataset_path, filter_white_noise)] = (model, id2label)
return model_cache[(model_path, dataset_path, filter_white_noise)]
def predict(audio_path, model_path, dataset_path, filter_white_noise):
model, id2label = load_model_and_dataset(model_path, dataset_path, filter_white_noise)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
audios = AudioDataset(dataset_path, {}, filter_white_noise).preprocess_audio(audio_path)
inputs = {"input_values": audios.to(device).unsqueeze(0)}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_ids = torch.argmax(logits, dim=-1).item()
label = id2label[predicted_class_ids]
if dataset_path == "data/mixed_data":
label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
label = label_mapping.get(predicted_class_ids, label)
return label
def predict_stream(audio_path):
model_mon, _ = load_model_and_dataset(
model_path="distilhubert-finetuned-cry-detector",
dataset_path="data/baby_cry_detection",
filter_white_noise=False
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_mon.to(device)
model_mon.eval()
audio_dataset = AudioDataset(dataset_path="data/baby_cry_detection", label2id={}, filter_white_noise=False)
processed_audio = audio_dataset.preprocess_audio(audio_path)
inputs = {"input_values": processed_audio.to(device).unsqueeze(0)}
with torch.no_grad():
outputs = model_mon(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
crying_probabilities = probabilities[:, 1]
avg_crying_probability = crying_probabilities.mean().item()*100
if avg_crying_probability < 25:
model_class, id2label = load_model_and_dataset(
model_path="distilhubert-finetuned-mixed-data",
dataset_path="data/mixed_data",
filter_white_noise=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_class.to(device)
model_class.eval()
audio_dataset_class = AudioDataset(dataset_path="data/mixed_data", label2id={}, filter_white_noise=True)
processed_audio_class = audio_dataset_class.preprocess_audio(audio_path)
inputs_class = {"input_values": processed_audio_class.to(device).unsqueeze(0)}
with torch.no_grad():
outputs_class = model_class(**inputs_class)
logits_class = outputs_class.logits
predicted_class_ids_class = torch.argmax(logits_class, dim=-1).item()
label_class = id2label[predicted_class_ids_class]
label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
label_class = label_mapping.get(predicted_class_ids_class, label_class)
return f"Bebé llorando por {label_class}. Probabilidad: {avg_crying_probability:.1f})"
else:
return f"No está llorando. Proabilidad: {avg_crying_probability:.1f})"
def chatbot_config(message, history: list[tuple[str, str]]):
system_message = "You are a Chatbot specialized in baby health and care."
max_tokens = 512
temperature = 0.7
top_p = 0.95
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message_response in client.chat_completion(messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p):
token = message_response.choices[0].delta.content
response += token
yield response
def cambiar_pestaña():
return gr.update(visible=False), gr.update(visible=True)
with gr.Blocks(theme=my_theme) as demo:
estilo()
with gr.Column(visible=True) as chatbot:
gr.Markdown("<h2>Asistente</h2>")
gr.ChatInterface(
chatbot_config # TODO: Mirar argumentos
)
gr.Markdown("Este chatbot no sustituye a un profesional de la salud. Ante cualquier preocupación o duda, consulta con tu pediatra.")
with gr.Row():
with gr.Column():
gr.Markdown("<h2>Predictor</h2>")
boton_pagina_1 = gr.Button("Prueba el predictor")
gr.Markdown("<p>Descubre por qué llora tu bebé y resuelve dudas sobre su cuidado con nuestro Iremia assistant</p>")
with gr.Column():
gr.Markdown("<h2>Monitor</h2>")
boton_pagina_2 = gr.Button("Prueba el monitor")
gr.Markdown("<p>Un monitor inteligente que detecta si tu hijo está llorando y te indica el motivo antes de que puedas levantarte del sofá</p>")
with gr.Column(visible=False) as pag_predictor:
gr.Markdown("<h2>Predictor</h2>")
audio_input = gr.Audio(
min_length=1.0,
format="wav",
label="Baby recorder",
type="filepath",
)
classify_btn = gr.Button("¿Por qué llora?")
classify_btn.click(
lambda audio: predict( # Mirar porque usar lambda
audio,
model_path="distilhubert-finetuned-mixed-data",
dataset_path="data/mixed_data",
filter_white_noise=True
),
inputs=audio_input,
outputs=gr.Textbox(label="Tu bebé llora por:")
)
gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_predictor, chatbot])
with gr.Column(visible=False) as pag_monitor:
gr.Markdown("<h2>Monitor</h2>")
audio_stream = gr.Audio(
# min_length=1.0, # mirar por qué no va esto
format="wav",
label="Baby recorder",
type="filepath",
streaming=True
)
audio_stream.stream(
predict_stream,
inputs=audio_stream,
outputs=gr.Textbox(label="Tu bebé está:"),
)
gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_monitor, chatbot])
boton_pagina_1.click(cambiar_pestaña, outputs=[chatbot, pag_predictor])
boton_pagina_2.click(cambiar_pestaña, outputs=[chatbot, pag_monitor])
demo.launch(share=True)
|