mgbam commited on
Commit
4b23857
Β·
verified Β·
1 Parent(s): 6b6515d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -48
app.py CHANGED
@@ -6,9 +6,9 @@ import os
6
  import traceback
7
  from dotenv import load_dotenv
8
 
9
- # Import agent logic and message types from agent.py
10
  try:
11
- from agent import ClinicalAgent, AgentState, check_red_flags
12
  from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
13
  except ImportError as e:
14
  st.error(f"Failed to import from agent.py: {e}. Make sure agent.py is in the same directory.")
@@ -46,26 +46,22 @@ def main():
46
  if "messages" not in st.session_state: st.session_state.messages = []
47
  if "patient_data" not in st.session_state: st.session_state.patient_data = None
48
  if "summary" not in st.session_state: st.session_state.summary = None
49
- # Initialize the agent instance only once
50
  if "agent" not in st.session_state:
51
  try:
52
  st.session_state.agent = ClinicalAgent()
53
  print("ClinicalAgent successfully initialized in Streamlit session state.")
54
  except Exception as e:
55
  st.error(f"Failed to initialize Clinical Agent: {e}. Check API keys and dependencies.")
56
- print(f"ERROR Initializing ClinicalAgent: {e}")
57
- traceback.print_exc()
58
- st.stop()
59
-
60
 
61
  # --- Patient Data Input Sidebar ---
62
  with st.sidebar:
63
  st.header("πŸ“„ Patient Intake Form")
64
- # Input fields... (Using shorter versions for brevity, assume full fields are here)
65
  st.subheader("Demographics"); age = st.number_input("Age", 0, 120, 55, key="sb_age"); sex = st.selectbox("Sex", ["Male", "Female", "Other"], key="sb_sex")
66
  st.subheader("HPI"); chief_complaint = st.text_input("Chief Complaint", "Chest pain", key="sb_cc"); hpi_details = st.text_area("HPI Details", "55 y/o male...", height=100, key="sb_hpi"); symptoms = st.multiselect("Symptoms", ["Nausea", "Diaphoresis", "SOB", "Dizziness", "Severe Headache", "Syncope", "Hemoptysis"], default=["Nausea", "Diaphoresis"], key="sb_sym")
67
  st.subheader("History"); pmh = st.text_area("PMH", "HTN, HLD, DM2, History of MI", key="sb_pmh"); psh = st.text_area("PSH", "Appendectomy", key="sb_psh")
68
- st.subheader("Meds & Allergies"); current_meds_str = st.text_area("Current Meds", "Lisinopril 10mg daily\nMetformin 1000mg BID\nWarfarin 5mg daily", key="sb_meds"); allergies_str = st.text_area("Allergies", "Penicillin (rash), Aspirin", key="sb_allergies") # Added Warfarin/Aspirin for testing
69
  st.subheader("Social/Family"); social_history = st.text_area("SH", "Smoker", key="sb_sh"); family_history = st.text_area("FHx", "Father MI", key="sb_fhx")
70
  st.subheader("Vitals & Exam"); col1, col2 = st.columns(2);
71
  with col1: temp_c = st.number_input("Temp C", 35.0, 42.0, 36.8, format="%.1f", key="sb_temp"); hr_bpm = st.number_input("HR", 30, 250, 95, key="sb_hr"); rr_rpm = st.number_input("RR", 5, 50, 18, key="sb_rr")
@@ -92,7 +88,6 @@ def main():
92
  st.session_state.messages = [HumanMessage(content=initial_prompt)]
93
  st.session_state.summary = None # Reset summary
94
  st.success("Patient data loaded/updated.")
95
- # Rerun might be needed if the main area should clear or update based on new data
96
  st.rerun()
97
 
98
  # --- Main Chat Interface Area ---
@@ -110,20 +105,44 @@ def main():
110
  if prefix: st.markdown(prefix); structured_output = json.loads(json_str);
111
  if suffix: st.markdown(suffix)
112
  elif ai_content.strip().startswith("{") and ai_content.strip().endswith("}"): structured_output = json.loads(ai_content); ai_content = ""
113
- else: st.markdown(ai_content) # Display non-JSON content
114
  except Exception as e: st.markdown(ai_content); print(f"Error parsing/displaying AI JSON: {e}")
 
115
  if structured_output and isinstance(structured_output, dict): # Structured JSON display logic...
116
  st.divider(); st.subheader("πŸ“Š AI Analysis & Recommendations")
117
  cols = st.columns(2);
