MedQA / pages /2_Consult.py
mgbam's picture
Update pages/2_Consult.py
71a1c43 verified
raw
history blame
9.24 kB
# /home/user/app/pages/2_Consult.py
import streamlit as st
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage
from datetime import datetime
from typing import List, Optional # Corrected import for List and Optional
from config.settings import settings
from agent import get_agent_executor
from models import ChatMessage, ChatSession, User # User might not be needed directly if ID is used
from models.db import get_session_context # Or from models import get_session_context
from services.logger import app_logger
from services.metrics import log_consultation_start # Assuming this function exists
# Page config typically in app.py
# st.set_page_config(page_title=f"Consult - {settings.APP_TITLE}", layout="wide")
# --- Authentication Check ---
if not st.session_state.get("authenticated_user_id"):
st.warning("Please log in to access the consultation page.")
try:
st.switch_page("app.py")
except st.errors.StreamlitAPIException as e:
if "st.switch_page can only be called when running in MPA mode" in str(e):
app_logger.warning("Consult: Running in single-page mode or st.switch_page issue. Stopping script.")
st.info("Please navigate to the main login page.")
else:
app_logger.error(f"Consult: Error during st.switch_page: {e}")
st.error("Redirection error. Please go to the login page manually.")
st.stop()
# Get authenticated user's ID and username
authenticated_user_id = st.session_state.get("authenticated_user_id")
authenticated_username = st.session_state.get("authenticated_username", "User")
app_logger.info(f"User {authenticated_username} (ID: {authenticated_user_id}) accessed Consult page.")
# --- Initialize Agent ---
try:
agent_executor = get_agent_executor()
except ValueError as e: # Handles missing API key or other init issues
st.error(f"Could not initialize AI Agent: {e}")
app_logger.critical(f"AI Agent initialization failed: {e}", exc_info=True)
st.stop()
except Exception as e:
st.error(f"An unexpected error occurred while initializing the AI Agent: {e}")
app_logger.critical(f"Unexpected AI Agent initialization error: {e}", exc_info=True)
st.stop()
# --- Helper Functions ---
@st.cache_data(ttl=60) # Short cache for chat history to avoid constant DB hits on reruns
def load_chat_history_for_agent(session_id: int) -> List: # Type hint for return
"""Loads chat history from DB for the current session, formatted for LangChain agent."""
messages = []
app_logger.debug(f"Loading agent chat history for session_id: {session_id}")
with get_session_context() as db:
# If using SQLModel: from sqlmodel import select
# db_messages = db.exec(select(ChatMessage).where(ChatMessage.session_id == session_id).order_by(ChatMessage.timestamp)).all()
db_messages = db.query(ChatMessage).filter(ChatMessage.session_id == session_id).order_by(ChatMessage.timestamp).all()
for msg in db_messages:
if msg.role == "user":
messages.append(HumanMessage(content=msg.content))
elif msg.role == "assistant":
messages.append(AIMessage(content=msg.content))
elif msg.role == "tool" and hasattr(msg, 'tool_call_id') and msg.tool_call_id: # Ensure tool_call_id exists
messages.append(ToolMessage(content=msg.content, tool_call_id=str(msg.tool_call_id))) # Cast to str just in case
# Add other roles if necessary
app_logger.debug(f"Loaded {len(messages)} messages for agent history for session {session_id}.")
return messages
def save_chat_message_to_db(session_id: int, role: str, content: str, tool_call_id: Optional[str]=None, tool_name: Optional[str]=None):
"""Saves a chat message to the database."""
app_logger.debug(f"Saving message to DB for session {session_id}: Role={role}, Content snippet='{content[:50]}...'")
with get_session_context() as db:
chat_message = ChatMessage(
session_id=session_id,
role=role,
content=content,
timestamp=datetime.utcnow(),
tool_call_id=tool_call_id,
tool_name=tool_name
)
db.add(chat_message)
db.commit()
app_logger.info(f"Message saved to DB for session {session_id}. Role: {role}.")
# --- Page Logic ---
st.title("AI Consultation Room")
st.markdown(f"Interacting as: **{authenticated_username}**")
chat_session_id = st.session_state.get("current_chat_session_id")
if not chat_session_id:
st.error("No active chat session ID found in session state. This might happen if you logged in before this feature was fully active. Please try logging out and logging back in.")
app_logger.error(f"User {authenticated_username} (ID: {authenticated_user_id}) on Consult page with no current_chat_session_id.")
st.stop()
# Initialize agent's chat history if not already present for this session_id
# We use a more specific key for agent_chat_history to handle session changes
agent_history_key = f"agent_chat_history_{chat_session_id}"
if agent_history_key not in st.session_state:
st.session_state[agent_history_key] = load_chat_history_for_agent(chat_session_id)
if not st.session_state[agent_history_key]: # If no history, maybe add a system greeting
try:
log_consultation_start(user_id=authenticated_user_id, session_id=chat_session_id)
except Exception as e:
app_logger.warning(f"Failed to log consultation start: {e}")
initial_ai_message_content = "Hello! I am your AI Health Navigator. How can I assist you today?"
st.session_state[agent_history_key].append(AIMessage(content=initial_ai_message_content))
save_chat_message_to_db(chat_session_id, "assistant", initial_ai_message_content)
app_logger.info(f"Initialized new consultation for session {chat_session_id} with a greeting.")
# Display chat messages for UI (always fetch fresh from DB for UI consistency)
# This ensures UI reflects what's actually in the DB.
with st.container(): # Use a container for the chat display area
with get_session_context() as db:
# If using SQLModel: from sqlmodel import select
# ui_messages = db.exec(select(ChatMessage).where(ChatMessage.session_id == chat_session_id).order_by(ChatMessage.timestamp)).all()
ui_messages = db.query(ChatMessage).filter(ChatMessage.session_id == chat_session_id).order_by(ChatMessage.timestamp).all()
for msg in ui_messages:
avatar = "πŸ§‘β€βš•οΈ" if msg.role == "assistant" else "πŸ‘€"
if msg.role == "tool": avatar = "πŸ› οΈ"
with st.chat_message(msg.role, avatar=avatar):
st.markdown(msg.content)
# Chat input
if prompt := st.chat_input("Ask the AI... (e.g., 'What is hypertension?' or 'Suggest diagnostic tests for chest pain')"):
# Add user message to UI immediately (optimistic update)
with st.chat_message("user", avatar="πŸ‘€"):
st.markdown(prompt)
# Save user message to DB
save_chat_message_to_db(chat_session_id, "user", prompt)
# Add to agent's history (LangChain format)
st.session_state[agent_history_key].append(HumanMessage(content=prompt))
# Get AI response
with st.chat_message("assistant", avatar="πŸ§‘β€βš•οΈ"): # Prepare AI's chat message bubble
with st.spinner("AI is thinking..."):
try:
response = agent_executor.invoke({
"input": prompt,
"chat_history": st.session_state[agent_history_key] # Pass the current agent history
})
ai_response_content = response.get('output', "No output from AI.")
if not isinstance(ai_response_content, str): # Ensure it's a string
ai_response_content = str(ai_response_content)
st.markdown(ai_response_content) # Display AI response in UI
save_chat_message_to_db(chat_session_id, "assistant", ai_response_content) # Save to DB
st.session_state[agent_history_key].append(AIMessage(content=ai_response_content)) # Add to agent's history
except Exception as e:
app_logger.error(f"Error during agent invocation for session {chat_session_id}: {e}", exc_info=True)
error_message_user = f"Sorry, I encountered an error and could not process your request. Please try again or rephrase. (Error: {type(e).__name__})"
st.error(error_message_user) # Display error in the AI's bubble
# Save a generic error message to DB for the assistant's turn
save_chat_message_to_db(chat_session_id, "assistant", f"Error processing request: {type(e).__name__}")
# Add error representation to agent history so it's aware
st.session_state[agent_history_key].append(AIMessage(content=f"Observed internal error: {type(e).__name__}"))
# A full st.rerun() can be a bit disruptive if not needed.
# Streamlit's chat_input and context managers usually handle updates well.
# If messages aren't appearing correctly, a targeted rerun might be useful.
# st.rerun()