File size: 13,384 Bytes
71a1c43
2b9aa0c
7ae304b
2b9aa0c
861d237
 
2b9aa0c
 
7ae304b
861d237
352a295
2b9aa0c
352a295
2b9aa0c
861d237
71a1c43
2b9aa0c
7ae304b
 
 
 
71a1c43
 
 
 
861d237
2b9aa0c
861d237
2b9aa0c
7ae304b
 
 
 
 
 
 
352a295
7ae304b
 
71a1c43
2b9aa0c
7ae304b
861d237
 
7ae304b
861d237
 
 
2b9aa0c
7ae304b
 
2b9aa0c
71a1c43
861d237
352a295
861d237
2b9aa0c
7ae304b
 
 
 
 
 
 
2b9aa0c
 
861d237
 
 
 
 
 
 
 
 
 
7ae304b
861d237
 
 
 
 
 
 
7ae304b
2b9aa0c
 
 
71a1c43
7ae304b
2b9aa0c
 
 
7ae304b
 
2b9aa0c
 
861d237
 
7ae304b
 
861d237
7ae304b
 
 
 
 
 
861d237
 
 
7ae304b
 
 
 
 
 
861d237
7ae304b
 
 
 
 
 
 
 
 
 
 
861d237
 
 
 
7ae304b
 
 
 
 
861d237
7ae304b
 
 
 
 
 
 
71a1c43
7ae304b
71a1c43
 
7ae304b
861d237
7ae304b
 
71a1c43
7ae304b
 
 
71a1c43
861d237
7ae304b
71a1c43
7ae304b
352a295
861d237
 
 
71a1c43
7ae304b
71a1c43
7ae304b
71a1c43
7ae304b
352a295
 
861d237
71a1c43
 
2b9aa0c
352a295
71a1c43
 
7ae304b
 
 
 
 
 
 
 
 
71a1c43
861d237
7ae304b
 
 
 
 
 
 
861d237
 
7ae304b
 
352a295
 
7ae304b
71a1c43
 
7ae304b
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# /home/user/app/pages/2_Consult.py
import streamlit as st
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage # Ensure SystemMessage is imported
from datetime import datetime
from typing import List, Optional, Dict, Any
from sqlmodel import select

from config.settings import settings
from agent import get_agent_executor # This now returns the Gemini-based agent
from models import ChatMessage, ChatSession
from models.db import get_session_context
from services.logger import app_logger
from services.metrics import log_consultation_start

# --- Authentication Check ---
if not st.session_state.get("authenticated_user_id"):
    st.warning("Please log in to access the consultation page.")
    try:
        st.switch_page("app.py")
    except st.errors.StreamlitAPIException:
        st.info("Please navigate to the main login page.")
    st.stop()

authenticated_user_id = st.session_state.get("authenticated_user_id")
authenticated_username = st.session_state.get("authenticated_username", "User")
app_logger.info(f"User '{authenticated_username}' (ID: {authenticated_user_id}) accessed Consult page.")

# --- Initialize Agent ---
try:
    agent_executor = get_agent_executor() # Gets the Gemini agent executor
    app_logger.info("Gemini-based agent executor initialized for Consult page.")
except ValueError as e: # Catch specific error from get_agent_executor if API key is missing
    st.error(f"AI Agent Initialization Error: {e}")
    app_logger.critical(f"Fatal: AI Agent could not be initialized in Consult page: {e}", exc_info=True)
    st.info("Please ensure the necessary API keys (e.g., Google API Key for Gemini) are configured in the application settings.")
    st.stop()
except Exception as e:
    st.error(f"An unexpected error occurred while initializing the AI Agent: {e}")
    app_logger.critical(f"Fatal: Unexpected AI Agent initialization error: {e}", exc_info=True)
    st.stop()


# --- Session State for Consult Page ---
if 'current_consult_patient_context' not in st.session_state:
    st.session_state.current_consult_patient_context = {}
if 'consult_context_submitted' not in st.session_state:
    st.session_state.consult_context_submitted = False

# --- Helper Functions ---
@st.cache_data(ttl=30, show_spinner=False)
def load_chat_history_for_agent(session_id: int) -> List: # List of LangChain messages
    messages = []
    app_logger.debug(f"Loading agent chat history for session_id: {session_id}")
    with get_session_context() as db:
        statement = select(ChatMessage).where(ChatMessage.session_id == session_id).order_by(ChatMessage.timestamp)
        db_messages = db.exec(statement).all()
        for msg in db_messages:
            if msg.role == "user":
                messages.append(HumanMessage(content=msg.content))
            elif msg.role == "assistant":
                messages.append(AIMessage(content=msg.content))
            elif msg.role == "system": # Include system messages in agent history if they were saved
                messages.append(SystemMessage(content=msg.content))
    app_logger.debug(f"Loaded {len(messages)} messages for agent history for session {session_id}.")
    return messages

