|
|
|
import streamlit as st |
|
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage |
|
from datetime import datetime |
|
from typing import List, Optional, Dict, Any |
|
from sqlmodel import select |
|
|
|
from config.settings import settings |
|
from agent import get_agent_executor |
|
from models import ChatMessage, ChatSession |
|
from models.db import get_session_context |
|
from services.logger import app_logger |
|
from services.metrics import log_consultation_start |
|
|
|
|
|
if not st.session_state.get("authenticated_user_id"): |
|
st.warning("Please log in to access the consultation page.") |
|
try: |
|
st.switch_page("app.py") |
|
except st.errors.StreamlitAPIException: |
|
st.info("Please navigate to the main login page.") |
|
st.stop() |
|
|
|
authenticated_user_id = st.session_state.get("authenticated_user_id") |
|
authenticated_username = st.session_state.get("authenticated_username", "User") |
|
app_logger.info(f"User '{authenticated_username}' (ID: {authenticated_user_id}) accessed Consult page.") |
|
|
|
|
|
try: |
|
agent_executor = get_agent_executor() |
|
app_logger.info("OpenAI-based agent executor initialized successfully for Consult page.") |
|
except ValueError as e: |
|
st.error(f"AI Agent Initialization Error: {e}") |
|
app_logger.critical(f"Fatal: AI Agent could not be initialized in Consult page: {e}", exc_info=True) |
|
st.info("Please ensure the OPENAI_API_KEY is correctly configured in the application settings (Hugging Face Secrets).") |
|
st.stop() |
|
except Exception as e: |
|
st.error(f"An unexpected error occurred while initializing the AI Agent: {e}") |
|
app_logger.critical(f"Fatal: Unexpected AI Agent initialization error: {e}", exc_info=True) |
|
st.stop() |
|
|
|
|
|
if 'current_consult_patient_context_dict' not in st.session_state: |
|
st.session_state.current_consult_patient_context_dict = {} |
|
if 'consult_context_submitted' not in st.session_state: |
|
st.session_state.consult_context_submitted = False |
|
|
|
|
|
|
|
@st.cache_data(ttl=30, show_spinner=False, max_entries=10) |
|
def load_chat_history_for_agent(session_id: int) -> List[Any]: |
|
messages = [] |
|
app_logger.debug(f"Loading agent chat history from DB for session_id: {session_id}") |
|
try: |
|
with get_session_context() as db: |
|
statement = select(ChatMessage).where(ChatMessage.session_id == session_id).order_by(ChatMessage.timestamp) |
|
db_messages = db.exec(statement).all() |
|
for msg in db_messages: |
|
if msg.role == "user": messages.append(HumanMessage(content=msg.content)) |
|
elif msg.role == "assistant": messages.append(AIMessage(content=msg.content)) |
|
elif msg.role == "system": messages.append(SystemMessage(content=msg.content)) |
|
app_logger.debug(f"Loaded {len(messages)} LangChain messages for agent history (session {session_id}).") |
|
except Exception as e: |
|
app_logger.error(f"Error loading chat history for session {session_id}: {e}", exc_info=True) |
|
st.toast(f"Error loading history: {e}", icon="β οΈ") |
|
return messages |
|
|
|
def save_chat_message_to_db(session_id: int, role: str, content: str, tool_call_id: Optional[str]=None, tool_name: Optional[str]=None): |
|
app_logger.debug(f"Saving message to DB for session {session_id}: Role='{role}', Content snippet='{content[:50]}...'") |
|
try: |
|
with get_session_context() as db: |
|
chat_message_obj = ChatMessage( |
|
session_id=session_id, role=role, content=content, timestamp=datetime.utcnow(), |
|
tool_call_id=tool_call_id, tool_name=tool_name |
|
) |
|
db.add(chat_message_obj) |
|
app_logger.info(f"Message (Role: {role}) saved to DB for session {session_id}.") |
|
except Exception as e: |
|
app_logger.error(f"Error saving chat message to DB for session {session_id}: {e}", exc_info=True) |
|
st.toast(f"Error saving message: {e}", icon="β οΈ") |
|
|
|
def update_chat_session_with_context_summary_in_db(session_id: int, context_summary: str): |
|
try: |
|
with get_session_context() as db: |
|
session_to_update = db.get(ChatSession, session_id) |
|
if session_to_update: |
|
session_to_update.patient_context_summary = context_summary |
|
db.add(session_to_update) |
|
app_logger.info(f"Updated ChatSession {session_id} with patient context summary in DB.") |
|
else: |
|
app_logger.error(f"Could not find ChatSession {session_id} in DB to update context summary.") |
|
except Exception as e: |
|
app_logger.error(f"Error updating chat session {session_id} context summary: {e}", exc_info=True) |
|
st.toast(f"Error saving context: {e}", icon="β οΈ") |
|
|
|
|
|
st.title("AI Consultation Room") |
|
st.markdown(f"Interacting as: **{authenticated_username}**") |
|
st.warning(f"**Reminder & Disclaimer:** {settings.MAIN_DISCLAIMER_LONG} {settings.SIMULATION_DISCLAIMER}") |
|
|
|
chat_session_id = st.session_state.get("current_chat_session_id") |
|
if not chat_session_id: |
|
st.error("Error: No active chat session ID found. Please try logging out and back in.") |
|
app_logger.critical(f"User '{authenticated_username}' (ID: {authenticated_user_id}) on Consult page encountered MISSING current_chat_session_id.") |
|
st.stop() |
|
|
|
|
|
if not st.session_state.consult_context_submitted: |
|
st.subheader("Step 1: Provide Patient Context (Optional, Use Simulated Data Only)") |
|
|
|
with st.form(key="patient_context_form_consult_page_openai_v2"): |
|
st.markdown("**Crucial Reminder: Use only anonymized, simulated data. Do NOT enter real PHI.**") |
|
age_in = st.number_input("Patient Age (Simulated)", min_value=0, max_value=120, step=1, value=None, help="Leave blank if not applicable.") |
|
gender_in = st.selectbox("Patient Gender (Simulated)", ["Not Specified", "Male", "Female", "Other"], index=0) |
|
cc_in = st.text_area("Chief Complaint / Reason for Consult (Simulated)", height=100, placeholder="e.g., Persistent cough") |
|
hist_in = st.text_area("Key Medical History (Simulated)", height=100, placeholder="e.g., Type 2 Diabetes") |
|
meds_in = st.text_area("Current Medications (Simulated)", height=100, placeholder="e.g., Metformin") |
|
submit_context_btn = st.form_submit_button("Start Consult with this Context") |
|
|
|
if submit_context_btn: |
|
raw_context = {"Age": age_in, "Gender": gender_in, "Chief Complaint": cc_in, "Key Medical History": hist_in, "Current Medications": meds_in} |
|
filtered_context_dict = { k: v for k, v in raw_context.items() if v is not None and str(v).strip() and (isinstance(v, str) and v.lower() != "not specified") and (isinstance(v, int) and v > 0 or not isinstance(v, int))} |
|
st.session_state.current_consult_patient_context_dict = filtered_context_dict |
|
context_summary_str = "; ".join([f"{k}: {v}" for k, v in filtered_context_dict.items()]) if filtered_context_dict else "No specific patient context was provided." |
|
update_chat_session_with_context_summary_in_db(chat_session_id, context_summary_str) |
|
save_chat_message_to_db(chat_session_id, "system", f"Initial Patient Context Set: {context_summary_str}") |
|
st.session_state.consult_context_submitted = True |
|
app_logger.info(f"Patient context submitted for session {chat_session_id}: {context_summary_str}") |
|
st.rerun() |
|
st.stop() |
|
|
|
|
|
st.subheader("Step 2: Interact with AI Health Navigator") |
|
agent_history_key = f"agent_chat_history_{chat_session_id}" |
|
|
|
if agent_history_key not in st.session_state: |
|
st.session_state[agent_history_key] = load_chat_history_for_agent(chat_session_id) |
|
if not st.session_state[agent_history_key]: |
|
try: log_consultation_start(user_id=authenticated_user_id, session_id=chat_session_id) |
|
except Exception as e_metric: app_logger.warning(f"Failed log_consultation_start: {e_metric}") |
|
initial_ai_msg = "Hello! I am your AI Health Navigator. How can I assist you today?" |
|
if st.session_state.get('current_consult_patient_context_dict'): |
|
initial_ai_msg += " I have noted the patient context you provided." |
|
st.session_state[agent_history_key].append(AIMessage(content=initial_ai_msg)) |
|
save_chat_message_to_db(chat_session_id, "assistant", initial_ai_msg) |
|
|
|
chat_display_container = st.container(height=450) |
|
with chat_display_container: |
|
|
|
with get_session_context() as db: |
|
stmt = select(ChatMessage).where(ChatMessage.session_id == chat_session_id).order_by(ChatMessage.timestamp) |
|
ui_messages_from_db = db.exec(stmt).all() |
|
for msg_db in ui_messages_from_db: |
|
if msg_db.role == "system": continue |
|
avatar_icon = "π§ββοΈ" if msg_db.role == "assistant" else "π€" |
|
if msg_db.role == "tool": avatar_icon = "π οΈ" |
|
with st.chat_message(msg_db.role, avatar=avatar_icon): |
|
st.markdown(msg_db.content) |
|
|
|
if user_prompt := st.chat_input("Ask the AI... (e.g., 'What is hypertension?')"): |
|
with chat_display_container: |
|
with st.chat_message("user", avatar="π€"): st.markdown(user_prompt) |
|
save_chat_message_to_db(chat_session_id, "user", user_prompt) |
|
st.session_state[agent_history_key].append(HumanMessage(content=user_prompt)) |
|
|
|
with chat_display_container: |
|
with st.chat_message("assistant", avatar="π§ββοΈ"): |
|
thinking_msg_placeholder = st.empty() |
|
thinking_msg_placeholder.markdown("β") |
|
try: |
|
patient_context_dict = st.session_state.get('current_consult_patient_context_dict', {}) |
|
patient_context_str_for_invoke = "; ".join([f"{k}: {v}" for k,v in patient_context_dict.items()]) if patient_context_dict else "No specific patient context provided." |
|
|
|
invoke_payload = { |
|
"input": user_prompt, |
|
"chat_history": st.session_state[agent_history_key], |
|
"patient_context": patient_context_str_for_invoke |
|
} |
|
app_logger.debug(f"Invoking OpenAI agent with payload: {invoke_payload}") |
|
thinking_msg_placeholder.markdown("AI is thinking...") |
|
response = agent_executor.invoke(invoke_payload) |
|
ai_response_content = response.get('output', "Could not generate a response.") |
|
if not isinstance(ai_response_content, str): ai_response_content = str(ai_response_content) |
|
|
|
app_logger.info(f"OpenAI Agent response for session {chat_session_id}: '{ai_response_content[:100]}...'") |
|
thinking_msg_placeholder.empty() |
|
st.markdown(ai_response_content) |
|
save_chat_message_to_db(chat_session_id, "assistant", ai_response_content) |
|
st.session_state[agent_history_key].append(AIMessage(content=ai_response_content)) |
|
except Exception as e: |
|
app_logger.error(f"Error during OpenAI agent invocation for session {chat_session_id}: {e}", exc_info=True) |
|
error_type_name = type(e).__name__ |
|
user_friendly_error = f"Sorry, an error occurred ({error_type_name}). Please try rephrasing or contact support." |
|
thinking_msg_placeholder.empty() |
|
st.error(user_friendly_error) |
|
db_error_msg = f"System encountered an error: {error_type_name}. Details logged." |
|
save_chat_message_to_db(chat_session_id, "assistant", db_error_msg) |
|
st.session_state[agent_history_key].append(AIMessage(content=f"Note: Error ({error_type_name}).")) |
|
st.rerun() |