Spaces:
Running
Running
from fastapi import FastAPI, HTTPException | |
import streamlit as st | |
import pandas as pd | |
from pydantic import BaseModel, Field, validator | |
import numpy as np | |
import plotly.graph_objects as go | |
from azure_openai import converse_with_patient, create_diagnosis | |
from memory import get_conversation, store_conversation, update_conversation | |
import uuid | |
class ask_question (BaseModel): | |
user_input: str | |
id: str | |
app = FastAPI() | |
def generate_expert_confidence_chart(diagnosis): | |
""" | |
Extracts expert confidence data from JSON and generates a multi-colored bar chart. | |
""" | |
# Extract expert distribution data | |
expert_distribution = diagnosis["expert_distribution"] | |
# Process the data into a structured format | |
rows = [] | |
for key, value in expert_distribution.items(): | |
expert, attribute = key.rsplit(", ", 1) # Ensure splitting at the last comma | |
rows.append({"Expert": expert, "Attribute": attribute, "Value": value}) | |
# Create a DataFrame | |
df = pd.DataFrame(rows) | |
# Filter the DataFrame for confidence values only | |
df_confidence = df[df["Attribute"] == "confidence"].copy() | |
# Merge confidence values with corresponding thinking explanations | |
df_thinking = df[df["Attribute"] == "thinking"].copy() | |
df_confidence = df_confidence.merge(df_thinking, on="Expert", suffixes=("_confidence", "_thinking")) | |
# Convert confidence values to numeric | |
df_confidence["Value_confidence"] = pd.to_numeric(df_confidence["Value_confidence"]) | |
# Define a function to map confidence scores to colors | |
def confidence_to_color(confidence): | |
""" | |
Maps confidence score (0-100) to a blended color between red (0 confidence) and green (100 confidence). | |
""" | |
red = np.array([255, 0, 0]) | |
green = np.array([0, 255, 0]) | |
blend_ratio = confidence / 100 # Normalize between 0 and 1 | |
blended_color = (1 - blend_ratio) * red + blend_ratio * green | |
return f"rgb({int(blended_color[0])}, {int(blended_color[1])}, {int(blended_color[2])})" | |
# Apply color mapping | |
df_confidence["Color"] = df_confidence["Value_confidence"].apply(confidence_to_color) | |
# Create the bar chart | |
fig = go.Figure() | |
# Add bars with customized colors and reduced spacing | |
fig.add_trace(go.Bar( | |
y=df_confidence["Expert"], | |
x=df_confidence["Value_confidence"], | |
text=df_confidence["Value_confidence"], | |
hovertext=df_confidence["Value_thinking"], | |
orientation="h", | |
marker=dict(color=df_confidence["Color"]), | |
width=0.3, # Reduce bar width for closer spacing | |
textposition="inside" | |
)) | |
# Update layout for better visibility | |
fig.update_layout( | |
title="Expert Confidence in Diagnosis", | |
xaxis_title="Confidence Score", | |
yaxis_title="Medical Expert", | |
yaxis=dict(tickmode="linear", dtick=1, automargin=True), | |
height=max(400, 40 * len(df_confidence)), # Adjust height dynamically | |
bargap=0.1 # Reduce spacing between bars | |
) | |
# Update hover template | |
fig.update_traces( | |
hovertemplate="<b>%{y}</b><br>Confidence: %{x}%<br>Thinking: %{hovertext}" | |
) | |
# Show the plot | |
return fig | |
# FastAPI interface routes | |
# @app.get("/") | |
# async def root(): | |
# return {"message": "Welcome to the GenAI Symptom Checker"} | |
# @app.post("/ask") | |
# async def ask_question(question: ask_question): | |
# try: | |
# user_input = question.user_input | |
# conversation_id = question.id | |
# exists, count, conversation_obj = get_conversation(conversation_id) | |
# if count == 6: | |
# response = converse_with_patient(conversation_obj, user_input) | |
# store_conversation(conversation_id, conversation_id, user_input, response) | |
# exists, count, conversation_obj = get_conversation(conversation_id) | |
# diagnosis = create_diagnosis(conversation_obj) | |
# return {"response": response, "count": count, "diagnosis": diagnosis} | |
# if count > 6: | |
# exists, count, conversation_obj = get_conversation(conversation_id) | |
# diagnosis_content = next((item['content'] for item in conversation_obj if item['role'] == 'diagnosis'), None) | |
# return {"response": "You have reached the maximum number of questions", "count": count, "diagnosis": diagnosis_content} | |
# if exists == "PASS": | |
# response = converse_with_patient(conversation_obj, user_input) | |
# update_conversation(conversation_id, conversation_id, user_input, response) | |
# return {"response": response, "count": count, "diagnosis": "none"} | |
# else: | |
# response = converse_with_patient("",user_input) | |
# store_conversation(conversation_id, conversation_id, user_input, response) | |
# return {"response": response, "count": count, "diagnosis": "none"} | |
# except Exception as e: | |
# raise HTTPException(status_code=500, detail=str(e)) | |
# app config | |
st.set_page_config(page_title="virtual clinician", page_icon=":medical_symbol:") | |
st.title("Virtual Clinician :medical_symbol:") | |
user_id = st.text_input("Name:", key="user_id") | |
conversation_id = user_id | |
# Ensure user_id is defined or fallback to a default value | |
if not user_id: | |
st.warning("Hi, Who am I speaking with?") | |
else: | |
# session state | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [ | |
{"role": "AI", "content": f"Hello, {user_id} I am the virtual clinician. How can I help you today?"}, | |
] | |
# conversation | |
for message in st.session_state.chat_history: | |
if message["role"] == "AI": | |
with st.chat_message("AI"): | |
st.write(message["content"]) | |
elif message["role"] == "Human": | |
with st.chat_message("Human"): | |
st.write(message["content"]) | |
# user input | |
user_input = st.chat_input("Type your message here...") | |
if user_input is not None and user_input != "": | |
st.session_state.chat_history.append({"role": "Human", "content": user_input}) | |
with st.chat_message("Human"): | |
st.markdown(user_input) | |
# this functions checks to see if the conversation exists | |
exists, count, conversation_obj = get_conversation(conversation_id) | |
# if the conversation does not exist, it creates a new conversation | |
if count > 5: | |
response = converse_with_patient(conversation_obj, user_input) | |
conversation_obj = update_conversation(conversation_id, user_input, response) | |
print(conversation_obj) | |
with st.spinner("Creating a diagnosis..."): | |
outcome, diagnosis = create_diagnosis(conversation_obj) | |
if outcome == "SUCCESS": | |
st.subheader("Diagnosis Summary") | |
st.write(f"**Diagnosis:** {diagnosis['concensus_diagnosis']}") | |
st.write(f"**Consensus Confidence:** {diagnosis['concensus_confidence']}%") | |
st.write(f"**Consensus Thinking:** {diagnosis['concensus_thinking']}") | |
st.write(f"**Evaluation Confidence:** {diagnosis['evaluate_confidence']}%") | |
st.write(f"**Evaluation Explanation:** {diagnosis['evaluate_explanation']}") | |
st.write(f"**Next Best Action:** {diagnosis['next_best_action_']}") | |
st.write(f"**Next Best Action Explanation:** {diagnosis['next_best_action_explanation']}") | |
st.write(f"**Next Best Action Confidence:** {diagnosis['next_best_action_confidence']}%") | |
# Generate and display the plotly chart | |
st.subheader("Expert Confidence Levels") | |
fig = generate_expert_confidence_chart(diagnosis) | |
st.plotly_chart(fig) | |
# if the diagnosis is not successful, display a message | |
if outcome == "FAIL1": | |
st.write("Diagnosis not available Failed to find concensus") | |
st.subheader("Incomplete Diagnosis") | |
st.write(f"**Diagnosis:** {diagnosis['concensus_diagnosis']}") | |
st.write(f"**Consensus Confidence:** {diagnosis['concensus_confidence']}%") | |
st.write(f"**Consensus Thinking:** {diagnosis['concensus_thinking']}") | |
st.write(f"**Next Best Action:** See GP") | |
st.write(f"**Next Best Action Explanation:** Please give more details to help the AI better understand your symptoms ") | |
# Generate and display the plotly chart | |
st.subheader("Expert Confidence Levels") | |
fig = generate_expert_confidence_chart(diagnosis) | |
st.plotly_chart(fig) | |
if outcome == "FAIL2": | |
st.write("Diagnosis not available Failed to match described symptoms with know symptoms for AI diagnosis") | |
st.subheader("Incomplete Diagnosis") | |
st.write(f"**Diagnosis:** {diagnosis['concensus_diagnosis']}") | |
st.write(f"**Consensus Confidence:** {diagnosis['concensus_confidence']}%") | |
st.write(f"**Consensus Thinking:** {diagnosis['concensus_thinking']}") | |
st.write(f"**Evaluation Confidence:** {diagnosis['evaluate_confidence']}%") | |
st.write(f"**Evaluation Explanation:** {diagnosis['evaluate_explanation']}") | |
st.write(f"**Next Best Action:** See GP") | |
st.write(f"**Next Best Action Explanation:** Please give more details to help the AI better understand your symptoms ") | |
# Generate and display the plotly chart | |
st.subheader("Expert Confidence Levels") | |
fig = generate_expert_confidence_chart(diagnosis) | |
st.plotly_chart(fig) | |
if exists == "PASS": | |
response = converse_with_patient(conversation_obj, user_input) | |
update_conversation(conversation_id, user_input, response) | |
st.session_state.chat_history.append({"role": "AI", "content": response}) | |
with st.chat_message("AI"): | |
st.write(response) | |
else: | |
response = converse_with_patient("",user_input) | |
store_conversation(conversation_id, user_input, response) | |
st.session_state.chat_history.append({"role": "AI", "content": response}) | |
with st.chat_message("AI"): | |
st.write(response) | |