File size: 7,208 Bytes
1e6dc54
 
5cf41d0
ace06e3
abdf62b
 
cc3562b
6d1143c
2ca1b49
abdf62b
ace06e3
abdf62b
 
 
 
 
 
 
 
 
 
 
 
 
1e6dc54
 
abdf62b
 
 
 
 
 
1e6dc54
 
abdf62b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ace06e3
 
 
ebf42ac
 
ace06e3
 
 
abdf62b
 
ace06e3
 
 
1e6dc54
 
 
ace06e3
abdf62b
 
 
 
 
 
 
ebf42ac
 
 
 
 
 
 
 
 
abdf62b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import torch
import gradio as gr
from huggingface_hub import InferenceClient
from model import predict_params, AudioDataset
from interfaz import estilo, my_theme

token = os.getenv("HF_TOKEN")
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
model_cache = {}

def load_model_and_dataset(model_path, dataset_path, filter_white_noise):
    if (model_path, dataset_path, filter_white_noise) not in model_cache:
        model, _, _, id2label = predict_params(dataset_path, model_path, filter_white_noise)
        model_cache[(model_path, dataset_path, filter_white_noise)] = (model, id2label)
    return model_cache[(model_path, dataset_path, filter_white_noise)]

def predict(audio_path, model_path, dataset_path, filter_white_noise):
    model, id2label = load_model_and_dataset(model_path, dataset_path, filter_white_noise)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    audios = AudioDataset(dataset_path, {}, filter_white_noise).preprocess_audio(audio_path)
    inputs = {"input_values": audios.to(device).unsqueeze(0)}
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_class_ids = torch.argmax(logits, dim=-1).item()
        label = id2label[predicted_class_ids]
        if dataset_path == "data/mixed_data":
            label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
            label = label_mapping.get(predicted_class_ids, label)
    return label

def predict_stream(audio_path):
    model_mon, _ = load_model_and_dataset(
        model_path="distilhubert-finetuned-cry-detector",
        dataset_path="data/baby_cry_detection",
        filter_white_noise=False
        )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_mon.to(device)
    model_mon.eval()
    audio_dataset = AudioDataset(dataset_path="data/baby_cry_detection", label2id={}, filter_white_noise=False)
    processed_audio = audio_dataset.preprocess_audio(audio_path)
    inputs = {"input_values": processed_audio.to(device).unsqueeze(0)}
    with torch.no_grad():
        outputs = model_mon(**inputs)
        logits = outputs.logits
        probabilities = torch.nn.functional.softmax(logits, dim=-1)
        crying_probabilities = probabilities[:, 1]
        avg_crying_probability = crying_probabilities.mean().item()*100
    if avg_crying_probability < 25:
        model_class, id2label = load_model_and_dataset(
            model_path="distilhubert-finetuned-mixed-data",
            dataset_path="data/mixed_data",
            filter_white_noise=True
            )
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model_class.to(device)
        model_class.eval()
        audio_dataset_class = AudioDataset(dataset_path="data/mixed_data", label2id={}, filter_white_noise=True)
        processed_audio_class = audio_dataset_class.preprocess_audio(audio_path)
        inputs_class = {"input_values": processed_audio_class.to(device).unsqueeze(0)}
        with torch.no_grad():
            outputs_class = model_class(**inputs_class)
            logits_class = outputs_class.logits
            predicted_class_ids_class = torch.argmax(logits_class, dim=-1).item()
            label_class = id2label[predicted_class_ids_class]
            label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
            label_class = label_mapping.get(predicted_class_ids_class, label_class)
        return f"Bebé llorando por {label_class}. Probabilidad: {avg_crying_probability:.1f})"
    else:
        return f"No está llorando. Proabilidad: {avg_crying_probability:.1f})"

def chatbot_config(message, history: list[tuple[str, str]]):
    system_message = "You are a Chatbot specialized in baby health and care."
    max_tokens = 512
    temperature = 0.7
    top_p = 0.95
    messages = [{"role": "system", "content": system_message}]
    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})
    messages.append({"role": "user", "content": message})
    response = ""
    for message_response in client.chat_completion(messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p):
        token = message_response.choices[0].delta.content
        response += token
        yield response

def cambiar_pestaña():
    return gr.update(visible=False), gr.update(visible=True)

with gr.Blocks(theme=my_theme) as demo:
    estilo()
    with gr.Column(visible=True) as chatbot:    
        gr.Markdown("<h2>Asistente</h2>")
        gr.ChatInterface(
            chatbot_config # TODO: Mirar argumentos
            )
        gr.Markdown("Este chatbot no sustituye a un profesional de la salud. Ante cualquier preocupación o duda, consulta con tu pediatra.")
        with gr.Row():
            with gr.Column():
                gr.Markdown("<h2>Predictor</h2>")
                boton_pagina_1 = gr.Button("Prueba el predictor")
                gr.Markdown("<p>Descubre por qué llora tu bebé y resuelve dudas sobre su cuidado con nuestro Iremia assistant</p>")
            with gr.Column():
                gr.Markdown("<h2>Monitor</h2>")
                boton_pagina_2 = gr.Button("Prueba el monitor")
                gr.Markdown("<p>Un monitor inteligente que detecta si tu hijo está llorando y te indica el motivo antes de que puedas levantarte del sofá</p>")
    with gr.Column(visible=False) as pag_predictor:
        gr.Markdown("<h2>Predictor</h2>")
        audio_input = gr.Audio(
            min_length=1.0,
            format="wav",
            label="Baby recorder",
            type="filepath",
            )
        classify_btn = gr.Button("¿Por qué llora?")
        classify_btn.click(
            lambda audio: predict( # Mirar porque usar lambda
                audio,
                model_path="distilhubert-finetuned-mixed-data",
                dataset_path="data/mixed_data",
                filter_white_noise=True
                ),
            inputs=audio_input,
            outputs=gr.Textbox(label="Tu bebé llora por:")
            )
        gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_predictor, chatbot])
    with gr.Column(visible=False) as pag_monitor:
        gr.Markdown("<h2>Monitor</h2>")
        audio_stream = gr.Audio(
                # min_length=1.0, # mirar por qué no va esto
                format="wav",
                label="Baby recorder",
                type="filepath",
                streaming=True
            )
        audio_stream.stream(
            predict_stream,
            inputs=audio_stream,
            outputs=gr.Textbox(label="Tu bebé está:"),
        )
        gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_monitor, chatbot])
    boton_pagina_1.click(cambiar_pestaña, outputs=[chatbot, pag_predictor])
    boton_pagina_2.click(cambiar_pestaña, outputs=[chatbot, pag_monitor])
demo.launch(share=True)