Doctor_v2 / app.py
davidfearne's picture
Update app.py
339056d verified
from fastapi import FastAPI, HTTPException
import streamlit as st
import pandas as pd
from pydantic import BaseModel, Field, validator
import numpy as np
import plotly.graph_objects as go
from azure_openai import converse_with_patient, create_diagnosis
from memory import get_conversation, store_conversation, update_conversation
import uuid
class ask_question (BaseModel):
user_input: str
id: str
app = FastAPI()
def generate_expert_confidence_chart(diagnosis):
"""
Extracts expert confidence data from JSON and generates a multi-colored bar chart.
"""
# Extract expert distribution data
expert_distribution = diagnosis["expert_distribution"]
# Process the data into a structured format
rows = []
for key, value in expert_distribution.items():
expert, attribute = key.rsplit(", ", 1) # Ensure splitting at the last comma
rows.append({"Expert": expert, "Attribute": attribute, "Value": value})
# Create a DataFrame
df = pd.DataFrame(rows)
# Filter the DataFrame for confidence values only
df_confidence = df[df["Attribute"] == "confidence"].copy()
# Merge confidence values with corresponding thinking explanations
df_thinking = df[df["Attribute"] == "thinking"].copy()
df_confidence = df_confidence.merge(df_thinking, on="Expert", suffixes=("_confidence", "_thinking"))
# Convert confidence values to numeric
df_confidence["Value_confidence"] = pd.to_numeric(df_confidence["Value_confidence"])
# Define a function to map confidence scores to colors
def confidence_to_color(confidence):
"""
Maps confidence score (0-100) to a blended color between red (0 confidence) and green (100 confidence).
"""
red = np.array([255, 0, 0])
green = np.array([0, 255, 0])
blend_ratio = confidence / 100 # Normalize between 0 and 1
blended_color = (1 - blend_ratio) * red + blend_ratio * green
return f"rgb({int(blended_color[0])}, {int(blended_color[1])}, {int(blended_color[2])})"
# Apply color mapping
df_confidence["Color"] = df_confidence["Value_confidence"].apply(confidence_to_color)
# Create the bar chart
fig = go.Figure()
# Add bars with customized colors and reduced spacing
fig.add_trace(go.Bar(
y=df_confidence["Expert"],
x=df_confidence["Value_confidence"],
text=df_confidence["Value_confidence"],
hovertext=df_confidence["Value_thinking"],
orientation="h",
marker=dict(color=df_confidence["Color"]),
width=0.3, # Reduce bar width for closer spacing
textposition="inside"
))
# Update layout for better visibility
fig.update_layout(
title="Expert Confidence in Diagnosis",
xaxis_title="Confidence Score",
yaxis_title="Medical Expert",
yaxis=dict(tickmode="linear", dtick=1, automargin=True),
height=max(400, 40 * len(df_confidence)), # Adjust height dynamically
bargap=0.1 # Reduce spacing between bars
)
# Update hover template
fig.update_traces(
hovertemplate="<b>%{y}</b><br>Confidence: %{x}%<br>Thinking: %{hovertext}"
)
# Show the plot
return fig
# FastAPI interface routes
# @app.get("/")
# async def root():
# return {"message": "Welcome to the GenAI Symptom Checker"}
# @app.post("/ask")
# async def ask_question(question: ask_question):
# try:
# user_input = question.user_input
# conversation_id = question.id
# exists, count, conversation_obj = get_conversation(conversation_id)
# if count == 6:
# response = converse_with_patient(conversation_obj, user_input)
# store_conversation(conversation_id, conversation_id, user_input, response)
# exists, count, conversation_obj = get_conversation(conversation_id)
# diagnosis = create_diagnosis(conversation_obj)
# return {"response": response, "count": count, "diagnosis": diagnosis}
# if count > 6:
# exists, count, conversation_obj = get_conversation(conversation_id)
# diagnosis_content = next((item['content'] for item in conversation_obj if item['role'] == 'diagnosis'), None)
# return {"response": "You have reached the maximum number of questions", "count": count, "diagnosis": diagnosis_content}
# if exists == "PASS":
# response = converse_with_patient(conversation_obj, user_input)
# update_conversation(conversation_id, conversation_id, user_input, response)
# return {"response": response, "count": count, "diagnosis": "none"}
# else:
# response = converse_with_patient("",user_input)
# store_conversation(conversation_id, conversation_id, user_input, response)
# return {"response": response, "count": count, "diagnosis": "none"}
# except Exception as e:
# raise HTTPException(status_code=500, detail=str(e))
# app config
st.set_page_config(page_title="virtual clinician", page_icon=":medical_symbol:")
st.title("Virtual Clinician :medical_symbol:")
user_id = st.text_input("Name:", key="user_id")
conversation_id = user_id
# Ensure user_id is defined or fallback to a default value
if not user_id:
st.warning("Hi, Who am I speaking with?")
else:
# session state
if "chat_history" not in st.session_state:
st.session_state.chat_history = [
{"role": "AI", "content": f"Hello, {user_id} I am the virtual clinician. How can I help you today?"},
]
# conversation
for message in st.session_state.chat_history:
if message["role"] == "AI":
with st.chat_message("AI"):
st.write(message["content"])
elif message["role"] == "Human":
with st.chat_message("Human"):
st.write(message["content"])
# user input
user_input = st.chat_input("Type your message here...")
if user_input is not None and user_input != "":
st.session_state.chat_history.append({"role": "Human", "content": user_input})
with st.chat_message("Human"):
st.markdown(user_input)
# this functions checks to see if the conversation exists
exists, count, conversation_obj = get_conversation(conversation_id)
# if the conversation does not exist, it creates a new conversation
if count > 5:
response = converse_with_patient(conversation_obj, user_input)
conversation_obj = update_conversation(conversation_id, user_input, response)
print(conversation_obj)
with st.spinner("Creating a diagnosis..."):
outcome, diagnosis = create_diagnosis(conversation_obj)
if outcome == "SUCCESS":
st.subheader("Diagnosis Summary")
st.write(f"**Diagnosis:** {diagnosis['concensus_diagnosis']}")
st.write(f"**Consensus Confidence:** {diagnosis['concensus_confidence']}%")
st.write(f"**Consensus Thinking:** {diagnosis['concensus_thinking']}")
st.write(f"**Evaluation Confidence:** {diagnosis['evaluate_confidence']}%")
st.write(f"**Evaluation Explanation:** {diagnosis['evaluate_explanation']}")
st.write(f"**Next Best Action:** {diagnosis['next_best_action_']}")
st.write(f"**Next Best Action Explanation:** {diagnosis['next_best_action_explanation']}")
st.write(f"**Next Best Action Confidence:** {diagnosis['next_best_action_confidence']}%")
# Generate and display the plotly chart
st.subheader("Expert Confidence Levels")
fig = generate_expert_confidence_chart(diagnosis)
st.plotly_chart(fig)
# if the diagnosis is not successful, display a message
if outcome == "FAIL1":
st.write("Diagnosis not available Failed to find concensus")
st.subheader("Incomplete Diagnosis")
st.write(f"**Diagnosis:** {diagnosis['concensus_diagnosis']}")
st.write(f"**Consensus Confidence:** {diagnosis['concensus_confidence']}%")
st.write(f"**Consensus Thinking:** {diagnosis['concensus_thinking']}")
st.write(f"**Next Best Action:** See GP")
st.write(f"**Next Best Action Explanation:** Please give more details to help the AI better understand your symptoms ")
# Generate and display the plotly chart
st.subheader("Expert Confidence Levels")
fig = generate_expert_confidence_chart(diagnosis)
st.plotly_chart(fig)
if outcome == "FAIL2":
st.write("Diagnosis not available Failed to match described symptoms with know symptoms for AI diagnosis")
st.subheader("Incomplete Diagnosis")
st.write(f"**Diagnosis:** {diagnosis['concensus_diagnosis']}")
st.write(f"**Consensus Confidence:** {diagnosis['concensus_confidence']}%")
st.write(f"**Consensus Thinking:** {diagnosis['concensus_thinking']}")
st.write(f"**Evaluation Confidence:** {diagnosis['evaluate_confidence']}%")
st.write(f"**Evaluation Explanation:** {diagnosis['evaluate_explanation']}")
st.write(f"**Next Best Action:** See GP")
st.write(f"**Next Best Action Explanation:** Please give more details to help the AI better understand your symptoms ")
# Generate and display the plotly chart
st.subheader("Expert Confidence Levels")
fig = generate_expert_confidence_chart(diagnosis)
st.plotly_chart(fig)
if exists == "PASS":
response = converse_with_patient(conversation_obj, user_input)
update_conversation(conversation_id, user_input, response)
st.session_state.chat_history.append({"role": "AI", "content": response})
with st.chat_message("AI"):
st.write(response)
else:
response = converse_with_patient("",user_input)
store_conversation(conversation_id, user_input, response)
st.session_state.chat_history.append({"role": "AI", "content": response})
with st.chat_message("AI"):
st.write(response)