def save_chat_message_to_db(session_id: int, role: str, content: str, tool_call_id: Optional[str]=None, tool_name: Optional[str]=None):
    app_logger.debug(f"Saving message to DB for session {session_id}: Role={role}")
    with get_session_context() as db:
        chat_message = ChatMessage(
            session_id=session_id, role=role, content=content, timestamp=datetime.utcnow(),
            tool_call_id=tool_call_id, tool_name=tool_name
        )
        db.add(chat_message) # Commit handled by context manager
    app_logger.info(f"Message saved to DB for session {session_id}. Role: {role}.")

def update_chat_session_with_context_summary(session_id: int, context_summary: str):
    with get_session_context() as db:
        session_to_update = db.get(ChatSession, session_id)
        if session_to_update:
            session_to_update.patient_context_summary = context_summary
            db.add(session_to_update) # Stage for commit
            app_logger.info(f"Updated ChatSession {session_id} with patient context summary.")
        else:
            app_logger.error(f"Could not find ChatSession {session_id} to update with context summary.")

# --- Page Logic ---
st.title("AI Consultation Room")
st.markdown(f"Interacting as: **{authenticated_username}**")
st.info(f"{settings.MAIN_DISCLAIMER_SHORT} Remember to use only anonymized, simulated data.")

chat_session_id = st.session_state.get("current_chat_session_id")
if not chat_session_id:
    st.error("No active chat session ID. This can occur if a session wasn't properly created on login. Please try logging out and then logging back in. If the problem persists, contact support.")
    app_logger.error(f"User '{authenticated_username}' (ID: {authenticated_user_id}) on Consult page with NO current_chat_session_id.")
    st.stop()

# --- Patient Context Input Form ---
if not st.session_state.consult_context_submitted:
    st.subheader("Step 1: Provide Patient Context (Optional, Simulated Data Only)")
    with st.form(key="patient_context_form_consult"):
        st.markdown("**Reminder: Use only anonymized, simulated data for this demonstration.**")
        age = st.number_input("Patient Age (Simulated)", min_value=0, max_value=120, step=1, value=None) # Default to None
        gender_options = ["Not Specified", "Male", "Female", "Other"]
        gender = st.selectbox("Patient Gender (Simulated)", gender_options, index=0)
        chief_complaint = st.text_area("Chief Complaint / Reason for Consult (Simulated)", height=100, placeholder="e.g., Persistent cough for 2 weeks")
        key_history = st.text_area("Key Medical History (Simulated)", height=100, placeholder="e.g., Type 2 Diabetes, Hypertension, Asthma")
        current_meds = st.text_area("Current Medications (Simulated)", height=100, placeholder="e.g., Metformin 500mg BID, Lisinopril 10mg OD")
        submit_context_button = st.form_submit_button("Start Consult with this Context")

        if submit_context_button:
            context_dict = {
                "age": age if age is not None and age > 0 else None, # Store None if not specified
                "gender": gender if gender != "Not Specified" else None,
                "chief_complaint": chief_complaint.strip() or None,
                "key_medical_history": key_history.strip() or None,
                "current_medications": current_meds.strip() or None,
            }
            # Filter out None values for the summary string
            valid_context_parts = {k: v for k, v in context_dict.items() if v is not None}
            st.session_state.current_consult_patient_context = valid_context_parts # Store the filtered dict

            if valid_context_parts:
                context_summary_str_parts = [f"{k.replace('_', ' ').title()}: {v}" for k, v in valid_context_parts.items()]
                context_summary_for_db_and_agent = "; ".join(context_summary_str_parts)
            else:
                context_summary_for_db_and_agent = "No specific patient context provided for this session."
            
            update_chat_session_with_context_summary(chat_session_id, context_summary_for_db_and_agent)
            
            agent_history_key = f"agent_chat_history_{chat_session_id}"
            if agent_history_key not in st.session_state: st.session_state[agent_history_key] = []
            
            # Don't add patient context as a SystemMessage if it's passed as a variable to invoke
            # The agent's main system prompt will now include a placeholder for it.
            # However, we save it to DB for record keeping.
            if valid_context_parts: # Save a system message indicating context was provided
                 save_chat_message_to_db(chat_session_id, "system", f"Initial Patient Context Provided: {context_summary_for_db_and_agent}")

            st.session_state.consult_context_submitted = True
            app_logger.info(f"Patient context submitted for session {chat_session_id}: {context_summary_for_db_and_agent}")
            st.rerun()
    st.stop()

