File size: 15,676 Bytes
45ea80d
2b9aa0c
45ea80d
2b9aa0c
45ea80d
861d237
2b9aa0c
 
45ea80d
861d237
352a295
2b9aa0c
45ea80d
2b9aa0c
861d237
71a1c43
2b9aa0c
7ae304b
45ea80d
7ae304b
45ea80d
7ae304b
45ea80d
71a1c43
45ea80d
 
861d237
2b9aa0c
45ea80d
 
2b9aa0c
743ac85
45ea80d
 
7ae304b
45ea80d
 
7ae304b
45ea80d
 
 
71a1c43
2b9aa0c
861d237
45ea80d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b9aa0c
45ea80d
 
 
 
2b9aa0c
 
 
45ea80d
 
2b9aa0c
 
45ea80d
861d237
45ea80d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
743ac85
45ea80d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
861d237
 
45ea80d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71a1c43
45ea80d
 
 
 
 
 
 
 
 
 
 
 
 
7ae304b
45ea80d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1aff0c6
45ea80d
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
# /home/user/app/pages/2_Consult.py
import streamlit as st
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from datetime import datetime
from typing import List, Optional, Dict, Any
from sqlmodel import select

from config.settings import settings
from agent import get_agent_executor # This now returns the OpenAI-based agent executor
from models import ChatMessage, ChatSession
from models.db import get_session_context
from services.logger import app_logger
from services.metrics import log_consultation_start # Assuming this function exists

# --- Authentication Check ---
if not st.session_state.get("authenticated_user_id"):
    st.warning("Please log in to access the consultation page.")
    try:
        st.switch_page("app.py") # Redirect to the main login page (app.py)
    except st.errors.StreamlitAPIException:
        # This can happen if st.switch_page is called when not in a multipage app context (e.g. dev)
        st.info("Please navigate to the main login page.")
    st.stop() # Halt script execution

authenticated_user_id = st.session_state.get("authenticated_user_id")
authenticated_username = st.session_state.get("authenticated_username", "User") # Default to "User"
app_logger.info(f"User '{authenticated_username}' (ID: {authenticated_user_id}) accessed Consult page.")

# --- Initialize Agent ---
# This will now initialize the OpenAI agent via get_agent_executor() from agent.py
try:
    agent_executor = get_agent_executor()
    app_logger.info("OpenAI-based agent executor initialized successfully for Consult page.")
except ValueError as e: # Catch specific error from get_agent_executor if API key is missing
    st.error(f"AI Agent Initialization Error: {e}")
    app_logger.critical(f"Fatal: AI Agent could not be initialized in Consult page: {e}", exc_info=True)
    st.info("Please ensure the OPENAI_API_KEY is correctly configured in the application settings (Hugging Face Secrets).")
    st.stop()
except Exception as e: # Catch any other unexpected error during agent init
    st.error(f"An unexpected error occurred while initializing the AI Agent: {e}")
    app_logger.critical(f"Fatal: Unexpected AI Agent initialization error: {e}", exc_info=True)
    st.stop()


# --- Session State for Consult Page ---
# Using more descriptive key for the patient context dictionary
if 'current_consult_patient_context_dict' not in st.session_state:
    st.session_state.current_consult_patient_context_dict = {}
if 'consult_context_submitted' not in st.session_state:
    st.session_state.consult_context_submitted = False

# --- Helper Functions ---
@st.cache_data(ttl=30, show_spinner=False, max_entries=10) # Cache agent history for current session
def load_chat_history_for_agent(session_id: int) -> List[Any]: # List of LangChain BaseMessage
    """Loads chat history from DB, formatted for LangChain agent (HumanMessage, AIMessage, SystemMessage)."""
    messages = []
    app_logger.debug(f"Loading agent chat history from DB for session_id: {session_id}")
    try:
        with get_session_context() as db:
            statement = select(ChatMessage).where(ChatMessage.session_id == session_id).order_by(ChatMessage.timestamp)
            db_messages = db.exec(statement).all()
            for msg in db_messages:
                if msg.role == "user": messages.append(HumanMessage(content=msg.content))
                elif msg.role == "assistant": messages.append(AIMessage(content=msg.content))
                elif msg.role == "system": messages.append(SystemMessage(content=msg.content))
                # ToolMessages are usually part of agent_scratchpad, not general chat_history for invoke
        app_logger.debug(f"Loaded {len(messages)} LangChain messages for agent history (session {session_id}).")
    except Exception as e:
        app_logger.error(f"Error loading chat history for session {session_id}: {e}", exc_info=True)
        st.toast(f"Error loading chat history: {e}", icon="⚠️") # Non-blocking error
    return messages

