MedQA / pages /2_Consult.py
mgbam's picture
Update pages/2_Consult.py
45ea80d verified
raw
history blame
15.7 kB
# /home/user/app/pages/2_Consult.py
import streamlit as st
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from datetime import datetime
from typing import List, Optional, Dict, Any
from sqlmodel import select
from config.settings import settings
from agent import get_agent_executor # This now returns the OpenAI-based agent executor
from models import ChatMessage, ChatSession
from models.db import get_session_context
from services.logger import app_logger
from services.metrics import log_consultation_start # Assuming this function exists
# --- Authentication Check ---
if not st.session_state.get("authenticated_user_id"):
st.warning("Please log in to access the consultation page.")
try:
st.switch_page("app.py") # Redirect to the main login page (app.py)
except st.errors.StreamlitAPIException:
# This can happen if st.switch_page is called when not in a multipage app context (e.g. dev)
st.info("Please navigate to the main login page.")
st.stop() # Halt script execution
authenticated_user_id = st.session_state.get("authenticated_user_id")
authenticated_username = st.session_state.get("authenticated_username", "User") # Default to "User"
app_logger.info(f"User '{authenticated_username}' (ID: {authenticated_user_id}) accessed Consult page.")
# --- Initialize Agent ---
# This will now initialize the OpenAI agent via get_agent_executor() from agent.py
try:
agent_executor = get_agent_executor()
app_logger.info("OpenAI-based agent executor initialized successfully for Consult page.")
except ValueError as e: # Catch specific error from get_agent_executor if API key is missing
st.error(f"AI Agent Initialization Error: {e}")
app_logger.critical(f"Fatal: AI Agent could not be initialized in Consult page: {e}", exc_info=True)
st.info("Please ensure the OPENAI_API_KEY is correctly configured in the application settings (Hugging Face Secrets).")
st.stop()
except Exception as e: # Catch any other unexpected error during agent init
st.error(f"An unexpected error occurred while initializing the AI Agent: {e}")
app_logger.critical(f"Fatal: Unexpected AI Agent initialization error: {e}", exc_info=True)
st.stop()
# --- Session State for Consult Page ---
# Using more descriptive key for the patient context dictionary
if 'current_consult_patient_context_dict' not in st.session_state:
st.session_state.current_consult_patient_context_dict = {}
if 'consult_context_submitted' not in st.session_state:
st.session_state.consult_context_submitted = False
# --- Helper Functions ---
@st.cache_data(ttl=30, show_spinner=False, max_entries=10) # Cache agent history for current session
def load_chat_history_for_agent(session_id: int) -> List[Any]: # List of LangChain BaseMessage
"""Loads chat history from DB, formatted for LangChain agent (HumanMessage, AIMessage, SystemMessage)."""
messages = []
app_logger.debug(f"Loading agent chat history from DB for session_id: {session_id}")
try:
with get_session_context() as db:
statement = select(ChatMessage).where(ChatMessage.session_id == session_id).order_by(ChatMessage.timestamp)
db_messages = db.exec(statement).all()
for msg in db_messages:
if msg.role == "user": messages.append(HumanMessage(content=msg.content))
elif msg.role == "assistant": messages.append(AIMessage(content=msg.content))
elif msg.role == "system": messages.append(SystemMessage(content=msg.content))
# ToolMessages are usually part of agent_scratchpad, not general chat_history for invoke
app_logger.debug(f"Loaded {len(messages)} LangChain messages for agent history (session {session_id}).")
except Exception as e:
app_logger.error(f"Error loading chat history for session {session_id}: {e}", exc_info=True)
st.toast(f"Error loading chat history: {e}", icon="⚠️") # Non-blocking error
return messages
def save_chat_message_to_db(session_id: int, role: str, content: str, tool_call_id: Optional[str]=None, tool_name: Optional[str]=None):
"""Saves a chat message to the database."""
app_logger.debug(f"Saving message to DB for session {session_id}: Role='{role}', Content snippet='{content[:50]}...'")
try:
with get_session_context() as db:
chat_message_obj = ChatMessage(
session_id=session_id, role=role, content=content, timestamp=datetime.utcnow(),
tool_call_id=tool_call_id, tool_name=tool_name
)
db.add(chat_message_obj) # Commit will be handled by the context manager
app_logger.info(f"Message (Role: {role}) saved to DB for session {session_id}.")
except Exception as e:
app_logger.error(f"Error saving chat message to DB for session {session_id}: {e}", exc_info=True)
st.toast(f"Error saving message: {e}", icon="⚠️")
def update_chat_session_with_context_summary_in_db(session_id: int, context_summary: str):
"""Updates the ChatSession record with the patient context summary."""
try:
with get_session_context() as db:
session_to_update = db.get(ChatSession, session_id) # Fetch by primary key
if session_to_update:
session_to_update.patient_context_summary = context_summary
db.add(session_to_update) # SQLModel handles add for updates if PK is present
app_logger.info(f"Updated ChatSession {session_id} with patient context summary in DB.")
else:
app_logger.error(f"Could not find ChatSession {session_id} in DB to update context summary.")
except Exception as e:
app_logger.error(f"Error updating chat session {session_id} context summary: {e}", exc_info=True)
st.toast(f"Error saving context summary: {e}", icon="⚠️")
# --- Page Logic ---
st.title("AI Consultation Room")
st.markdown(f"Interacting as: **{authenticated_username}**")
# Prominent disclaimer on the consult page itself
st.warning(f"**Reminder & Disclaimer:** {settings.MAIN_DISCLAIMER_LONG} {settings.SIMULATION_DISCLAIMER}")
chat_session_id = st.session_state.get("current_chat_session_id")
if not chat_session_id:
st.error("Error: No active chat session ID found. This is unexpected after login. Please try logging out and logging back in. If the problem persists, please contact support.")
app_logger.critical(f"User '{authenticated_username}' (ID: {authenticated_user_id}) on Consult page encountered MISSING current_chat_session_id.")
st.stop()
# --- Patient Context Input Form ---
if not st.session_state.consult_context_submitted:
st.subheader("Step 1: Provide Patient Context (Optional, Use Simulated Data Only)")
with st.form(key="patient_context_form_consult_page_openai"): # Unique key
st.markdown("**Crucial Reminder: Use only anonymized, simulated data for this demonstration. Do NOT enter real PHI.**")
age_in = st.number_input("Patient Age (Simulated)", min_value=0, max_value=120, step=1, value=None, help="Leave blank if not applicable.")
gender_options = ["Not Specified", "Male", "Female", "Other"]
gender_in = st.selectbox("Patient Gender (Simulated)", gender_options, index=0)
cc_in = st.text_area("Chief Complaint / Reason for Consult (Simulated)", height=100, placeholder="e.g., Persistent cough for 2 weeks, fatigue")
hist_in = st.text_area("Key Medical History (Simulated)", height=100, placeholder="e.g., Type 2 Diabetes (controlled), Hypertension (on medication), Asthma (mild intermittent)")
meds_in = st.text_area("Current Medications (Simulated)", height=100, placeholder="e.g., Metformin 500mg BID, Lisinopril 10mg OD, Salbutamol inhaler PRN")
submit_context_btn = st.form_submit_button("Start Consult with this Context")
if submit_context_btn:
raw_context = {
"Age": age_in, "Gender": gender_in, "Chief Complaint": cc_in,
"Key Medical History": hist_in, "Current Medications": meds_in,
}
# Filter out None/empty/"Not Specified" values for a cleaner context dictionary
filtered_context_dict = {
k: v for k, v in raw_context.items()
if v is not None and str(v).strip() and \
(isinstance(v, str) and v.lower() != "not specified") and \
(isinstance(v, int) and v > 0 or not isinstance(v, int)) # ensure age > 0 if int
}
st.session_state.current_consult_patient_context_dict = filtered_context_dict
if filtered_context_dict:
context_summary_str = "; ".join([f"{k}: {v}" for k, v in filtered_context_dict.items()])
else:
context_summary_str = "No specific patient context was provided for this session."
update_chat_session_with_context_summary_in_db(chat_session_id, context_summary_str)
# Save a system message to DB indicating context was set for this session
save_chat_message_to_db(chat_session_id, "system", f"Initial Patient Context Set: {context_summary_str}")
st.session_state.consult_context_submitted = True
app_logger.info(f"Patient context submitted for session {chat_session_id}: {context_summary_str}")
st.rerun() # Rerun to hide form and proceed to chat interface
st.stop() # Don't proceed to chat interface until context form is handled
# --- Chat Interface (Shown after context is submitted) ---
st.subheader("Step 2: Interact with AI Health Navigator")
agent_history_key = f"agent_chat_history_{chat_session_id}" # Session-specific key for agent's message history
# Initialize or load agent's chat history (list of LangChain messages)
if agent_history_key not in st.session_state:
st.session_state[agent_history_key] = load_chat_history_for_agent(chat_session_id)
if not st.session_state[agent_history_key]: # If history is empty after loading from DB
try:
log_consultation_start(user_id=authenticated_user_id, session_id=chat_session_id)
except Exception as e_metric:
app_logger.warning(f"Failed to log consultation start metric: {e_metric}")
initial_ai_message_content = "Hello! I am your AI Health Navigator. How can I assist you today?"
patient_context_for_greeting = st.session_state.get('current_consult_patient_context_dict', {})
if patient_context_for_greeting: # Check if the dict itself is non-empty
initial_ai_message_content += " I have noted the patient context you provided."
# Add initial AI message to agent's history and save to DB
st.session_state[agent_history_key].append(AIMessage(content=initial_ai_message_content))
save_chat_message_to_db(chat_session_id, "assistant", initial_ai_message_content)
app_logger.info(f"Initialized new consultation (session {chat_session_id}) with a greeting.")
# Display chat messages for UI (fetch fresh from DB for UI consistency)
# This uses a scrollable container.
chat_display_container = st.container(height=450)
with chat_display_container:
with get_session_context() as db:
stmt = select(ChatMessage).where(ChatMessage.session_id == chat_session_id).order_by(ChatMessage.timestamp)
ui_messages_from_db = db.exec(stmt).all()
for msg in ui_messages_from_db:
if msg.role == "system": continue # Don't display raw system messages in chat UI
avatar_icon = "πŸ§‘β€βš•οΈ" if msg.role == "assistant" else "πŸ‘€"
if msg.role == "tool": avatar_icon = "πŸ› οΈ" # Example if you log tool messages for UI
with st.chat_message(msg.role, avatar=avatar_icon):
st.markdown(msg.content) # Add source/confidence here if msg object from DB supports it
# Chat input from user
if user_prompt := st.chat_input("Ask the AI... (e.g., 'What is hypertension?')"):
# Display user message in UI immediately
with chat_display_container: # Also add to the container
with st.chat_message("user", avatar="πŸ‘€"):
st.markdown(user_prompt)
save_chat_message_to_db(chat_session_id, "user", user_prompt)
st.session_state[agent_history_key].append(HumanMessage(content=user_prompt))
# Get AI response
with chat_display_container: # Add AI response to the container
with st.chat_message("assistant", avatar="πŸ§‘β€βš•οΈ"):
thinking_message = st.empty() # Placeholder for "AI is thinking..."
thinking_message.markdown("β–Œ") # Simple animated cursor
try:
# Prepare patient context string for the agent
patient_context_dict = st.session_state.get('current_consult_patient_context_dict', {})
if patient_context_dict: # If there's any context
context_parts_for_invoke = [f"{k}: {v}" for k, v in patient_context_dict.items()]
patient_context_str_for_invoke = "; ".join(context_parts_for_invoke)
else: # If no context was provided or all fields were empty/default
patient_context_str_for_invoke = "No specific patient context was provided by the user for this interaction."
# These are the keys expected by the OpenAI Functions Agent prompt
invoke_payload = {
"input": user_prompt,
"chat_history": st.session_state[agent_history_key], # List of BaseMessage
"patient_context": patient_context_str_for_invoke
}
app_logger.debug(f"Invoking OpenAI agent with payload: {invoke_payload}")
thinking_message.markdown("AI is thinking...") # Update spinner text
response = agent_executor.invoke(invoke_payload)
ai_response_content = response.get('output', "I could not generate a valid response at this time.")
if not isinstance(ai_response_content, str): ai_response_content = str(ai_response_content)
app_logger.info(f"OpenAI Agent response for session {chat_session_id}: '{ai_response_content[:100]}...'")
thinking_message.empty() # Clear "thinking..." message
st.markdown(ai_response_content) # Display AI response
save_chat_message_to_db(chat_session_id, "assistant", ai_response_content)
st.session_state[agent_history_key].append(AIMessage(content=ai_response_content))
except Exception as e:
app_logger.error(f"Error during OpenAI agent invocation for session {chat_session_id}: {e}", exc_info=True)
error_type_name = type(e).__name__
user_friendly_error = f"Sorry, an error occurred ({error_type_name}). Please try rephrasing your query or contact support if the issue persists."
thinking_message.empty() # Clear "thinking..." message
st.error(user_friendly_error) # Display error in the AI's bubble
db_error_msg = f"System encountered an error: {error_type_name}. Details logged."
save_chat_message_to_db(chat_session_id, "assistant", db_error_msg)
st.session_state[agent_history_key].append(AIMessage(content=f"Note to self: Encountered error ({error_type_name})."))
st.rerun() # Rerun to ensure the new messages are at the bottom of the scrollable container