# --- Chat Interface (Shown after context is submitted or if skipped by some other logic not yet present) ---
st.subheader("Step 2: Interact with AI Health Navigator")
agent_history_key = f"agent_chat_history_{chat_session_id}"

if agent_history_key not in st.session_state:
    st.session_state[agent_history_key] = load_chat_history_for_agent(chat_session_id)
    if not st.session_state[agent_history_key]: # If history is empty
        try: log_consultation_start(user_id=authenticated_user_id, session_id=chat_session_id)
        except Exception as e: app_logger.warning(f"Failed to log consultation start metric: {e}")
        
        initial_ai_message_content = "Hello! I am your AI Health Navigator. How can I assist you today?"
        if st.session_state.get('current_consult_patient_context'):
            initial_ai_message_content += " I have noted the patient context you provided."
        
        st.session_state[agent_history_key].append(AIMessage(content=initial_ai_message_content))
        save_chat_message_to_db(chat_session_id, "assistant", initial_ai_message_content)
        app_logger.info(f"Initialized new consultation (session {chat_session_id}) with a greeting.")

# Display chat messages for UI
with st.container():
    with get_session_context() as db:
        stmt = select(ChatMessage).where(ChatMessage.session_id == chat_session_id).order_by(ChatMessage.timestamp)
        ui_messages = db.exec(stmt).all()
        for msg in ui_messages:
            if msg.role == "system": continue # Don't display system context messages directly
            avatar = "πŸ§‘β€βš•οΈ" if msg.role == "assistant" else "πŸ‘€"
            if msg.role == "tool": avatar = "πŸ› οΈ" # Assuming you might log tool calls this way
            with st.chat_message(msg.role, avatar=avatar):
                st.markdown(msg.content) # Potentially enhance to show sources/confidence if agent provides

if prompt := st.chat_input("Ask the AI..."):
    with st.chat_message("user", avatar="πŸ‘€"): st.markdown(prompt)
    save_chat_message_to_db(chat_session_id, "user", prompt)
    st.session_state[agent_history_key].append(HumanMessage(content=prompt))

    with st.chat_message("assistant", avatar="πŸ§‘β€βš•οΈ"):
        with st.spinner("AI is thinking..."):
            try:
                # Prepare patient context string for the agent, if any was provided
                patient_context_dict = st.session_state.get('current_consult_patient_context', {})
                if patient_context_dict:
                    context_parts_for_invoke = [f"{k.replace('_', ' ').title()}: {v}" for k, v in patient_context_dict.items()]
                    patient_context_str_for_invoke = "; ".join(context_parts_for_invoke)
                else:
                    patient_context_str_for_invoke = "No specific patient context was provided for this interaction."

                invoke_payload = {
                    "input": prompt,
                    "chat_history": st.session_state[agent_history_key],
                    "patient_context": patient_context_str_for_invoke # Pass to agent
                }
                app_logger.debug(f"Invoking agent with payload: {invoke_payload}")

                response = agent_executor.invoke(invoke_payload)
                
                ai_response_content = response.get('output', "I could not generate a valid response.")
                if not isinstance(ai_response_content, str): ai_response_content = str(ai_response_content)
                
                app_logger.info(f"Agent response for session {chat_session_id}: '{ai_response_content[:100]}...'")
                st.markdown(ai_response_content) # Display AI response
                save_chat_message_to_db(chat_session_id, "assistant", ai_response_content)
                st.session_state[agent_history_key].append(AIMessage(content=ai_response_content))

            except Exception as e:
                app_logger.error(f"Error during agent invocation for session {chat_session_id}: {e}", exc_info=True)
                # The user-facing error was: "Sorry, an error occurred: ValidationError. Please try again."
                # Let's try to be a bit more specific if we can, or keep it generic but log details.
                error_type_name = type(e).__name__ # e.g., "ValidationError", "APIError"
                user_friendly_error_message = f"Sorry, an error occurred ({error_type_name}). Please try rephrasing your query or contact support if the issue persists."
                st.error(user_friendly_error_message)
                
                # Save a representation of the error to DB for the assistant's turn
                db_error_message = f"System encountered an error: {error_type_name} while processing user query. Details logged."
                save_chat_message_to_db(chat_session_id, "assistant", db_error_message)
                # Add error representation to agent history so it's aware for next turn (optional)
                st.session_state[agent_history_key].append(AIMessage(content=f"Note to self: Encountered an error ({error_type_name}) on the previous turn."))