def save_chat_message_to_db(session_id: int, role: str, content: str, tool_call_id: Optional[str]=None, tool_name: Optional[str]=None):
    """Saves a chat message to the database."""
    app_logger.debug(f"Saving message to DB for session {session_id}: Role='{role}', Content snippet='{content[:50]}...'")
    try:
        with get_session_context() as db:
            chat_message_obj = ChatMessage(
                session_id=session_id, role=role, content=content, timestamp=datetime.utcnow(),
                tool_call_id=tool_call_id, tool_name=tool_name
            )
            db.add(chat_message_obj) # Commit will be handled by the context manager
        app_logger.info(f"Message (Role: {role}) saved to DB for session {session_id}.")
    except Exception as e:
        app_logger.error(f"Error saving chat message to DB for session {session_id}: {e}", exc_info=True)
        st.toast(f"Error saving message: {e}", icon="⚠️")


def update_chat_session_with_context_summary_in_db(session_id: int, context_summary: str):
    """Updates the ChatSession record with the patient context summary."""
    try:
        with get_session_context() as db:
            session_to_update = db.get(ChatSession, session_id) # Fetch by primary key
            if session_to_update:
                session_to_update.patient_context_summary = context_summary
                db.add(session_to_update) # SQLModel handles add for updates if PK is present
                app_logger.info(f"Updated ChatSession {session_id} with patient context summary in DB.")
            else:
                app_logger.error(f"Could not find ChatSession {session_id} in DB to update context summary.")
    except Exception as e:
        app_logger.error(f"Error updating chat session {session_id} context summary: {e}", exc_info=True)
        st.toast(f"Error saving context summary: {e}", icon="⚠️")


# --- Page Logic ---
st.title("AI Consultation Room")
st.markdown(f"Interacting as: **{authenticated_username}**")
# Prominent disclaimer on the consult page itself
st.warning(f"**Reminder & Disclaimer:** {settings.MAIN_DISCLAIMER_LONG} {settings.SIMULATION_DISCLAIMER}")


chat_session_id = st.session_state.get("current_chat_session_id")
if not chat_session_id:
    st.error("Error: No active chat session ID found. This is unexpected after login. Please try logging out and logging back in. If the problem persists, please contact support.")
    app_logger.critical(f"User '{authenticated_username}' (ID: {authenticated_user_id}) on Consult page encountered MISSING current_chat_session_id.")
    st.stop()

# --- Patient Context Input Form ---
if not st.session_state.consult_context_submitted:
    st.subheader("Step 1: Provide Patient Context (Optional, Use Simulated Data Only)")
    with st.form(key="patient_context_form_consult_page_openai"): # Unique key
        st.markdown("**Crucial Reminder: Use only anonymized, simulated data for this demonstration. Do NOT enter real PHI.**")
        age_in = st.number_input("Patient Age (Simulated)", min_value=0, max_value=120, step=1, value=None, help="Leave blank if not applicable.")
        gender_options = ["Not Specified", "Male", "Female", "Other"]
        gender_in = st.selectbox("Patient Gender (Simulated)", gender_options, index=0)
        cc_in = st.text_area("Chief Complaint / Reason for Consult (Simulated)", height=100, placeholder="e.g., Persistent cough for 2 weeks, fatigue")
        hist_in = st.text_area("Key Medical History (Simulated)", height=100, placeholder="e.g., Type 2 Diabetes (controlled), Hypertension (on medication), Asthma (mild intermittent)")
        meds_in = st.text_area("Current Medications (Simulated)", height=100, placeholder="e.g., Metformin 500mg BID, Lisinopril 10mg OD, Salbutamol inhaler PRN")
        submit_context_btn = st.form_submit_button("Start Consult with this Context")

        if submit_context_btn:
            raw_context = {
                "Age": age_in, "Gender": gender_in, "Chief Complaint": cc_in,
                "Key Medical History": hist_in, "Current Medications": meds_in,
            }
            # Filter out None/empty/"Not Specified" values for a cleaner context dictionary
            filtered_context_dict = {
                k: v for k, v in raw_context.items()
                if v is not None and str(v).strip() and \
                   (isinstance(v, str) and v.lower() != "not specified") and \
                   (isinstance(v, int) and v > 0 or not isinstance(v, int)) # ensure age > 0 if int
            }
            st.session_state.current_consult_patient_context_dict = filtered_context_dict

            if filtered_context_dict:
                context_summary_str = "; ".join([f"{k}: {v}" for k, v in filtered_context_dict.items()])
            else:
                context_summary_str = "No specific patient context was provided for this session."
            
            update_chat_session_with_context_summary_in_db(chat_session_id, context_summary_str)
            # Save a system message to DB indicating context was set for this session
            save_chat_message_to_db(chat_session_id, "system", f"Initial Patient Context Set: {context_summary_str}")
            
            st.session_state.consult_context_submitted = True
            app_logger.info(f"Patient context submitted for session {chat_session_id}: {context_summary_str}")
            st.rerun() # Rerun to hide form and proceed to chat interface
    st.stop() # Don't proceed to chat interface until context form is handled

# --- Chat Interface (Shown after context is submitted) ---
st.subheader("Step 2: Interact with AI Health Navigator")
agent_history_key = f"agent_chat_history_{chat_session_id}" # Session-specific key for agent's message history