118
- with cols[0]: st.markdown("**Assessment:**"); st.markdown(f"> {structured_output.get('assessment', 'N/A')}"); st.markdown("**Differential Diagnosis:**"); ddx = structured_output.get('differential_diagnosis', []);
119
- if ddx: [st.expander(f"{'πŸ₯‡πŸ₯ˆπŸ₯‰'[('High','Medium','Low').index(item.get('likelihood','Low')[0])] if item.get('likelihood','?')[0] in 'HML' else '?'} {item.get('diagnosis', 'Unknown')} ({item.get('likelihood','?')})").write(f"**Rationale:** {item.get('rationale', 'N/A')}") for item in ddx]
120
- else: st.info("No DDx provided."); st.markdown("**Risk Assessment:**"); risk = structured_output.get('risk_assessment', {}); flags=risk.get('identified_red_flags',[]); concerns=risk.get("immediate_concerns",[]); comps=risk.get("potential_complications",[])
121
- if flags: st.warning(f"**Flags:** {', '.join(flags)}"); if concerns: st.warning(f"**Concerns:** {', '.join(concerns)}"); if comps: st.info(f"**Potential Complications:** {', '.join(comps)}");
122
- if not flags and not concerns: st.success("No major risks highlighted.")
123
- with cols[1]: st.markdown("**Recommended Plan:**"); plan = structured_output.get('recommended_plan', {});
124
- for section in ["investigations","therapeutics","consultations","patient_education"]: st.markdown(f"_{section.replace('_',' ').capitalize()}:_"); items = plan.get(section); [st.markdown(f"- {item}") for item in items] if items and isinstance(items, list) else (st.markdown(f"- {items}") if items else st.markdown("_None_")); st.markdown("")
125
- st.markdown("**Rationale & Guideline Check:**"); st.markdown(f"> {structured_output.get('rationale_summary', 'N/A')}"); interaction_summary = structured_output.get("interaction_check_summary", "");
126
- if interaction_summary: st.markdown("**Interaction Check Summary:**"); st.markdown(f"> {interaction_summary}"); st.divider()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  # Tool Call Display
129
  if getattr(msg, 'tool_calls', None):
@@ -138,9 +157,7 @@ def main():
138
  with st.chat_message(tool_name_display, avatar="πŸ› οΈ"):
139
  try: # Tool message display logic...
140
  tool_data = json.loads(msg.content); status = tool_data.get("status", "info"); message = tool_data.get("message", msg.content); details = tool_data.get("details"); warnings = tool_data.get("warnings");
141
- # Display flagged risks immediately if the tool signals it
142
- if tool_name_display == "flag_risk" and status == "flagged":
143
- st.error(f"🚨 **RISK FLAGGED:** {message}", icon="🚨") # Show flag in UI too
144
  elif status == "success" or status == "clear": st.success(f"{message}", icon="βœ…")
145
  elif status == "warning": st.warning(f"{message}", icon="⚠️");
146
  if warnings and isinstance(warnings, list): st.caption("Details:"); [st.caption(f"- {warn}") for warn in warnings]
@@ -154,37 +171,15 @@ def main():
154
  if not st.session_state.patient_data: st.warning("Please load patient data first."); st.stop()
155
  if 'agent' not in st.session_state or not st.session_state.agent: st.error("Agent not initialized. Check logs."); st.stop()
156
 
157
- # Append user message and display immediately
158
- user_message = HumanMessage(content=prompt)
159
- st.session_state.messages.append(user_message)
160
  with st.chat_message("user"): st.markdown(prompt)
161
-
162
- # Prepare state for the agent
163
- current_state_dict = {
164
- "messages": st.session_state.messages,
165
- "patient_data": st.session_state.patient_data,
166
- "summary": st.session_state.get("summary"),
167
- "interaction_warnings": None # Start clean
168
- }
169
-
170
- # Invoke the agent's graph for one turn
171
  with st.spinner("SynapseAI is processing..."):
172
  try:
173
- # Call the agent instance's method
174
  final_state = st.session_state.agent.invoke_turn(current_state_dict)
175
-
176
- # Update Streamlit session state from the returned agent state
177
  st.session_state.messages = final_state.get('messages', [])
178
  st.session_state.summary = final_state.get('summary')
179
-
180
- except Exception as e:
181
- print(f"CRITICAL ERROR during agent invocation: {type(e).__name__} - {e}")
182
- traceback.print_exc()
183
- st.error(f"An error occurred during processing: {e}", icon="❌")
184
- # Append error to messages for user visibility
185
- st.session_state.messages.append(AIMessage(content=f"Error during processing: {e}"))
186
-
187
- # Rerun Streamlit script to update the chat display
188
  st.rerun()
