from fastapi import FastAPI, HTTPException import streamlit as st import pandas as pd from pydantic import BaseModel, Field, validator import numpy as np import plotly.graph_objects as go from azure_openai import converse_with_patient, create_diagnosis from memory import get_conversation, store_conversation, update_conversation import uuid class ask_question (BaseModel): user_input: str id: str app = FastAPI() def generate_expert_confidence_chart(diagnosis): """ Extracts expert confidence data from JSON and generates a multi-colored bar chart. """ # Extract expert distribution data expert_distribution = diagnosis["expert_distribution"] # Process the data into a structured format rows = [] for key, value in expert_distribution.items(): expert, attribute = key.rsplit(", ", 1) # Ensure splitting at the last comma rows.append({"Expert": expert, "Attribute": attribute, "Value": value}) # Create a DataFrame df = pd.DataFrame(rows) # Filter the DataFrame for confidence values only df_confidence = df[df["Attribute"] == "confidence"].copy() # Merge confidence values with corresponding thinking explanations df_thinking = df[df["Attribute"] == "thinking"].copy() df_confidence = df_confidence.merge(df_thinking, on="Expert", suffixes=("_confidence", "_thinking")) # Convert confidence values to numeric df_confidence["Value_confidence"] = pd.to_numeric(df_confidence["Value_confidence"]) # Define a function to map confidence scores to colors def confidence_to_color(confidence): """ Maps confidence score (0-100) to a blended color between red (0 confidence) and green (100 confidence). """ red = np.array([255, 0, 0]) green = np.array([0, 255, 0]) blend_ratio = confidence / 100 # Normalize between 0 and 1 blended_color = (1 - blend_ratio) * red + blend_ratio * green return f"rgb({int(blended_color[0])}, {int(blended_color[1])}, {int(blended_color[2])})" # Apply color mapping df_confidence["Color"] = df_confidence["Value_confidence"].apply(confidence_to_color) # Create the bar chart fig = go.Figure() # Add bars with customized colors and reduced spacing fig.add_trace(go.Bar( y=df_confidence["Expert"], x=df_confidence["Value_confidence"], text=df_confidence["Value_confidence"], hovertext=df_confidence["Value_thinking"], orientation="h", marker=dict(color=df_confidence["Color"]), width=0.3, # Reduce bar width for closer spacing textposition="inside" )) # Update layout for better visibility fig.update_layout( title="Expert Confidence in Diagnosis", xaxis_title="Confidence Score", yaxis_title="Medical Expert", yaxis=dict(tickmode="linear", dtick=1, automargin=True), height=max(400, 40 * len(df_confidence)), # Adjust height dynamically bargap=0.1 # Reduce spacing between bars ) # Update hover template fig.update_traces( hovertemplate="%{y}
Confidence: %{x}%
Thinking: %{hovertext}" ) # Show the plot return fig # FastAPI interface routes # @app.get("/") # async def root(): # return {"message": "Welcome to the GenAI Symptom Checker"} # @app.post("/ask") # async def ask_question(question: ask_question): # try: # user_input = question.user_input # conversation_id = question.id # exists, count, conversation_obj = get_conversation(conversation_id) # if count == 6: # response = converse_with_patient(conversation_obj, user_input) # store_conversation(conversation_id, conversation_id, user_input, response) # exists, count, conversation_obj = get_conversation(conversation_id) # diagnosis = create_diagnosis(conversation_obj) # return {"response": response, "count": count, "diagnosis": diagnosis} # if count > 6: # exists, count, conversation_obj = get_conversation(conversation_id) # diagnosis_content = next((item['content'] for item in conversation_obj if item['role'] == 'diagnosis'), None) # return {"response": "You have reached the maximum number of questions", "count": count, "diagnosis": diagnosis_content} # if exists == "PASS": # response = converse_with_patient(conversation_obj, user_input) # update_conversation(conversation_id, conversation_id, user_input, response) # return {"response": response, "count": count, "diagnosis": "none"} # else: # response = converse_with_patient("",user_input) # store_conversation(conversation_id, conversation_id, user_input, response) # return {"response": response, "count": count, "diagnosis": "none"} # except Exception as e: # raise HTTPException(status_code=500, detail=str(e)) # app config st.set_page_config(page_title="virtual clinician", page_icon=":medical_symbol:") st.title("Virtual Clinician :medical_symbol:") user_id = st.text_input("Name:", key="user_id") conversation_id = user_id # Ensure user_id is defined or fallback to a default value if not user_id: st.warning("Hi, Who am I speaking with?") else: # session state if "chat_history" not in st.session_state: st.session_state.chat_history = [ {"role": "AI", "content": f"Hello, {user_id} I am the virtual clinician. How can I help you today?"}, ] # conversation for message in st.session_state.chat_history: if message["role"] == "AI": with st.chat_message("AI"): st.write(message["content"]) elif message["role"] == "Human": with st.chat_message("Human"): st.write(message["content"]) # user input user_input = st.chat_input("Type your message here...") if user_input is not None and user_input != "": st.session_state.chat_history.append({"role": "Human", "content": user_input}) with st.chat_message("Human"): st.markdown(user_input) # this functions checks to see if the conversation exists exists, count, conversation_obj = get_conversation(conversation_id) # if the conversation does not exist, it creates a new conversation if count > 5: response = converse_with_patient(conversation_obj, user_input) conversation_obj = update_conversation(conversation_id, user_input, response) print(conversation_obj) with st.spinner("Creating a diagnosis..."): outcome, diagnosis = create_diagnosis(conversation_obj) if outcome == "SUCCESS": st.subheader("Diagnosis Summary") st.write(f"**Diagnosis:** {diagnosis['concensus_diagnosis']}") st.write(f"**Consensus Confidence:** {diagnosis['concensus_confidence']}%") st.write(f"**Consensus Thinking:** {diagnosis['concensus_thinking']}") st.write(f"**Evaluation Confidence:** {diagnosis['evaluate_confidence']}%") st.write(f"**Evaluation Explanation:** {diagnosis['evaluate_explanation']}") st.write(f"**Next Best Action:** {diagnosis['next_best_action_']}") st.write(f"**Next Best Action Explanation:** {diagnosis['next_best_action_explanation']}") st.write(f"**Next Best Action Confidence:** {diagnosis['next_best_action_confidence']}%") # Generate and display the plotly chart st.subheader("Expert Confidence Levels") fig = generate_expert_confidence_chart(diagnosis) st.plotly_chart(fig) # if the diagnosis is not successful, display a message if outcome == "FAIL1": st.write("Diagnosis not available Failed to find concensus") st.subheader("Incomplete Diagnosis") st.write(f"**Diagnosis:** {diagnosis['concensus_diagnosis']}") st.write(f"**Consensus Confidence:** {diagnosis['concensus_confidence']}%") st.write(f"**Consensus Thinking:** {diagnosis['concensus_thinking']}") st.write(f"**Next Best Action:** See GP") st.write(f"**Next Best Action Explanation:** Please give more details to help the AI better understand your symptoms ") # Generate and display the plotly chart st.subheader("Expert Confidence Levels") fig = generate_expert_confidence_chart(diagnosis) st.plotly_chart(fig) if outcome == "FAIL2": st.write("Diagnosis not available Failed to match described symptoms with know symptoms for AI diagnosis") st.subheader("Incomplete Diagnosis") st.write(f"**Diagnosis:** {diagnosis['concensus_diagnosis']}") st.write(f"**Consensus Confidence:** {diagnosis['concensus_confidence']}%") st.write(f"**Consensus Thinking:** {diagnosis['concensus_thinking']}") st.write(f"**Evaluation Confidence:** {diagnosis['evaluate_confidence']}%") st.write(f"**Evaluation Explanation:** {diagnosis['evaluate_explanation']}") st.write(f"**Next Best Action:** See GP") st.write(f"**Next Best Action Explanation:** Please give more details to help the AI better understand your symptoms ") # Generate and display the plotly chart st.subheader("Expert Confidence Levels") fig = generate_expert_confidence_chart(diagnosis) st.plotly_chart(fig) if exists == "PASS": response = converse_with_patient(conversation_obj, user_input) update_conversation(conversation_id, user_input, response) st.session_state.chat_history.append({"role": "AI", "content": response}) with st.chat_message("AI"): st.write(response) else: response = converse_with_patient("",user_input) store_conversation(conversation_id, user_input, response) st.session_state.chat_history.append({"role": "AI", "content": response}) with st.chat_message("AI"): st.write(response)