# Initialize or load agent's chat history (list of LangChain messages)
if agent_history_key not in st.session_state:
    st.session_state[agent_history_key] = load_chat_history_for_agent(chat_session_id)
    if not st.session_state[agent_history_key]: # If history is empty after loading from DB
        try:
            log_consultation_start(user_id=authenticated_user_id, session_id=chat_session_id)
        except Exception as e_metric:
            app_logger.warning(f"Failed to log consultation start metric: {e_metric}")
        
        initial_ai_message_content = "Hello! I am your AI Health Navigator. How can I assist you today?"
        patient_context_for_greeting = st.session_state.get('current_consult_patient_context_dict', {})
        if patient_context_for_greeting: # Check if the dict itself is non-empty
            initial_ai_message_content += " I have noted the patient context you provided."
        
        # Add initial AI message to agent's history and save to DB
        st.session_state[agent_history_key].append(AIMessage(content=initial_ai_message_content))
        save_chat_message_to_db(chat_session_id, "assistant", initial_ai_message_content)
        app_logger.info(f"Initialized new consultation (session {chat_session_id}) with a greeting.")

# Display chat messages for UI (fetch fresh from DB for UI consistency)
# This uses a scrollable container.
chat_display_container = st.container(height=450)
with chat_display_container:
    with get_session_context() as db:
        stmt = select(ChatMessage).where(ChatMessage.session_id == chat_session_id).order_by(ChatMessage.timestamp)
        ui_messages_from_db = db.exec(stmt).all()
        for msg in ui_messages_from_db:
            if msg.role == "system": continue # Don't display raw system messages in chat UI
            avatar_icon = "πŸ§‘β€βš•οΈ" if msg.role == "assistant" else "πŸ‘€"
            if msg.role == "tool": avatar_icon = "πŸ› οΈ" # Example if you log tool messages for UI
            with st.chat_message(msg.role, avatar=avatar_icon):
                st.markdown(msg.content) # Add source/confidence here if msg object from DB supports it

# Chat input from user
if user_prompt := st.chat_input("Ask the AI... (e.g., 'What is hypertension?')"):
    # Display user message in UI immediately
    with chat_display_container: # Also add to the container
        with st.chat_message("user", avatar="πŸ‘€"):
            st.markdown(user_prompt)
    
    save_chat_message_to_db(chat_session_id, "user", user_prompt)
    st.session_state[agent_history_key].append(HumanMessage(content=user_prompt))

    # Get AI response
    with chat_display_container: # Add AI response to the container
        with st.chat_message("assistant", avatar="πŸ§‘β€βš•οΈ"):
            thinking_message = st.empty() # Placeholder for "AI is thinking..."
            thinking_message.markdown("β–Œ") # Simple animated cursor
            
            try:
                # Prepare patient context string for the agent
                patient_context_dict = st.session_state.get('current_consult_patient_context_dict', {})
                if patient_context_dict: # If there's any context
                    context_parts_for_invoke = [f"{k}: {v}" for k, v in patient_context_dict.items()]
                    patient_context_str_for_invoke = "; ".join(context_parts_for_invoke)
                else: # If no context was provided or all fields were empty/default
                    patient_context_str_for_invoke = "No specific patient context was provided by the user for this interaction."

                # These are the keys expected by the OpenAI Functions Agent prompt
                invoke_payload = {
                    "input": user_prompt,
                    "chat_history": st.session_state[agent_history_key], # List of BaseMessage
                    "patient_context": patient_context_str_for_invoke
                }
                app_logger.debug(f"Invoking OpenAI agent with payload: {invoke_payload}")
                
                thinking_message.markdown("AI is thinking...") # Update spinner text

                response = agent_executor.invoke(invoke_payload)
                
                ai_response_content = response.get('output', "I could not generate a valid response at this time.")
                if not isinstance(ai_response_content, str): ai_response_content = str(ai_response_content)
                
                app_logger.info(f"OpenAI Agent response for session {chat_session_id}: '{ai_response_content[:100]}...'")
                thinking_message.empty() # Clear "thinking..." message
                st.markdown(ai_response_content) # Display AI response
                
                save_chat_message_to_db(chat_session_id, "assistant", ai_response_content)
                st.session_state[agent_history_key].append(AIMessage(content=ai_response_content))

            except Exception as e:
                app_logger.error(f"Error during OpenAI agent invocation for session {chat_session_id}: {e}", exc_info=True)
                error_type_name = type(e).__name__
                user_friendly_error = f"Sorry, an error occurred ({error_type_name}). Please try rephrasing your query or contact support if the issue persists."
                thinking_message.empty() # Clear "thinking..." message
                st.error(user_friendly_error) # Display error in the AI's bubble
                
                db_error_msg = f"System encountered an error: {error_type_name}. Details logged."
                save_chat_message_to_db(chat_session_id, "assistant", db_error_msg)
                st.session_state[agent_history_key].append(AIMessage(content=f"Note to self: Encountered error ({error_type_name})."))
    
    st.rerun() # Rerun to ensure the new messages are at the bottom of the scrollable container