playground / main.py
Francesco's picture
added more settings to patients
4a50ad1
raw
history blame
10 kB
import logging
from pathlib import Path
from typing import List, Optional, Tuple, Dict
import json
from dotenv import load_dotenv
load_dotenv()
from queue import Empty, Queue
from threading import Thread
import os
import gradio as gr
from langchain.chat_models import ChatOpenAI
from langchain.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain.schema import AIMessage, BaseMessage, HumanMessage
from js import get_window_url_params
from callback import QueueCallback
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from db import (
User,
Chat,
create_user,
get_client,
get_user_by_username,
add_chat_by_uid,
)
MODELS_NAMES = ["gpt-3.5-turbo", "gpt-4"]
DEFAULT_TEMPERATURE = 0.7
ChatHistory = List[str]
logging.basicConfig(
format="[%(asctime)s %(levelname)s]: %(message)s", level=logging.INFO
)
# load redis client
client = get_client()
# load up our system prompt
system_message_prompt = SystemMessagePromptTemplate.from_template(
Path("prompts/system.prompt").read_text()
)
# for the human, we will just inject the text
human_message_prompt_template = HumanMessagePromptTemplate.from_template("{text}")
with open("data/patients.json") as f:
patiens = json.load(f)
patients_names = [el["name"] for el in patiens]
def message_handler(
chat: Optional[ChatOpenAI],
message: str,
chatbot_messages: ChatHistory,
messages: List[BaseMessage],
) -> Tuple[ChatOpenAI, str, ChatHistory, List[BaseMessage]]:
if chat is None:
# in the queue we will store our streamed tokens
queue = Queue()
print("Creating new chat")
# let's create our default chat
chat = ChatOpenAI(
model_name=MODELS_NAMES[0],
temperature=DEFAULT_TEMPERATURE,
streaming=True,
callbacks=([QueueCallback(queue)]),
)
else:
# hacky way to get the queue back
queue = chat.callbacks[0].queue
job_done = object()
logging.info("asking question to GPT")
# let's add the messages to our stuff
messages.append(HumanMessage(content=f"Doctor:{message}"))
chatbot_messages.append((message, ""))
# this is a little wrapper we need cuz we have to add the job_done
def task():
chat(messages)
queue.put(job_done)
# now let's start a thread and run the generation inside it
t = Thread(target=task)
t.start()
# this will hold the content as we generate
content = ""
# now, we read the next_token from queue and do what it has to be done
while True:
try:
next_token = queue.get(True, timeout=1)
if next_token is job_done:
break
content += next_token
chatbot_messages[-1] = (message, content)
yield chat, "", chatbot_messages, messages
except Empty:
continue
# finally we can add our reply to messsages
messages.append(AIMessage(content=content))
logging.debug(f"reply = {content}")
logging.info(f"Done!")
return chat, "", chatbot_messages, messages
def on_clear_click(patient: str) -> Tuple[str, List, List]:
messages = [system_message_prompt.format(patient=patient)]
return "", [], messages
def on_done_click(
chatbot_messages: ChatHistory, patient: str, user: User
) -> Tuple[str, List, List]:
logging.info(f"Saving chat for user={user}")
add_chat_by_uid(
client, Chat(patient=patient, messages=chatbot_messages), user["uid"]
)
return on_clear_click(patient)
def on_apply_settings_click(
model_name: str,
temperature: float,
patient: str,
difficulty: str,
communicative: str,
):
logging.info(
f"Applying settings: model_name={model_name}, temperature={temperature}"
)
chat = ChatOpenAI(
model_name=model_name,
temperature=temperature,
streaming=True,
callbacks=[QueueCallback(Queue())],
)
# don't forget to nuke our queue
chat.callbacks[0].queue.empty()
patient["difficulty"] = difficulty
patient["communicative"] = communicative
print("patient", patient)
message, chatbot, messages = on_clear_click(patient)
return chat, message, chatbot, messages, patient
def on_drop_down_change(selected_item, messages):
index = patients_names.index(selected_item)
patient = patiens[index]
messages = [system_message_prompt.format(patient=patient)]
print(f"You selected: {selected_item}", index)
print(f"on_drop_down_change {patient}")
return (
patient,
patient,
[],
messages,
patient["difficulty"],
patient["communicative"],
)
def on_demo_load(url_params, request: gr.Request):
username = request.username or url_params.get("username", "test")
logging.info(f"Getting user for username={username}")
create_user(client, User(username=username, uid=None))
user = get_user_by_username(client, username)
logging.info(f"User {user}")
print(f"got url_params: {url_params}")
return user, f"Nice to see you {user['username']} πŸ‘‹"
url_params = gr.JSON({}, visible=False, label="URL Params")
# some css why not, "borrowed" from https://huggingface.co/spaces/ysharma/Gradio-demo-streaming/blob/main/app.py
with gr.Blocks(
css="""#col_container {width: 700px; margin-left: auto; margin-right: auto;}
#chatbot {height: 400px; overflow: auto;}"""
) as demo:
# here we keep our state so multiple user can use the app at the same time!
messages = gr.State([system_message_prompt.format(patient=patiens[0])])
# same thing for the chat, we want one chat per use so callbacks are unique I guess
chat = gr.State(None)
user = gr.State(None)
patient = gr.State(patiens[0])
# see here https://github.com/gradio-app/gradio/discussions/2949#discussioncomment-5278991
url_params.render()
with gr.Column(elem_id="col_container"):
gr.Markdown("# Welcome to OscePal! πŸ‘¨β€βš•οΈπŸ§‘β€βš•οΈ")
welcome_markdown = gr.Markdown("")
demo.load(
fn=on_demo_load,
inputs=[url_params],
outputs=[user, welcome_markdown],
_js=get_window_url_params,
)
chatbot = gr.Chatbot()
with gr.Column():
message = gr.Textbox(label="chat input")
message.submit(
message_handler,
[chat, message, chatbot, messages],
[chat, message, chatbot, messages],
queue=True,
)
submit = gr.Button("Send Message", variant="primary")
submit.click(
message_handler,
[chat, message, chatbot, messages],
[chat, message, chatbot, messages],
)
with gr.Row():
with gr.Column():
js = "(x) => confirm('Press a button!')"
done = gr.Button("Done", variant="stop")
done.click(
on_done_click,
[chatbot, patient, user],
[message, chatbot, messages],
)
with gr.Accordion("Settings", open=False):
model_name = gr.Dropdown(
choices=MODELS_NAMES, value=MODELS_NAMES[0], label="model"
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.1,
label="temperature",
interactive=True,
)
difficulty = gr.Dropdown(
choices=["easy", "medium", "hard"],
value=patient.value["difficulty"],
label="difficulty",
interactive=True,
)
communicative = gr.Slider(
minimum=0,
maximum=5,
value=patient.value["communicative"],
step=1,
label="communicative",
interactive=True,
)
apply_settings = gr.Button("Apply")
with gr.Column():
patients_names = [el["name"] for el in patiens]
dropdown = gr.Dropdown(
choices=patients_names,
value=patients_names[0],
interactive=True,
label="Patient",
)
patient_card = gr.JSON(patient.value, visible=True, label="Patient card")
apply_settings.click(
on_apply_settings_click,
[model_name, temperature, patient, difficulty, communicative],
[chat, message, chatbot, messages, patient_card],
)
dropdown.change(
fn=on_drop_down_change,
inputs=[dropdown, messages],
outputs=[
patient_card,
patient,
chatbot,
messages,
difficulty,
communicative,
],
)
# app = FastAPI()
# os.makedirs("static", exist_ok=True)
# app.mount("/static", StaticFiles(directory="static"), name="static")
# templates = Jinja2Templates(directory="templates")
# @app.get("/", response_class=HTMLResponse)
# async def home(request: Request):
# return templates.TemplateResponse(
# "home.html", {"request": request, "videos": []})
def auth_handler(username: str, password: str) -> bool:
if password != os.environ["GRADIO_PASSWORD"]:
return False
return True
demo.queue()
demo.launch(auth=auth_handler)
# gradio_app = gr.routes.App.create_app(demo)
# app.mount("/gradio", gradio_app)