mgbam commited on
Commit
2b9aa0c
·
verified ·
1 Parent(s): 93cbb0d

Update pages/2_Consult.py

Browse files
Files changed (1) hide show
  1. pages/2_Consult.py +129 -0
pages/2_Consult.py CHANGED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage
3
+ from datetime import datetime
4
+
5
+ from config.settings import settings
6
+ from agent import get_agent_executor
7
+ from models import ChatMessage, ChatSession, User # Assuming User is in session_state
8
+ from models.db import get_session_context
9
+ from services.logger import app_logger
10
+ from services.metrics import log_consultation_start
11
+
12
+ st.set_page_config(page_title=f"Consult - {settings.APP_TITLE}", layout="wide")
13
+
14
+ if not st.session_state.get("authenticated_user"):
15
+ st.warning("Please log in to access the consultation page.")
16
+ st.switch_page("app.py") # Redirect to login
17
+
18
+ # --- Initialize Agent ---
19
+ try:
20
+ agent_executor = get_agent_executor()
21
+ except ValueError as e: # Handles missing API key
22
+ st.error(f"Could not initialize AI Agent: {e}")
23
+ st.stop()
24
+
25
+
26
+ # --- Helper Functions ---
27
+ def load_chat_history(session_id: int) -> list:
28
+ """Loads chat history from DB for the current session"""
29
+ messages = []
30
+ with get_session_context() as db:
31
+ db_messages = db.query(ChatMessage).filter(ChatMessage.session_id == session_id).order_by(ChatMessage.timestamp).all()
32
+ for msg in db_messages:
33
+ if msg.role == "user":
34
+ messages.append(HumanMessage(content=msg.content))
35
+ elif msg.role == "assistant":
36
+ messages.append(AIMessage(content=msg.content))
37
+ # Add tool message handling if you store them as distinct roles in DB
38
+ # elif msg.role == "tool":
39
+ # messages.append(ToolMessage(content=msg.content, tool_call_id=msg.tool_call_id))
40
+ return messages
41
+
42
+ def save_chat_message(session_id: int, role: str, content: str, tool_call_id: Optional[str]=None, tool_name: Optional[str]=None):
43
+ """Saves a chat message to the database."""
44
+ with get_session_context() as db:
45
+ chat_message = ChatMessage(
46
+ session_id=session_id,
47
+ role=role,
48
+ content=content,
49
+ timestamp=datetime.utcnow(),
50
+ tool_call_id=tool_call_id,
51
+ tool_name=tool_name
52
+ )
53
+ db.add(chat_message)
54
+ db.commit()
55
+
56
+ # --- Page Logic ---
57
+ st.title("AI Consultation Room")
58
+ st.markdown("Interact with the Quantum Health Navigator AI.")
59
+
60
+ current_user: User = st.session_state.authenticated_user
61
+ chat_session_id = st.session_state.get("current_chat_session_id")
62
+
63
+ if not chat_session_id:
64
+ st.error("No active chat session. Please re-login or contact support.")
65
+ st.stop()
66
+
67
+ # Load initial chat history for the agent (from Langchain Message objects)
68
+ # For the agent, we need history in LangChain message format
69
+ if "agent_chat_history" not in st.session_state:
70
+ st.session_state.agent_chat_history = load_chat_history(chat_session_id)
71
+ if not st.session_state.agent_chat_history: # If no history, maybe add a system greeting
72
+ log_consultation_start()
73
+ # You could add an initial AIMessage here if desired
74
+ # initial_ai_message = AIMessage(content="Hello! How can I assist you today?")
75
+ # st.session_state.agent_chat_history.append(initial_ai_message)
76
+ # save_chat_message(chat_session_id, "assistant", initial_ai_message.content)
77
+
78
+
79
+ # Display chat messages from DB (for UI)
80
+ with get_session_context() as db:
81
+ ui_messages = db.query(ChatMessage).filter(ChatMessage.session_id == chat_session_id).order_by(ChatMessage.timestamp).all()
82
+ for msg in ui_messages:
83
+ with st.chat_message(msg.role):
84
+ st.markdown(msg.content)
85
+
86
+ # Chat input
87
+ if prompt := st.chat_input("Ask the AI... (e.g., 'What is hypertension?' or 'Optimize treatment for patient X with diabetes')"):
88
+ # Add user message to UI and save to DB
89
+ with st.chat_message("user"):
90
+ st.markdown(prompt)
91
+ save_chat_message(chat_session_id, "user", prompt)
92
+
93
+ # Add to agent's history (LangChain format)
94
+ st.session_state.agent_chat_history.append(HumanMessage(content=prompt))
95
+
96
+ # Get AI response
97
+ with st.spinner("AI is thinking..."):
98
+ try:
99
+ response = agent_executor.invoke({
100
+ "input": prompt,
101
+ "chat_history": st.session_state.agent_chat_history
102
+ })
103
+ ai_response_content = response['output']
104
+
105
+ # Display AI response in UI and save to DB
106
+ with st.chat_message("assistant"):
107
+ st.markdown(ai_response_content)
108
+ save_chat_message(chat_session_id, "assistant", ai_response_content)
109
+
110
+ # Add AI response to agent's history
111
+ st.session_state.agent_chat_history.append(AIMessage(content=ai_response_content))
112
+
113
+ # Note: The agent executor might make tool calls. The create_openai_functions_agent
114
+ # and AgentExecutor handle the tool invocation and adding ToolMessages to history internally
115
+ # before producing the final 'output'. If you need to log individual tool calls/results
116
+ # to your DB, you might need a more custom agent loop or callbacks.
117
+
118
+ except Exception as e:
119
+ app_logger.error(f"Error during agent invocation: {e}")
120
+ st.error(f"An error occurred: {e}")
121
+ # Save error message as AI response?
122
+ error_message = f"Sorry, I encountered an error: {str(e)[:200]}" # Truncate for DB
123
+ with st.chat_message("assistant"): # Or a custom error role
124
+ st.markdown(error_message)
125
+ save_chat_message(chat_session_id, "assistant", error_message) # Or "error" role
126
+ st.session_state.agent_chat_history.append(AIMessage(content=error_message))
127
+
128
+ # Rerun to show the latest messages immediately (though Streamlit usually does this)
129
+ # st.rerun() # Usually not needed with st.chat_input and context managers