MedQA / pages /2_Consult.py
mgbam's picture
Update pages/2_Consult.py
861d237 verified
raw
history blame
11 kB
# /home/user/app/pages/2_Consult.py
import streamlit as st
from langchain_core.messages import HumanMessage, AIMessage
from datetime import datetime
from typing import List, Optional, Dict, Any
from sqlmodel import select
from config.settings import settings
from agent import get_agent_executor # Assumes this is your LangChain agent
from models import ChatMessage, ChatSession
from models.db import get_session_context
from services.logger import app_logger
from services.metrics import log_consultation_start
# --- Authentication Check ---
if not st.session_state.get("authenticated_user_id"):
st.warning("Please log in to access the consultation page.")
try: st.switch_page("app.py")
except st.errors.StreamlitAPIException: st.info("Please navigate to the main login page.")
st.stop()
authenticated_user_id = st.session_state.get("authenticated_user_id")
authenticated_username = st.session_state.get("authenticated_username", "User")
app_logger.info(f"User '{authenticated_username}' (ID: {authenticated_user_id}) accessed Consult page.")
# --- Initialize Agent ---
try:
agent_executor = get_agent_executor()
except Exception as e:
st.error(f"Fatal Error: Could not initialize AI Agent: {e}. Please check API keys and configurations.")
app_logger.critical(f"AI Agent initialization failed: {e}", exc_info=True)
st.stop()
# --- Session State for Consult Page ---
if 'current_consult_patient_context' not in st.session_state:
st.session_state.current_consult_patient_context = {} # Stores structured context for current consult
if 'consult_context_submitted' not in st.session_state:
st.session_state.consult_context_submitted = False
# --- Helper Functions ---
@st.cache_data(ttl=30, show_spinner=False) # Short cache for agent history
def load_chat_history_for_agent(session_id: int) -> List:
messages = []
# ... (load_chat_history_for_agent from previous full rewrite of 2_Consult.py, using SQLModel select) ...
# This function should convert DB ChatMessages to LangChain HumanMessage/AIMessage
app_logger.debug(f"Loading agent chat history for session_id: {session_id}")
with get_session_context() as db:
statement = select(ChatMessage).where(ChatMessage.session_id == session_id).order_by(ChatMessage.timestamp)
db_messages = db.exec(statement).all()
for msg in db_messages:
if msg.role == "user": messages.append(HumanMessage(content=msg.content))
elif msg.role == "assistant": messages.append(AIMessage(content=msg.content))
return messages
def save_chat_message_to_db(session_id: int, role: str, content: str, tool_call_id: Optional[str]=None, tool_name: Optional[str]=None):
# ... (save_chat_message_to_db from previous full rewrite of 2_Consult.py) ...
app_logger.debug(f"Saving message to DB for session {session_id}: Role={role}")
with get_session_context() as db:
chat_message = ChatMessage(
session_id=session_id, role=role, content=content, timestamp=datetime.utcnow(),
tool_call_id=tool_call_id, tool_name=tool_name
)
db.add(chat_message) # Commit handled by context manager
app_logger.info(f"Message saved to DB for session {session_id}. Role: {role}.")
def update_chat_session_with_context(session_id: int, context_summary: str):
with get_session_context() as db:
session_to_update = db.get(ChatSession, session_id)
if session_to_update:
session_to_update.patient_context_summary = context_summary
db.add(session_to_update) # Stage for commit
app_logger.info(f"Updated ChatSession {session_id} with patient context summary.")
else:
app_logger.error(f"Could not find ChatSession {session_id} to update with context.")
# --- Page Logic ---
st.title("AI Consultation Room")
st.markdown(f"Interacting as: **{authenticated_username}**")
st.info(settings.MAIN_DISCLAIMER_SHORT + " Do not enter real PHI.")
chat_session_id = st.session_state.get("current_chat_session_id")
if not chat_session_id:
st.error("No active chat session. This may occur if you logged out and back in. A new session was created. If issues persist, please re-login fully or contact support.")
app_logger.error(f"User '{authenticated_username}' on Consult page with no current_chat_session_id.")
# Attempt to create a new one if truly missing, or guide to re-login
# For now, stopping is safer if app.py is supposed to always create one.
st.stop()
# --- Patient Context Input Form ---
if not st.session_state.consult_context_submitted:
st.subheader("Optional: Provide Patient Context (Simulated Data Only)")
with st.form(key="patient_context_form"):
st.markdown("**Reminder: Use only anonymized, simulated data for this demonstration.**")
age = st.number_input("Patient Age (Simulated)", min_value=0, max_value=120, step=1)
gender = st.selectbox("Patient Gender (Simulated)", ["Not Specified", "Male", "Female", "Other"])
chief_complaint = st.text_area("Chief Complaint / Reason for Consult (Simulated)", height=100)
key_history = st.text_area("Key Medical History (Simulated - e.g., diabetes, hypertension)", height=100)
current_meds = st.text_area("Current Medications (Simulated - e.g., metformin, lisinopril)", height=100)
submit_context_button = st.form_submit_button("Start Consult with this Context")
if submit_context_button:
context = {
"age": age if age > 0 else "Not Specified",
"gender": gender,
"chief_complaint": chief_complaint.strip() if chief_complaint.strip() else "Not Specified",
"key_medical_history": key_history.strip() if key_history.strip() else "Not Specified",
"current_medications": current_meds.strip() if current_meds.strip() else "Not Specified",
}
st.session_state.current_consult_patient_context = context
st.session_state.consult_context_submitted = True
# Create a summary for the agent and DB
context_summary_parts = [f"{k.replace('_', ' ').title()}: {v}" for k, v in context.items() if v != "Not Specified" and v != ""]
context_summary_for_agent = "Patient Context: " + "; ".join(context_summary_parts) if context_summary_parts else "No specific patient context provided."
# Save context summary to ChatSession model
update_chat_session_with_context(chat_session_id, context_summary_for_agent)
# Prepend context to agent's chat history as a system message or initial user message
# For this example, let's add it as a system message to guide the AI
agent_history_key = f"agent_chat_history_{chat_session_id}"
if agent_history_key not in st.session_state: st.session_state[agent_history_key] = []
st.session_state[agent_history_key].insert(0, SystemMessage(content=context_summary_for_agent))
# Also save this "system" context message to DB for record keeping if desired
save_chat_message_to_db(chat_session_id, "system", context_summary_for_agent)
app_logger.info(f"Patient context submitted for session {chat_session_id}: {context_summary_for_agent}")
st.rerun() # Rerun to hide form and show chat
st.stop() # Don't proceed to chat until context is submitted or skipped
# --- Chat Interface (Shown after context is submitted/skipped) ---
agent_history_key = f"agent_chat_history_{chat_session_id}"
if agent_history_key not in st.session_state:
st.session_state[agent_history_key] = load_chat_history_for_agent(chat_session_id)
if not st.session_state[agent_history_key]: # If history is empty (even after context attempt)
try: log_consultation_start(user_id=authenticated_user_id, session_id=chat_session_id)
except Exception as e: app_logger.warning(f"Failed to log consultation start: {e}")
initial_ai_message_content = "Hello! I am your AI Health Navigator. How can I assist you today?"
st.session_state[agent_history_key].append(AIMessage(content=initial_ai_message_content))
save_chat_message_to_db(chat_session_id, "assistant", initial_ai_message_content)
# Display chat messages from DB for UI
with st.container():
with get_session_context() as db:
stmt = select(ChatMessage).where(ChatMessage.session_id == chat_session_id).order_by(ChatMessage.timestamp)
ui_messages = db.exec(stmt).all()
for msg in ui_messages:
if msg.role == "system": # Don't show system context messages directly in chat UI
continue
avatar = "πŸ§‘β€βš•οΈ" if msg.role == "assistant" else "πŸ‘€"
if msg.role == "tool": avatar = "πŸ› οΈ"
with st.chat_message(msg.role, avatar=avatar):
st.markdown(msg.content) # Add source/confidence here if msg object supports it
if prompt := st.chat_input("Ask the AI..."):
with st.chat_message("user", avatar="πŸ‘€"): st.markdown(prompt)
save_chat_message_to_db(chat_session_id, "user", prompt)
st.session_state[agent_history_key].append(HumanMessage(content=prompt))
with st.chat_message("assistant", avatar="πŸ§‘β€βš•οΈ"):
with st.spinner("AI is thinking..."):
try:
# Pass patient context if your agent is designed to use it explicitly
# current_context = st.session_state.get('current_consult_patient_context', {})
# context_str_for_invoke = "; ".join([f"{k}: {v}" for k,v in current_context.items() if v and v!="Not Specified"])
response = agent_executor.invoke({
"input": prompt,
"chat_history": st.session_state[agent_history_key],
# "patient_context": context_str_for_invoke # If agent expects this
})
ai_response_content = response.get('output', "I could not generate a response.")
if not isinstance(ai_response_content, str): ai_response_content = str(ai_response_content)
st.markdown(ai_response_content) # Display sources/confidence here if available in ai_response_content
save_chat_message_to_db(chat_session_id, "assistant", ai_response_content)
st.session_state[agent_history_key].append(AIMessage(content=ai_response_content))
except Exception as e:
app_logger.error(f"Error during agent invocation for session {chat_session_id}: {e}", exc_info=True)
error_msg_user = f"Sorry, an error occurred: {type(e).__name__}. Please try again."
st.error(error_msg_user)
save_chat_message_to_db(chat_session_id, "assistant", f"Internal error: {type(e).__name__}")
st.session_state[agent_history_key].append(AIMessage(content=f"Internal error: {type(e).__name__}"))