189
 
190
  # Disclaimer
 
6
  import traceback
7
  from dotenv import load_dotenv
8
 
9
+ # Import agent logic and message types
10
  try:
11
+ from agent import ClinicalAgent, AgentState, check_red_flags # Import necessary components
12
  from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
13
  except ImportError as e:
14
  st.error(f"Failed to import from agent.py: {e}. Make sure agent.py is in the same directory.")
 
46
  if "messages" not in st.session_state: st.session_state.messages = []
47
  if "patient_data" not in st.session_state: st.session_state.patient_data = None
48
  if "summary" not in st.session_state: st.session_state.summary = None
 
49
  if "agent" not in st.session_state:
50
  try:
51
  st.session_state.agent = ClinicalAgent()
52
  print("ClinicalAgent successfully initialized in Streamlit session state.")
53
  except Exception as e:
54
  st.error(f"Failed to initialize Clinical Agent: {e}. Check API keys and dependencies.")
55
+ print(f"ERROR Initializing ClinicalAgent: {e}"); traceback.print_exc(); st.stop()
 
 
 
56
 
57
  # --- Patient Data Input Sidebar ---
58
  with st.sidebar:
59
  st.header("πŸ“„ Patient Intake Form")
60
+ # Input fields... (Assume full fields as before)
61
  st.subheader("Demographics"); age = st.number_input("Age", 0, 120, 55, key="sb_age"); sex = st.selectbox("Sex", ["Male", "Female", "Other"], key="sb_sex")
62
  st.subheader("HPI"); chief_complaint = st.text_input("Chief Complaint", "Chest pain", key="sb_cc"); hpi_details = st.text_area("HPI Details", "55 y/o male...", height=100, key="sb_hpi"); symptoms = st.multiselect("Symptoms", ["Nausea", "Diaphoresis", "SOB", "Dizziness", "Severe Headache", "Syncope", "Hemoptysis"], default=["Nausea", "Diaphoresis"], key="sb_sym")
63
  st.subheader("History"); pmh = st.text_area("PMH", "HTN, HLD, DM2, History of MI", key="sb_pmh"); psh = st.text_area("PSH", "Appendectomy", key="sb_psh")
64
+ st.subheader("Meds & Allergies"); current_meds_str = st.text_area("Current Meds", "Lisinopril 10mg daily\nMetformin 1000mg BID\nWarfarin 5mg daily", key="sb_meds"); allergies_str = st.text_area("Allergies", "Penicillin (rash), Aspirin", key="sb_allergies")
65
  st.subheader("Social/Family"); social_history = st.text_area("SH", "Smoker", key="sb_sh"); family_history = st.text_area("FHx", "Father MI", key="sb_fhx")
66
  st.subheader("Vitals & Exam"); col1, col2 = st.columns(2);
67
  with col1: temp_c = st.number_input("Temp C", 35.0, 42.0, 36.8, format="%.1f", key="sb_temp"); hr_bpm = st.number_input("HR", 30, 250, 95, key="sb_hr"); rr_rpm = st.number_input("RR", 5, 50, 18, key="sb_rr")
 
88
  st.session_state.messages = [HumanMessage(content=initial_prompt)]
89
  st.session_state.summary = None # Reset summary
90
  st.success("Patient data loaded/updated.")
 
91
  st.rerun()
92
 
93
  # --- Main Chat Interface Area ---
 
105
  if prefix: st.markdown(prefix); structured_output = json.loads(json_str);
106
  if suffix: st.markdown(suffix)
107
  elif ai_content.strip().startswith("{") and ai_content.strip().endswith("}"): structured_output = json.loads(ai_content); ai_content = ""
108
+ else: st.markdown(ai_content)
109
  except Exception as e: st.markdown(ai_content); print(f"Error parsing/displaying AI JSON: {e}")
110
+
111
  if structured_output and isinstance(structured_output, dict): # Structured JSON display logic...
112
  st.divider(); st.subheader("πŸ“Š AI Analysis & Recommendations")
113
  cols = st.columns(2);
