File size: 2,943 Bytes
6295354
 
 
6d82372
6295354
 
6d82372
 
 
 
296923b
6d82372
6295354
6d82372
296923b
 
6295354
402fd8a
6295354
402fd8a
6295354
 
296923b
 
 
 
 
6295354
 
296923b
 
6295354
296923b
6295354
402fd8a
 
 
6d82372
402fd8a
6295354
296923b
 
 
 
 
 
 
 
 
6295354
 
296923b
 
 
6295354
296923b
 
6d82372
296923b
 
6d82372
296923b
6295354
 
 
 
296923b
 
6295354
6d82372
 
 
 
296923b
 
 
6d82372
 
296923b
 
 
 
 
 
6d82372
 
 
 
 
 
 
 
 
296923b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from pathlib import Path

from dotenv import load_dotenv

load_dotenv()

import json
import os
from functools import partial
from pathlib import Path
from pprint import pprint

import gradio as gr
from langchain.chat_models import ChatOpenAI
from langchain.prompts import (HumanMessagePromptTemplate,
                               PromptTemplate, SystemMessagePromptTemplate)

# import whisper

# model = whisper.load_model("base", device="cuda")


system_message_prompt = SystemMessagePromptTemplate(
    prompt=PromptTemplate(
        input_variables=["patient"],
        template=Path("prompts/patient.prompt").read_text(),
    )
)

human_template = "Doctor: {text}"
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)

chat = ChatOpenAI(temperature=0.7)

with open("data/patients.json") as f:
    patiens = json.load(f)

patients_names = [el["name"] for el in patiens]


def run_text_prompt(message, chat_history, messages):
    if not messages:
        messages = []
        messages.append(system_message_prompt.format(patient=patient))
    messages.append(human_message_prompt.format(text=message))
    messages.append(chat(messages))
    pprint(messages)
    chat_history.append((message, messages[-1].content))
    return "", chat_history, messages


def on_clear_button_click(patient, messages):
    messages = [system_message_prompt.format(patient=patient)]
    return [], messages


def on_drop_down_change(selected_item, messages):
    index = patients_names.index(selected_item)
    patient = patiens[index]
    messages = [system_message_prompt.format(patient=patient)]
    print(f"You selected: {selected_item}", index)
    return f"```json\n{json.dumps(patient, indent=2)}\n```", patient, [], messages


with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    messages = gr.State([])
    patient = gr.State(patiens[0])

    with gr.Row():
        with gr.Column():
            msg = gr.Textbox()
            msg.submit(
                run_text_prompt,
                [msg, chatbot, messages],
                [msg, chatbot, messages],
            )
            clear = gr.Button("Clear")
            clear.click(
                on_clear_button_click,
                [patient, messages],
                [chatbot, messages],
                queue=False,
            )

        with gr.Column():
            patients_names = [el["name"] for el in patiens]
            dropdown = gr.Dropdown(
                choices=patients_names,
                value=patients_names[0],
                interactive=True,
                label="Patient",
            )
            markdown = gr.Markdown(
                f"```json\n{json.dumps(patient.value, indent=2)}\n```"
            )
            dropdown.change(
                fn=on_drop_down_change,
                inputs=[dropdown, messages],
                outputs=[markdown, patient, chatbot, messages],
            ),
# demo.launch(debug=True)