114
+ with cols[0]: # Assessment, DDx, Risk
115
+ st.markdown("**Assessment:**"); st.markdown(f"> {structured_output.get('assessment', 'N/A')}")
116
+ st.markdown("**Differential Diagnosis:**"); ddx = structured_output.get('differential_diagnosis', []);
117
+ if ddx: [st.expander(f"{'πŸ₯‡πŸ₯ˆπŸ₯‰'[('High','Medium','Low').index(item.get('likelihood','Low')[0])] if item.get('likelihood','?')[0] in 'HML' else '?'} {item.get('diagnosis', 'Unknown')} ({item.get('likelihood','?')})").write(f"**Rationale:** {item.get('rationale', 'N/A')}") for item in ddx]
118
+ else: st.info("No DDx provided.")
119
+
120
+ # Risk Assessment Display (CORRECTED - Separate lines)
121
+ st.markdown(f"**Risk Assessment:**")
122
+ risk = structured_output.get('risk_assessment', {})
123
+ flags = risk.get('identified_red_flags', [])
124
+ concerns = risk.get("immediate_concerns", [])
125
+ comps = risk.get("potential_complications", [])
126
+
127
+ if flags:
128
+ st.warning(f"**Flags:** {', '.join(flags)}")
129
+ if concerns:
130
+ st.warning(f"**Concerns:** {', '.join(concerns)}")
131
+ if comps:
132
+ st.info(f"**Potential Complications:** {', '.join(comps)}")
133
+ # Add a message if no risks were highlighted by the AI assessment
134
+ if not flags and not concerns and not comps:
135
+ st.success("No specific risks highlighted in this AI assessment.")
136
+
137
+ with cols[1]: # Plan
138
+ st.markdown("**Recommended Plan:**"); plan = structured_output.get('recommended_plan', {});
139
+ for section in ["investigations","therapeutics","consultations","patient_education"]: st.markdown(f"_{section.replace('_',' ').capitalize()}:_"); items = plan.get(section); [st.markdown(f"- {item}") for item in items] if items and isinstance(items, list) else (st.markdown(f"- {items}") if items else st.markdown("_None_")); st.markdown("")
140
+
141
+ # Rationale & Interaction Summary
142
+ st.markdown("**Rationale & Guideline Check:**"); st.markdown(f"> {structured_output.get('rationale_summary', 'N/A')}")
143
+ interaction_summary = structured_output.get("interaction_check_summary", "");
144
+ if interaction_summary: st.markdown("**Interaction Check Summary:**"); st.markdown(f"> {interaction_summary}");
145
+ st.divider()
146
 
147
  # Tool Call Display
148
  if getattr(msg, 'tool_calls', None):
 
157
  with st.chat_message(tool_name_display, avatar="πŸ› οΈ"):
158
  try: # Tool message display logic...
159
  tool_data = json.loads(msg.content); status = tool_data.get("status", "info"); message = tool_data.get("message", msg.content); details = tool_data.get("details"); warnings = tool_data.get("warnings");
160
+ if tool_name_display == "flag_risk" and status == "flagged": st.error(f"🚨 **RISK FLAGGED:** {message}", icon="🚨") # Show flag in UI too
 
 
161
  elif status == "success" or status == "clear": st.success(f"{message}", icon="βœ…")
162
  elif status == "warning": st.warning(f"{message}", icon="⚠️");
163
  if warnings and isinstance(warnings, list): st.caption("Details:"); [st.caption(f"- {warn}") for warn in warnings]
 
171
  if not st.session_state.patient_data: st.warning("Please load patient data first."); st.stop()
172
  if 'agent' not in st.session_state or not st.session_state.agent: st.error("Agent not initialized. Check logs."); st.stop()
173
 
174
+ user_message = HumanMessage(content=prompt); st.session_state.messages.append(user_message)
 
 
175
  with st.chat_message("user"): st.markdown(prompt)
176
+ current_state_dict = {"messages": st.session_state.messages, "patient_data": st.session_state.patient_data, "summary": st.session_state.get("summary"), "interaction_warnings": None}
 
 
 
 
 
 
 
 
 
177
  with st.spinner("SynapseAI is processing..."):
178
  try:
 
179
  final_state = st.session_state.agent.invoke_turn(current_state_dict)
 
 
180
  st.session_state.messages = final_state.get('messages', [])
181
  st.session_state.summary = final_state.get('summary')
182
+ except Exception as e: print(f"CRITICAL ERROR: {e}"); traceback.print_exc(); st.error(f"Error: {e}"); st.session_state.messages.append(AIMessage(content=f"Error processing request: {e}"))
 
 
 
 
 
 
 
 
183
  st.rerun()
184
 
185
  # Disclaimer