mgbam commited on
Commit
31ea2bf
Β·
verified Β·
1 Parent(s): 4258926

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +516 -294
app.py CHANGED
@@ -7,31 +7,32 @@ from langchain_core.pydantic_v1 import BaseModel, Field
7
  from langchain_core.tools import tool
8
  from langgraph.prebuilt import ToolExecutor
9
  from langgraph.graph import StateGraph, END
10
- from langgraph.checkpoint.memory import MemorySaver # For state persistence (optional but good)
11
 
12
  from typing import Optional, List, Dict, Any, TypedDict, Annotated
13
  import json
14
  import re
15
  import operator
 
16
 
17
- # --- Configuration & Constants --- (Keep previous ones like ClinicalAppSettings)
18
  class ClinicalAppSettings:
19
  APP_TITLE = "SynapseAI: Interactive Clinical Decision Support"
20
  PAGE_LAYOUT = "wide"
21
- MODEL_NAME = "llama3-70b-8192"
22
  TEMPERATURE = 0.1
23
  MAX_SEARCH_RESULTS = 3
24
 
25
  class ClinicalPrompts:
26
- # UPDATED SYSTEM PROMPT FOR CONVERSATIONAL FLOW & GUIDELINES
27
  SYSTEM_PROMPT = """
28
  You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation.
29
  Your goal is to support healthcare professionals by analyzing patient data, providing differential diagnoses, suggesting evidence-based management plans, and identifying risks according to current standards of care.
30
 
31
  **Core Directives for this Conversation:**
32
- 1. **Analyze Sequentially:** Process information turn-by-turn. You will receive initial patient data, and potentially follow-up messages or results from tools you requested. Base your responses on the *entire* conversation history.
33
  2. **Seek Clarity:** If the provided information is insufficient or ambiguous for a safe assessment, CLEARLY STATE what specific additional information or clarification is needed. Do NOT guess or make unsafe assumptions.
34
- 3. **Structured Assessment (When Ready):** When you have sufficient information and have performed necessary checks (like interactions), provide a comprehensive assessment using the following JSON structure. Only output this structure when you believe you have a complete initial analysis or plan. Do NOT output incomplete JSON.
35
  ```json
36
  {
37
  "assessment": "Concise summary of the patient's presentation and key findings based on the conversation.",
@@ -41,137 +42,212 @@ class ClinicalPrompts:
41
  {"diagnosis": "Alternative Diagnosis 2", "likelihood": "Low", "rationale": "Why it's less likely but considered..."}
42
  ],
43
  "risk_assessment": {
44
- "identified_red_flags": ["List any triggered red flags"],
45
- "immediate_concerns": ["Specific urgent issues (e.g., sepsis risk, ACS rule-out)"],
46
- "potential_complications": ["Possible future issues"]
47
  },
48
  "recommended_plan": {
49
- "investigations": ["List specific lab tests or imaging needed. Use 'order_lab_test' tool."],
50
- "therapeutics": ["Suggest specific treatments/prescriptions. Use 'prescribe_medication' tool. MUST check interactions first."],
51
- "consultations": ["Recommend specialist consultations."],
52
  "patient_education": ["Key points for patient communication."]
53
  },
54
- "rationale_summary": "Justification for assessment/plan. **Crucially, if relevant (e.g., ACS, sepsis, common infections), use 'tavily_search_results' to find and cite current clinical practice guidelines (e.g., 'latest ACC/AHA chest pain guidelines 202X', 'Surviving Sepsis Campaign guidelines') supporting your recommendations.**",
55
  "interaction_check_summary": "Summary of findings from 'check_drug_interactions' if performed."
56
  }
57
  ```
58
- 4. **Safety First - Interactions:** BEFORE suggesting a new prescription via `prescribe_medication`, you MUST FIRST use `check_drug_interactions`. Report the findings. If interactions exist, modify the plan or state the contraindication.
59
- 5. **Safety First - Red Flags:** Use the `flag_risk` tool IMMEDIATELY if critical red flags requiring urgent action are identified at any point.
60
- 6. **Tool Use:** Employ tools (`order_lab_test`, `prescribe_medication`, `check_drug_interactions`, `flag_risk`, `tavily_search_results`) logically within the conversational flow. Wait for tool results before proceeding if the result is needed for the next step (e.g., wait for interaction check before confirming prescription).
61
  7. **Evidence & Guidelines:** Actively use `tavily_search_results` not just for general knowledge, but specifically to query for and incorporate **current clinical practice guidelines** relevant to the patient's presentation (e.g., chest pain, shortness of breath, suspected infection). Summarize findings in the `rationale_summary` when providing the structured output.
62
- 8. **Conciseness:** Be medically accurate and concise. Use standard terminology. Respond naturally in conversation until ready for the full structured JSON output.
63
  """
64
 
65
- # --- Mock Data / Helpers --- (Keep previous ones like MOCK_INTERACTION_DB, ALLERGY_INTERACTIONS, parse_bp, check_red_flags)
66
- # (Include the helper functions from the previous response here)
67
  MOCK_INTERACTION_DB = {
68
  ("lisinopril", "spironolactone"): "High risk of hyperkalemia. Monitor potassium closely.",
69
  ("warfarin", "amiodarone"): "Increased bleeding risk. Monitor INR frequently and adjust Warfarin dose.",
70
  ("simvastatin", "clarithromycin"): "Increased risk of myopathy/rhabdomyolysis. Avoid combination or use lower statin dose.",
71
- ("aspirin", "ibuprofen"): "Concurrent use may decrease Aspirin's cardioprotective effect. Potential for increased GI bleeding."
 
 
 
 
 
72
  }
73
 
74
  ALLERGY_INTERACTIONS = {
75
- "penicillin": ["amoxicillin", "ampicillin", "piperacillin"],
76
- "sulfa": ["sulfamethoxazole", "sulfasalazine"],
77
- "aspirin": ["ibuprofen", "naproxen"] # Cross-reactivity example for NSAIDs
78
  }
79
 
80
  def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
81
- match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string)
82
- if match: return int(match.group(1)), int(match.group(2))
 
 
 
83
  return None
84
 
85
  def check_red_flags(patient_data: dict) -> List[str]:
 
86
  flags = []
 
 
87
  symptoms = patient_data.get("hpi", {}).get("symptoms", [])
88
  vitals = patient_data.get("vitals", {})
89
  history = patient_data.get("pmh", {}).get("conditions", "")
90
- symptoms_lower = [s.lower() for s in symptoms]
 
91
 
 
92
  if "chest pain" in symptoms_lower: flags.append("Red Flag: Chest Pain reported.")
93
  if "shortness of breath" in symptoms_lower: flags.append("Red Flag: Shortness of Breath reported.")
94
  if "severe headache" in symptoms_lower: flags.append("Red Flag: Severe Headache reported.")
95
- # Add other symptom checks...
96
-
97
- if "temp_c" in vitals and vitals["temp_c"] >= 38.5: flags.append(f"Red Flag: Fever ({vitals['temp_c']}Β°C).")
98
- if "hr_bpm" in vitals and vitals["hr_bpm"] >= 120: flags.append(f"Red Flag: Tachycardia ({vitals['hr_bpm']} bpm).")
99
- if "bp_mmhg" in vitals:
100
- bp = parse_bp(vitals["bp_mmhg"])
101
- if bp and (bp[0] >= 180 or bp[1] >= 110): flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {vitals['bp_mmhg']} mmHg).")
102
- if bp and (bp[0] <= 90 or bp[1] <= 60): flags.append(f"Red Flag: Hypotension (BP: {vitals['bp_mmhg']} mmHg).")
103
- # Add other vital checks...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- if "history of mi" in history.lower() and "chest pain" in symptoms_lower: flags.append("Red Flag: History of MI with current Chest Pain.")
106
- # Add other history checks...
107
- return flags
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
 
110
- # --- Enhanced Tool Definitions --- (Keep previous Pydantic models and @tool functions)
111
- # (Include LabOrderInput, PrescriptionInput, InteractionCheckInput, FlagRiskInput
112
- # and the corresponding @tool functions: order_lab_test, prescribe_medication,
113
- # check_drug_interactions, flag_risk from the previous response here)
114
 
 
115
  class LabOrderInput(BaseModel):
116
- test_name: str = Field(..., description="Specific name of the lab test or panel (e.g., 'CBC', 'BMP', 'Troponin I', 'Urinalysis').")
117
- reason: str = Field(..., description="Clinical justification for ordering the test (e.g., 'Rule out infection', 'Assess renal function', 'Evaluate for ACS').")
118
  priority: str = Field("Routine", description="Priority of the test (e.g., 'STAT', 'Routine').")
119
 
120
  @tool("order_lab_test", args_schema=LabOrderInput)
121
  def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
122
  """Orders a specific lab test with clinical justification and priority."""
123
- return json.dumps({"status": "success", "message": f"Lab Ordered: {test_name} ({priority})", "details": f"Reason: {reason}"})
 
 
 
 
 
 
124
 
125
  class PrescriptionInput(BaseModel):
126
  medication_name: str = Field(..., description="Name of the medication.")
127
- dosage: str = Field(..., description="Dosage amount and unit (e.g., '500 mg', '10 mg').")
128
- route: str = Field(..., description="Route of administration (e.g., 'PO', 'IV', 'IM', 'Topical').")
129
- frequency: str = Field(..., description="How often the medication should be taken (e.g., 'BID', 'QDaily', 'Q4-6H PRN').")
130
- duration: str = Field("As directed", description="Duration of treatment (e.g., '7 days', '1 month', 'Until follow-up').")
131
  reason: str = Field(..., description="Clinical indication for the prescription.")
132
 
133
  @tool("prescribe_medication", args_schema=PrescriptionInput)
134
  def prescribe_medication(medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str) -> str:
135
- """Prescribes a medication with detailed instructions and clinical indication."""
136
- # NOTE: Interaction check should have been done *before* calling this via a separate tool call
137
- return json.dumps({"status": "success", "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", "details": f"Duration: {duration}. Reason: {reason}"})
 
 
 
 
 
 
138
 
139
  class InteractionCheckInput(BaseModel):
140
- potential_prescription: str = Field(..., description="The name of the NEW medication being considered.")
141
- current_medications: List[str] = Field(..., description="List of the patient's CURRENT medication names.")
142
- allergies: List[str] = Field(..., description="List of the patient's known allergies.")
 
 
 
 
 
143
 
144
  @tool("check_drug_interactions", args_schema=InteractionCheckInput)
145
- def check_drug_interactions(potential_prescription: str, current_medications: List[str], allergies: List[str]) -> str:
146
  """Checks for potential drug-drug and drug-allergy interactions BEFORE prescribing."""
 
147
  warnings = []
148
  potential_med_lower = potential_prescription.lower()
149
- current_meds_lower = [med.lower() for med in current_medications]
150
- allergies_lower = [a.lower() for a in allergies]
151
 
 
 
 
 
 
 
 
 
 
 
 
152
  for allergy in allergies_lower:
 
153
  if allergy == potential_med_lower:
154
- warnings.append(f"CRITICAL ALLERGY: Patient allergic to {allergy}. Cannot prescribe {potential_prescription}.")
155
- continue
 
156
  if allergy in ALLERGY_INTERACTIONS:
157
  for cross_reactant in ALLERGY_INTERACTIONS[allergy]:
158
  if cross_reactant.lower() == potential_med_lower:
159
- warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to {allergy}. High risk with {potential_prescription}.")
160
 
 
161
  for current_med in current_meds_lower:
 
162
  pair1 = (current_med, potential_med_lower)
163
  pair2 = (potential_med_lower, current_med)
164
- # Normalize keys for lookup if necessary (e.g., if DB keys are canonical names)
165
- key1 = tuple(sorted(pair1))
166
- key2 = tuple(sorted(pair2)) # Although redundant if always sorted
167
-
168
- if pair1 in MOCK_INTERACTION_DB:
169
- warnings.append(f"Interaction: {potential_prescription.capitalize()} with {current_med.capitalize()} - {MOCK_INTERACTION_DB[pair1]}")
170
- elif pair2 in MOCK_INTERACTION_DB:
171
- warnings.append(f"Interaction: {potential_prescription.capitalize()} with {current_med.capitalize()} - {MOCK_INTERACTION_DB[pair2]}")
172
 
173
  status = "warning" if warnings else "clear"
174
- message = f"Interaction check for {potential_prescription}: {len(warnings)} potential issue(s) found." if warnings else f"No major interactions identified for {potential_prescription}."
 
175
  return json.dumps({"status": status, "message": message, "warnings": warnings})
176
 
177
 
@@ -182,21 +258,26 @@ class FlagRiskInput(BaseModel):
182
  @tool("flag_risk", args_schema=FlagRiskInput)
183
  def flag_risk(risk_description: str, urgency: str) -> str:
184
  """Flags a critical risk identified during analysis for immediate attention."""
 
185
  # Display in Streamlit immediately
186
  st.error(f"🚨 **{urgency.upper()} RISK FLAGGED by AI:** {risk_description}", icon="🚨")
187
- return json.dumps({"status": "flagged", "message": f"Risk '{risk_description}' flagged with {urgency} urgency."})
 
 
 
188
 
189
  # Initialize Search Tool
190
- search_tool = TavilySearchResults(max_results=ClinicalAppSettings.MAX_SEARCH_RESULTS, name="tavily_search_results")
191
-
 
 
192
 
193
  # --- LangGraph Setup ---
194
 
195
  # Define the state structure
196
  class AgentState(TypedDict):
197
  messages: Annotated[list[Any], operator.add] # Accumulates messages (Human, AI, Tool)
198
- patient_data: Optional[dict] # Holds the structured patient data (can be updated if needed)
199
- # Potentially add other state elements like 'interaction_check_needed_for': Optional[str]
200
 
201
  # Define Tools and Tool Executor
202
  tools = [
@@ -211,99 +292,212 @@ tool_executor = ToolExecutor(tools)
211
  # Define the Agent Model
212
  model = ChatGroq(
213
  temperature=ClinicalAppSettings.TEMPERATURE,
214
- model=ClinicalAppSettings.MODEL_NAME
 
 
215
  )
216
- model_with_tools = model.bind_tools(tools) # Bind tools for the LLM to know about them
 
217
 
218
  # --- Graph Nodes ---
219
 
220
  # 1. Agent Node: Calls the LLM
221
  def agent_node(state: AgentState):
222
  """Invokes the LLM to decide the next action or response."""
223
- print("---AGENT NODE---")
224
- # Make sure patient data is included in the first message if not already there
225
- # This is a basic way; more robust would be merging patient_data into context
226
  current_messages = state['messages']
227
- if len(current_messages) == 1 and isinstance(current_messages[0], HumanMessage) and state.get('patient_data'):
228
- # Augment the first human message with formatted patient data
229
- formatted_data = format_patient_data_for_prompt(state['patient_data']) # Need this helper function
230
- current_messages = [
231
- SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT), # Ensure system prompt is first
232
- HumanMessage(content=f"{current_messages[0].content}\n\n**Initial Patient Data:**\n{formatted_data}")
233
- ]
234
- elif not any(isinstance(m, SystemMessage) for m in current_messages):
235
- # Add system prompt if missing
236
- current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages
237
-
238
-
239
- response = model_with_tools.invoke(current_messages)
240
- print(f"Agent response: {response}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  return {"messages": [response]}
242
 
243
- # 2. Tool Node: Executes tools called by the Agent
244
  def tool_node(state: AgentState):
245
  """Executes tools called by the LLM and returns results."""
246
- print("---TOOL NODE---")
 
247
  last_message = state['messages'][-1]
248
- if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
249
- print("No tool calls in last message.")
250
- return {} # Should not happen if routing is correct, but safety check
 
 
 
 
 
251
 
252
  tool_calls = last_message.tool_calls
253
- tool_messages = []
254
 
255
- # Safety Check: Ensure interaction check happens *before* prescribing the *same* drug
256
- prescribe_calls = {call['args'].get('medication_name'): call['id'] for call in tool_calls if call['name'] == 'prescribe_medication'}
257
- interaction_check_calls = {call['args'].get('potential_prescription'): call['id'] for call in tool_calls if call['name'] == 'check_drug_interactions'}
258
 
259
- for med_name, prescribe_call_id in prescribe_calls.items():
260
- if med_name not in interaction_check_calls:
261
- st.error(f"**Safety Violation:** AI attempted to prescribe '{med_name}' without requesting `check_drug_interactions` in the *same turn*. Prescription blocked for this turn.")
262
- # Create an error ToolMessage to send back to the LLM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  error_msg = ToolMessage(
264
- content=json.dumps({"status": "error", "message": f"Interaction check for {med_name} must be requested *before or alongside* the prescription call."}),
265
- tool_call_id=prescribe_call_id
 
266
  )
267
  tool_messages.append(error_msg)
268
- # Remove the invalid prescribe call to prevent execution
269
- tool_calls = [call for call in tool_calls if call['id'] != prescribe_call_id]
 
270
 
 
 
 
271
 
272
- # Add patient context to interaction checks if needed
273
  patient_meds = state.get("patient_data", {}).get("medications", {}).get("names_only", [])
274
  patient_allergies = state.get("patient_data", {}).get("allergies", [])
275
- for call in tool_calls:
 
276
  if call['name'] == 'check_drug_interactions':
 
 
277
  call['args']['current_medications'] = patient_meds
278
  call['args']['allergies'] = patient_allergies
279
- print(f"Augmented interaction check args: {call['args']}")
280
-
281
-
282
- # Execute remaining valid tool calls
283
- if tool_calls:
284
- responses = tool_executor.batch(tool_calls)
285
- # Responses is a list of tool outputs corresponding to tool_calls
286
- # We need to create ToolMessage objects
287
- tool_messages.extend([
288
- ToolMessage(content=str(resp), tool_call_id=call['id'])
289
- for call, resp in zip(tool_calls, responses)
290
- ])
291
- print(f"Tool results: {tool_messages}")
292
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  return {"messages": tool_messages}
294
 
295
 
296
  # --- Graph Edges (Routing Logic) ---
297
  def should_continue(state: AgentState) -> str:
298
- """Determines whether to continue the loop or end."""
299
- last_message = state['messages'][-1]
300
- # If the LLM made tool calls, we execute them
301
- if isinstance(last_message, AIMessage) and last_message.tool_calls:
302
- print("Routing: continue_tools")
 
 
 
 
 
 
 
 
 
 
 
 
303
  return "continue_tools"
304
- # Otherwise, we end the loop (AI provided a direct answer or finished)
305
  else:
306
- print("Routing: end_conversation_turn")
307
  return "end_conversation_turn"
308
 
309
  # --- Graph Definition ---
@@ -322,7 +516,7 @@ workflow.add_conditional_edges(
322
  should_continue, # Function to decide the route
323
  {
324
  "continue_tools": "tools", # If tool calls exist, go to tools node
325
- "end_conversation_turn": END # Otherwise, end the graph iteration
326
  }
327
  )
328
 
@@ -333,61 +527,48 @@ workflow.add_edge("tools", "agent")
333
  # memory = MemorySaverInMemory() # Optional: for persisting state across runs
334
  # app = workflow.compile(checkpointer=memory)
335
  app = workflow.compile()
 
336
 
337
- # --- Helper Function to Format Patient Data ---
338
- def format_patient_data_for_prompt(data: dict) -> str:
339
- """Formats the patient dictionary into a readable string for the LLM."""
340
- prompt_str = ""
341
- for key, value in data.items():
342
- if isinstance(value, dict):
343
- section_title = key.replace('_', ' ').title()
344
- prompt_str += f"**{section_title}:**\n"
345
- for sub_key, sub_value in value.items():
346
- if sub_value:
347
- prompt_str += f" - {sub_key.replace('_', ' ').title()}: {sub_value}\n"
348
- elif isinstance(value, list) and value:
349
- prompt_str += f"**{key.replace('_', ' ').title()}:** {', '.join(map(str, value))}\n"
350
- elif value:
351
- prompt_str += f"**{key.replace('_', ' ').title()}:** {value}\n"
352
- return prompt_str.strip()
353
-
354
- # --- Streamlit UI (Modified for Conversation) ---
355
  def main():
356
  st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
357
  st.title(f"🩺 {ClinicalAppSettings.APP_TITLE}")
358
  st.caption(f"Interactive Assistant | Powered by Langchain/LangGraph & Groq ({ClinicalAppSettings.MODEL_NAME})")
359
 
360
- # Initialize session state for conversation
361
  if "messages" not in st.session_state:
362
- st.session_state.messages = [] # Store entire conversation history (Human, AI, Tool)
363
  if "patient_data" not in st.session_state:
364
  st.session_state.patient_data = None
365
- if "initial_analysis_done" not in st.session_state:
366
- st.session_state.initial_analysis_done = False
367
  if "graph_app" not in st.session_state:
368
- st.session_state.graph_app = app # Store compiled graph
369
 
370
- # --- Patient Data Input Sidebar --- (Similar to before)
371
  with st.sidebar:
372
  st.header("πŸ“„ Patient Intake Form")
373
- # ... (Keep the input fields exactly as in the previous example) ...
374
  # Demographics
 
375
  age = st.number_input("Age", min_value=0, max_value=120, value=55, key="age_input")
376
  sex = st.selectbox("Biological Sex", ["Male", "Female", "Other/Prefer not to say"], key="sex_input")
377
  # HPI
 
378
  chief_complaint = st.text_input("Chief Complaint", "Chest pain", key="cc_input")
379
- hpi_details = st.text_area("Detailed HPI", "55 y/o male presents with substernal chest pain started 2 hours ago...", key="hpi_input")
380
- symptoms = st.multiselect("Associated Symptoms", ["Nausea", "Diaphoresis", "Shortness of Breath", "Dizziness", "Palpitations", "Fever", "Cough"], default=["Nausea", "Diaphoresis"], key="sym_input")
381
  # History
382
- pmh = st.text_area("Past Medical History (PMH)", "Hypertension (HTN), Hyperlipidemia (HLD), Type 2 Diabetes Mellitus (DM2)", key="pmh_input")
 
383
  psh = st.text_area("Past Surgical History (PSH)", "Appendectomy (2005)", key="psh_input")
384
  # Meds & Allergies
 
385
  current_meds_str = st.text_area("Current Medications (name, dose, freq)", "Lisinopril 10mg daily\nMetformin 1000mg BID\nAtorvastatin 40mg daily\nAspirin 81mg daily", key="meds_input")
386
- allergies_str = st.text_area("Allergies (comma separated)", "Penicillin (rash)", key="allergy_input")
387
  # Social/Family
 
388
  social_history = st.text_area("Social History (SH)", "Smoker (1 ppd x 30 years), occasional alcohol.", key="sh_input")
389
  family_history = st.text_area("Family History (FHx)", "Father had MI at age 60. Mother has HTN.", key="fhx_input")
390
  # Vitals/Exam
 
391
  col1, col2 = st.columns(2)
392
  with col1:
393
  temp_c = st.number_input("Temp (Β°C)", 35.0, 42.0, 36.8, format="%.1f", key="temp_input")
@@ -397,19 +578,28 @@ def main():
397
  bp_mmhg = st.text_input("BP (SYS/DIA)", "155/90", key="bp_input")
398
  spo2_percent = st.number_input("SpO2 (%)", 70, 100, 96, key="spo2_input")
399
  pain_scale = st.slider("Pain (0-10)", 0, 10, 8, key="pain_input")
400
- exam_notes = st.text_area("Brief Physical Exam Notes", "Awake, alert, oriented x3...", key="exam_input")
401
 
402
  # Compile Patient Data Dictionary on button press
403
  if st.button("Start/Update Consultation", key="start_button"):
404
  current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
 
405
  current_med_names = []
406
- # Improved parsing for names (still basic, assumes name is first word)
407
  for med in current_meds_list:
408
  match = re.match(r"^\s*([a-zA-Z\-]+)", med)
409
  if match:
410
- current_med_names.append(match.group(1).lower()) # Use lower case for matching
411
-
412
- allergies_list = [a.strip().lower() for a in allergies_str.split(',') if a.strip()] # Lowercase allergies
 
 
 
 
 
 
 
 
 
413
 
414
  st.session_state.patient_data = {
415
  "demographics": {"age": age, "sex": sex},
@@ -424,75 +614,131 @@ def main():
424
 
425
  # Initial Red Flag Check (Client-side)
426
  red_flags = check_red_flags(st.session_state.patient_data)
 
427
  if red_flags:
428
- st.warning("**Initial Red Flags Detected:**")
429
- for flag in red_flags: st.warning(f"- {flag}")
 
 
430
 
431
  # Prepare initial message for the graph
432
- initial_prompt = f"Analyze the following patient case:\nChief Complaint: {chief_complaint}\nSummary: {age} y/o {sex} presenting with..." # Keep it brief, full data is in state
 
433
  st.session_state.messages = [HumanMessage(content=initial_prompt)]
434
- st.session_state.initial_analysis_done = False # Reset analysis state
435
  st.success("Patient data loaded. Ready for analysis.")
436
- st.rerun() # Refresh main area to show chat
437
-
438
 
439
  # --- Main Chat Interface Area ---
440
  st.header("πŸ’¬ Clinical Consultation")
441
 
442
- # Display chat messages
443
- for msg in st.session_state.messages:
 
444
  if isinstance(msg, HumanMessage):
445
- with st.chat_message("user"):
446
  st.markdown(msg.content)
447
  elif isinstance(msg, AIMessage):
448
- with st.chat_message("assistant"):
449
- # Check for structured JSON output
 
450
  structured_output = None
 
 
451
  try:
452
- # Try to find JSON block first
453
- json_match = re.search(r"```json\n(\{.*?\})\n```", msg.content, re.DOTALL)
454
  if json_match:
455
- structured_output = json.loads(json_match.group(1))
456
- # Display non-JSON parts if any
457
- non_json_content = msg.content.replace(json_match.group(0), "").strip()
458
- if non_json_content:
459
- st.markdown(non_json_content)
460
- st.divider() # Separate text from structured output visually
461
- elif msg.content.strip().startswith("{") and msg.content.strip().endswith("}"):
462
- # Maybe the whole message is JSON
463
- structured_output = json.loads(msg.content)
 
 
464
  else:
465
- # No JSON found, display raw content
466
- st.markdown(msg.content)
467
-
468
- if structured_output:
469
- # Display the structured data nicely (reuse parts of previous UI display logic)
470
- st.subheader("πŸ“Š AI Analysis & Recommendations")
471
- # ... (Add logic here to display assessment, ddx, plan etc. from structured_output)
472
- # Example:
473
- st.write(f"**Assessment:** {structured_output.get('assessment', 'N/A')}")
474
- # Display DDx, Plan etc. using expanders or tabs
475
- # ...
476
- # Display Rationale & Interaction Summary
477
- with st.expander("Rationale & Guideline Check"):
478
- st.write(structured_output.get("rationale_summary", "N/A"))
479
- if structured_output.get("interaction_check_summary"):
480
- with st.expander("Interaction Check"):
481
- st.write(structured_output.get("interaction_check_summary"))
482
-
483
 
484
  except json.JSONDecodeError:
485
- st.markdown(msg.content) # Display raw if JSON parsing fails
486
-
487
- # Display tool calls if any were made in this AI turn
488
- if msg.tool_calls:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  with st.expander("πŸ› οΈ AI requested actions", expanded=False):
490
  for tc in msg.tool_calls:
491
- st.code(f"{tc['name']}(args={tc['args']})", language="python")
 
 
 
 
 
492
 
493
  elif isinstance(msg, ToolMessage):
494
- with st.chat_message("tool", avatar="πŸ› οΈ"):
 
 
495
  try:
 
496
  tool_data = json.loads(msg.content)
497
  status = tool_data.get("status", "info")
498
  message = tool_data.get("message", msg.content)
@@ -500,97 +746,73 @@ def main():
500
  warnings = tool_data.get("warnings")
501
 
502
  if status == "success" or status == "clear" or status == "flagged":
503
- st.success(f"Tool Result ({msg.name}): {message}", icon="βœ…" if status != "flagged" else "🚨")
504
  elif status == "warning":
505
- st.warning(f"Tool Result ({msg.name}): {message}", icon="⚠️")
506
- if warnings:
 
507
  for warn in warnings: st.caption(f"- {warn}")
508
  else: # Error or unknown status
509
- st.error(f"Tool Result ({msg.name}): {message}", icon="❌")
510
 
511
  if details: st.caption(f"Details: {details}")
512
 
513
  except json.JSONDecodeError:
514
- st.info(f"Tool Result ({msg.name}): {msg.content}") # Display raw if not JSON
515
-
 
 
 
516
 
517
- # Chat input for user
518
  if prompt := st.chat_input("Your message or follow-up query..."):
519
  if not st.session_state.patient_data:
520
  st.warning("Please load patient data using the sidebar first.")
521
- else:
522
- # Add user message to state
523
- st.session_state.messages.append(HumanMessage(content=prompt))
524
- with st.chat_message("user"):
525
- st.markdown(prompt)
526
-
527
- # Prepare state for graph invocation
528
- current_state = AgentState(
529
- messages=st.session_state.messages,
530
- patient_data=st.session_state.patient_data
531
- )
532
-
533
- # Stream graph execution
534
- with st.chat_message("assistant"):
535
- message_placeholder = st.empty()
536
- full_response = ""
537
-
538
- # Use stream to get intermediate steps (optional but good for UX)
539
- # This shows AI thinking and tool calls/results progressively
540
- try:
541
- for event in st.session_state.graph_app.stream(current_state, {"recursion_limit": 15}):
542
- # event is a dictionary, keys are node names
543
- if "agent" in event:
544
- ai_msg = event["agent"]["messages"][-1] # Get the latest AI message
545
- if isinstance(ai_msg, AIMessage):
546
- full_response += ai_msg.content # Append content for final display
547
- message_placeholder.markdown(full_response + "β–Œ") # Show typing indicator
548
-
549
- # Display tool calls as they happen (optional)
550
- # if ai_msg.tool_calls:
551
- # st.info(f"Requesting tools: {[tc['name'] for tc in ai_msg.tool_calls]}")
552
-
553
- elif "tools" in event:
554
- # Display tool results as they come back (optional, already handled by message display loop)
555
- pass
556
- # tool_msgs = event["tools"]["messages"]
557
- # for tool_msg in tool_msgs:
558
- # st.info(f"Tool {tool_msg.name} result received.")
559
-
560
-
561
- # Final display after streaming
562
- message_placeholder.markdown(full_response)
563
-
564
-
565
- # Update session state with the final messages from the graph run
566
- # The graph state itself isn't directly accessible after streaming finishes easily this way
567
- # We need to get the final state if we used invoke, or reconstruct from stream events
568
- # A simpler way for now: just append the *last* AI message and any Tool messages from the stream
569
- # This assumes the stream provides the final state implicitly. For robust state, use invoke or checkpointer.
570
-
571
- # A more robust way: invoke and get final state
572
- # final_state = st.session_state.graph_app.invoke(current_state, {"recursion_limit": 15})
573
- # st.session_state.messages = final_state['messages']
574
- # --- Let's stick to appending for simplicity in this example ---
575
- # Find the last AI message and tool messages from the stream (needs careful event parsing)
576
- # Or, re-run invoke non-streamed just to get final state (less efficient)
577
- final_state_capture = st.session_state.graph_app.invoke(current_state, {"recursion_limit": 15})
578
- st.session_state.messages = final_state_capture['messages']
579
-
580
-
581
- except Exception as e:
582
- st.error(f"An error occurred during analysis: {e}")
583
- # Attempt to add the error message to the history
584
- st.session_state.messages.append(AIMessage(content=f"Sorry, an error occurred: {e}"))
585
-
586
-
587
- # Rerun to display the updated chat history correctly
588
- st.rerun()
589
-
590
-
591
- # Disclaimer
592
  st.markdown("---")
593
- st.warning("**Disclaimer:** SynapseAI is for clinical decision support...") # Keep disclaimer
 
 
 
 
594
 
595
  if __name__ == "__main__":
596
  main()
 
7
  from langchain_core.tools import tool
8
  from langgraph.prebuilt import ToolExecutor
9
  from langgraph.graph import StateGraph, END
10
+ # from langgraph.checkpoint.memory import MemorySaverInMemory # Optional for state persistence
11
 
12
  from typing import Optional, List, Dict, Any, TypedDict, Annotated
13
  import json
14
  import re
15
  import operator
16
+ import traceback # For detailed error logging
17
 
18
+ # --- Configuration & Constants ---
19
  class ClinicalAppSettings:
20
  APP_TITLE = "SynapseAI: Interactive Clinical Decision Support"
21
  PAGE_LAYOUT = "wide"
22
+ MODEL_NAME = "llama3-70b-8192" # Groq Llama3 70b
23
  TEMPERATURE = 0.1
24
  MAX_SEARCH_RESULTS = 3
25
 
26
  class ClinicalPrompts:
27
+ # UPDATED SYSTEM PROMPT FOR CONVERSATIONAL FLOW, GUIDELINES & STRUCTURED OUTPUT FOCUS
28
  SYSTEM_PROMPT = """
29
  You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation.
30
  Your goal is to support healthcare professionals by analyzing patient data, providing differential diagnoses, suggesting evidence-based management plans, and identifying risks according to current standards of care.
31
 
32
  **Core Directives for this Conversation:**
33
+ 1. **Analyze Sequentially:** Process information turn-by-turn. Base your responses on the *entire* conversation history.
34
  2. **Seek Clarity:** If the provided information is insufficient or ambiguous for a safe assessment, CLEARLY STATE what specific additional information or clarification is needed. Do NOT guess or make unsafe assumptions.
35
+ 3. **Structured Assessment (When Ready):** When you have sufficient information and have performed necessary checks (like interactions, guideline searches), provide a comprehensive assessment using the following JSON structure. Output this JSON structure as the primary content of your response when you are providing the full analysis. Do NOT output incomplete JSON. If you need to ask a question or perform a tool call first, do that instead of outputting this structure.
36
  ```json
37
  {
38
  "assessment": "Concise summary of the patient's presentation and key findings based on the conversation.",
 
42
  {"diagnosis": "Alternative Diagnosis 2", "likelihood": "Low", "rationale": "Why it's less likely but considered..."}
43
  ],
44
  "risk_assessment": {
45
+ "identified_red_flags": ["List any triggered red flags based on input and analysis"],
46
+ "immediate_concerns": ["Specific urgent issues requiring attention (e.g., sepsis risk, ACS rule-out)"],
47
+ "potential_complications": ["Possible future issues based on presentation"]
48
  },
49
  "recommended_plan": {
50
+ "investigations": ["List specific lab tests or imaging required. Use 'order_lab_test' tool."],
51
+ "therapeutics": ["Suggest specific treatments or prescriptions. Use 'prescribe_medication' tool. MUST check interactions first using 'check_drug_interactions'."],
52
+ "consultations": ["Recommend specialist consultations if needed."],
53
  "patient_education": ["Key points for patient communication."]
54
  },
55
+ "rationale_summary": "Justification for assessment/plan. **Crucially, if relevant (e.g., ACS, sepsis, common infections), use 'tavily_search_results' to find and cite current clinical practice guidelines (e.g., 'latest ACC/AHA chest pain guidelines 202X', 'Surviving Sepsis Campaign guidelines') supporting your recommendations.** Include summary of guideline findings here.",
56
  "interaction_check_summary": "Summary of findings from 'check_drug_interactions' if performed."
57
  }
58
  ```
59
+ 4. **Safety First - Interactions:** BEFORE suggesting a new prescription via `prescribe_medication`, you MUST FIRST use `check_drug_interactions` in a preceding or concurrent tool call. Report the findings from the interaction check. If significant interactions exist, modify the plan or state the contraindication clearly.
60
+ 5. **Safety First - Red Flags:** Use the `flag_risk` tool IMMEDIATELY if critical red flags requiring urgent action are identified at any point in the conversation.
61
+ 6. **Tool Use:** Employ tools (`order_lab_test`, `prescribe_medication`, `check_drug_interactions`, `flag_risk`, `tavily_search_results`) logically within the conversational flow. Wait for tool results before proceeding if the result is needed for the next step (e.g., wait for interaction check before confirming prescription in the structured JSON).
62
  7. **Evidence & Guidelines:** Actively use `tavily_search_results` not just for general knowledge, but specifically to query for and incorporate **current clinical practice guidelines** relevant to the patient's presentation (e.g., chest pain, shortness of breath, suspected infection). Summarize findings in the `rationale_summary` when providing the structured output.
63
+ 8. **Conciseness & Flow:** Be medically accurate and concise. Use standard terminology. Respond naturally in conversation (asking questions, acknowledging info) until ready for the full structured JSON output.
64
  """
65
 
66
+ # --- Mock Data / Helpers ---
 
67
  MOCK_INTERACTION_DB = {
68
  ("lisinopril", "spironolactone"): "High risk of hyperkalemia. Monitor potassium closely.",
69
  ("warfarin", "amiodarone"): "Increased bleeding risk. Monitor INR frequently and adjust Warfarin dose.",
70
  ("simvastatin", "clarithromycin"): "Increased risk of myopathy/rhabdomyolysis. Avoid combination or use lower statin dose.",
71
+ ("aspirin", "ibuprofen"): "Concurrent use may decrease Aspirin's cardioprotective effect. Potential for increased GI bleeding.",
72
+ # Add lower case versions for easier lookup
73
+ ("amiodarone", "warfarin"): "Increased bleeding risk. Monitor INR frequently and adjust Warfarin dose.",
74
+ ("clarithromycin", "simvastatin"): "Increased risk of myopathy/rhabdomyolysis. Avoid combination or use lower statin dose.",
75
+ ("ibuprofen", "aspirin"): "Concurrent use may decrease Aspirin's cardioprotective effect. Potential for increased GI bleeding.",
76
+ ("spironolactone", "lisinopril"): "High risk of hyperkalemia. Monitor potassium closely.",
77
  }
78
 
79
  ALLERGY_INTERACTIONS = {
80
+ "penicillin": ["amoxicillin", "ampicillin", "piperacillin", "augmentin"],
81
+ "sulfa": ["sulfamethoxazole", "sulfasalazine", "bactrim"],
82
+ "aspirin": ["ibuprofen", "naproxen", "nsaid"] # Cross-reactivity example for NSAIDs
83
  }
84
 
85
  def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
86
+ """Parses BP string like '120/80' into (systolic, diastolic) integers."""
87
+ if not isinstance(bp_string, str): return None
88
+ match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string.strip())
89
+ if match:
90
+ return int(match.group(1)), int(match.group(2))
91
  return None
92
 
93
  def check_red_flags(patient_data: dict) -> List[str]:
94
+ """Checks patient data against predefined red flags."""
95
  flags = []
96
+ if not patient_data: return flags
97
+
98
  symptoms = patient_data.get("hpi", {}).get("symptoms", [])
99
  vitals = patient_data.get("vitals", {})
100
  history = patient_data.get("pmh", {}).get("conditions", "")
101
+ # Ensure symptoms are strings and lowercased
102
+ symptoms_lower = [str(s).lower() for s in symptoms if isinstance(s, str)]
103
 
104
+ # Symptom Flags
105
  if "chest pain" in symptoms_lower: flags.append("Red Flag: Chest Pain reported.")
106
  if "shortness of breath" in symptoms_lower: flags.append("Red Flag: Shortness of Breath reported.")
107
  if "severe headache" in symptoms_lower: flags.append("Red Flag: Severe Headache reported.")
108
+ if "sudden vision loss" in symptoms_lower: flags.append("Red Flag: Sudden Vision Loss reported.")
109
+ if "weakness on one side" in symptoms_lower: flags.append("Red Flag: Unilateral Weakness reported (potential stroke).")
110
+ if "hemoptysis" in symptoms_lower: flags.append("Red Flag: Hemoptysis (coughing up blood).")
111
+ if "syncope" in symptoms_lower: flags.append("Red Flag: Syncope (fainting).")
112
+
113
+ # Vital Sign Flags
114
+ if vitals:
115
+ temp = vitals.get("temp_c")
116
+ hr = vitals.get("hr_bpm")
117
+ rr = vitals.get("rr_rpm")
118
+ spo2 = vitals.get("spo2_percent")
119
+ bp_str = vitals.get("bp_mmhg")
120
+
121
+ if temp is not None and temp >= 38.5: flags.append(f"Red Flag: Fever (Temperature: {temp}Β°C).")
122
+ if hr is not None and hr >= 120: flags.append(f"Red Flag: Tachycardia (Heart Rate: {hr} bpm).")
123
+ if hr is not None and hr <= 50: flags.append(f"Red Flag: Bradycardia (Heart Rate: {hr} bpm).")
124
+ if rr is not None and rr >= 24: flags.append(f"Red Flag: Tachypnea (Respiratory Rate: {rr} rpm).")
125
+ if spo2 is not None and spo2 <= 92: flags.append(f"Red Flag: Hypoxia (SpO2: {spo2}%).")
126
+ if bp_str:
127
+ bp = parse_bp(bp_str)
128
+ if bp:
129
+ if bp[0] >= 180 or bp[1] >= 110: flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {bp_str} mmHg).")
130
+ if bp[0] <= 90 or bp[1] <= 60: flags.append(f"Red Flag: Hypotension (BP: {bp_str} mmHg).")
131
+
132
+ # History Flags (Simple examples)
133
+ if history and "history of mi" in history.lower() and "chest pain" in symptoms_lower:
134
+ flags.append("Red Flag: History of MI with current Chest Pain.")
135
+ if history and "history of dvt/pe" in history.lower() and "shortness of breath" in symptoms_lower:
136
+ flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.")
137
+
138
+ # Remove duplicates
139
+ return list(set(flags))
140
 
141
+ def format_patient_data_for_prompt(data: dict) -> str:
142
+ """Formats the patient dictionary into a readable string for the LLM."""
143
+ if not data: return "No patient data provided."
144
+ prompt_str = ""
145
+ for key, value in data.items():
146
+ section_title = key.replace('_', ' ').title()
147
+ if isinstance(value, dict) and value:
148
+ has_content = any(sub_value for sub_value in value.values())
149
+ if has_content:
150
+ prompt_str += f"**{section_title}:**\n"
151
+ for sub_key, sub_value in value.items():
152
+ if sub_value:
153
+ prompt_str += f" - {sub_key.replace('_', ' ').title()}: {sub_value}\n"
154
+ elif isinstance(value, list) and value:
155
+ prompt_str += f"**{section_title}:** {', '.join(map(str, value))}\n"
156
+ elif value and not isinstance(value, dict): # Check it's not an empty dict
157
+ prompt_str += f"**{section_title}:** {value}\n"
158
+ return prompt_str.strip()
159
 
160
 
161
+ # --- Tool Definitions ---
 
 
 
162
 
163
+ # Pydantic models for robust argument validation
164
  class LabOrderInput(BaseModel):
165
+ test_name: str = Field(..., description="Specific name of the lab test or panel (e.g., 'CBC', 'BMP', 'Troponin I', 'Urinalysis', 'D-dimer').")
166
+ reason: str = Field(..., description="Clinical justification for ordering the test (e.g., 'Rule out infection', 'Assess renal function', 'Evaluate for ACS', 'Assess for PE').")
167
  priority: str = Field("Routine", description="Priority of the test (e.g., 'STAT', 'Routine').")
168
 
169
  @tool("order_lab_test", args_schema=LabOrderInput)
170
  def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
171
  """Orders a specific lab test with clinical justification and priority."""
172
+ print(f"Executing order_lab_test: {test_name}, Reason: {reason}, Priority: {priority}")
173
+ # In a real system, this would integrate with an LIS/EMR API
174
+ return json.dumps({
175
+ "status": "success",
176
+ "message": f"Lab Ordered: {test_name} ({priority})",
177
+ "details": f"Reason: {reason}"
178
+ })
179
 
180
  class PrescriptionInput(BaseModel):
181
  medication_name: str = Field(..., description="Name of the medication.")
182
+ dosage: str = Field(..., description="Dosage amount and unit (e.g., '500 mg', '10 mg', '81 mg').")
183
+ route: str = Field(..., description="Route of administration (e.g., 'PO', 'IV', 'IM', 'Topical', 'SL').")
184
+ frequency: str = Field(..., description="How often the medication should be taken (e.g., 'BID', 'QDaily', 'Q4-6H PRN', 'once').")
185
+ duration: str = Field("As directed", description="Duration of treatment (e.g., '7 days', '1 month', 'Ongoing', 'Until follow-up').")
186
  reason: str = Field(..., description="Clinical indication for the prescription.")
187
 
188
  @tool("prescribe_medication", args_schema=PrescriptionInput)
189
  def prescribe_medication(medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str) -> str:
190
+ """Prescribes a medication with detailed instructions and clinical indication. IMPORTANT: Requires prior interaction check."""
191
+ print(f"Executing prescribe_medication: {medication_name} {dosage}...")
192
+ # NOTE: The safety check (ensuring interaction check was requested) happens in the tool_node *before* this function is called.
193
+ # In a real system, this would trigger an e-prescription workflow.
194
+ return json.dumps({
195
+ "status": "success",
196
+ "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}",
197
+ "details": f"Duration: {duration}. Reason: {reason}"
198
+ })
199
 
200
  class InteractionCheckInput(BaseModel):
201
+ potential_prescription: str = Field(..., description="The name of the NEW medication being considered for prescribing.")
202
+ # These next two args are now populated by the tool_node using AgentState
203
+ # current_medications: List[str] = Field(..., description="List of the patient's CURRENT medication names.")
204
+ # allergies: List[str] = Field(..., description="List of the patient's known allergies.")
205
+ # Make them optional in the schema, mandatory in the node logic
206
+ current_medications: Optional[List[str]] = Field(None, description="List of patient's current medication names (populated from state).")
207
+ allergies: Optional[List[str]] = Field(None, description="List of patient's known allergies (populated from state).")
208
+
209
 
210
  @tool("check_drug_interactions", args_schema=InteractionCheckInput)
211
+ def check_drug_interactions(potential_prescription: str, current_medications: Optional[List[str]] = None, allergies: Optional[List[str]] = None) -> str:
212
  """Checks for potential drug-drug and drug-allergy interactions BEFORE prescribing."""
213
+ print(f"Executing check_drug_interactions for: {potential_prescription}")
214
  warnings = []
215
  potential_med_lower = potential_prescription.lower()
 
 
216
 
217
+ # Use provided lists or default to empty if None (should be populated by tool_node)
218
+ current_meds_list = current_medications or []
219
+ allergies_list = allergies or []
220
+
221
+ current_meds_lower = [str(med).lower() for med in current_meds_list]
222
+ allergies_lower = [str(a).lower() for a in allergies_list]
223
+
224
+ print(f" Checking against Meds: {current_meds_lower}")
225
+ print(f" Checking against Allergies: {allergies_lower}")
226
+
227
+ # Check Allergies
228
  for allergy in allergies_lower:
229
+ # Direct match
230
  if allergy == potential_med_lower:
231
+ warnings.append(f"CRITICAL ALLERGY: Patient allergic to '{allergy}'. Cannot prescribe '{potential_prescription}'.")
232
+ continue # Don't check cross-reactivity if direct match
233
+ # Check cross-reactivity
234
  if allergy in ALLERGY_INTERACTIONS:
235
  for cross_reactant in ALLERGY_INTERACTIONS[allergy]:
236
  if cross_reactant.lower() == potential_med_lower:
237
+ warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to '{allergy}'. High risk with '{potential_prescription}'.")
238
 
239
+ # Check Drug-Drug Interactions
240
  for current_med in current_meds_lower:
241
+ # Check pairs in both orders using the mock DB
242
  pair1 = (current_med, potential_med_lower)
243
  pair2 = (potential_med_lower, current_med)
244
+ interaction_msg = MOCK_INTERACTION_DB.get(pair1) or MOCK_INTERACTION_DB.get(pair2)
245
+ if interaction_msg:
246
+ warnings.append(f"Interaction: {potential_prescription.capitalize()} with {current_med.capitalize()} - {interaction_msg}")
 
 
 
 
 
247
 
248
  status = "warning" if warnings else "clear"
249
+ message = f"Interaction check for '{potential_prescription}': {len(warnings)} potential issue(s) found." if warnings else f"No major interactions identified for '{potential_prescription}' based on provided lists."
250
+ print(f" Interaction Check Result: {status}, Message: {message}, Warnings: {warnings}")
251
  return json.dumps({"status": status, "message": message, "warnings": warnings})
252
 
253
 
 
258
  @tool("flag_risk", args_schema=FlagRiskInput)
259
  def flag_risk(risk_description: str, urgency: str) -> str:
260
  """Flags a critical risk identified during analysis for immediate attention."""
261
+ print(f"Executing flag_risk: {risk_description}, Urgency: {urgency}")
262
  # Display in Streamlit immediately
263
  st.error(f"🚨 **{urgency.upper()} RISK FLAGGED by AI:** {risk_description}", icon="🚨")
264
+ return json.dumps({
265
+ "status": "flagged",
266
+ "message": f"Risk '{risk_description}' flagged with {urgency} urgency."
267
+ })
268
 
269
  # Initialize Search Tool
270
+ search_tool = TavilySearchResults(
271
+ max_results=ClinicalAppSettings.MAX_SEARCH_RESULTS,
272
+ name="tavily_search_results" # Explicitly name the tool
273
+ )
274
 
275
  # --- LangGraph Setup ---
276
 
277
  # Define the state structure
278
  class AgentState(TypedDict):
279
  messages: Annotated[list[Any], operator.add] # Accumulates messages (Human, AI, Tool)
280
+ patient_data: Optional[dict] # Holds the structured patient data
 
281
 
282
  # Define Tools and Tool Executor
283
  tools = [
 
292
  # Define the Agent Model
293
  model = ChatGroq(
294
  temperature=ClinicalAppSettings.TEMPERATURE,
295
+ model=ClinicalAppSettings.MODEL_NAME,
296
+ # Increase max_tokens if needed for large JSON output + conversation history
297
+ # max_tokens=4096
298
  )
299
+ # Bind tools FOR the model to know their schemas and descriptions
300
+ model_with_tools = model.bind_tools(tools)
301
 
302
  # --- Graph Nodes ---
303
 
304
  # 1. Agent Node: Calls the LLM
305
  def agent_node(state: AgentState):
306
  """Invokes the LLM to decide the next action or response."""
307
+ print("\n---AGENT NODE---")
 
 
308
  current_messages = state['messages']
309
+ # Ensure System Prompt is present
310
+ if not current_messages or not isinstance(current_messages[0], SystemMessage):
311
+ print("Prepending System Prompt.")
312
+ current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages
313
+
314
+ # Optional: Augment first human message with patient data if not already done explicitly
315
+ # This helps ensure the LLM sees it early, though it's also in the state.
316
+ # Be mindful of context window limits.
317
+ # if len(current_messages) > 1 and isinstance(current_messages[1], HumanMessage) and state.get('patient_data'):
318
+ # if "**Initial Patient Data:**" not in current_messages[1].content:
319
+ # print("Augmenting first HumanMessage with patient data summary.")
320
+ # formatted_data = format_patient_data_for_prompt(state['patient_data'])
321
+ # current_messages[1] = HumanMessage(content=f"{current_messages[1].content}\n\n**Initial Patient Data Summary:**\n{formatted_data}")
322
+
323
+ print(f"Invoking LLM with {len(current_messages)} messages.")
324
+ # print(f"Messages Sent: {[m.type for m in current_messages]}") # Log message types
325
+ try:
326
+ response = model_with_tools.invoke(current_messages)
327
+ print(f"Agent Raw Response Type: {type(response)}")
328
+ # print(f"Agent Raw Response Content: {response.content}")
329
+ if hasattr(response, 'tool_calls') and response.tool_calls:
330
+ print(f"Agent Response Tool Calls: {response.tool_calls}")
331
+ else:
332
+ print("Agent Response: No tool calls.")
333
+
334
+ except Exception as e:
335
+ print(f"ERROR in agent_node during LLM invocation: {type(e).__name__} - {e}")
336
+ traceback.print_exc() # Print full traceback for debugging
337
+ # Return an error message to the graph state
338
+ error_message = AIMessage(content=f"Sorry, an internal error occurred while processing the request: {type(e).__name__}")
339
+ return {"messages": [error_message]}
340
+
341
  return {"messages": [response]}
342
 
343
+ # 2. Tool Node: Executes tools called by the Agent (REVISED WITH ROBUST ERROR HANDLING)
344
  def tool_node(state: AgentState):
345
  """Executes tools called by the LLM and returns results."""
346
+ print("\n---TOOL NODE---")
347
+ tool_messages = [] # Initialize list to store results or errors
348
  last_message = state['messages'][-1]
349
+
350
+ # Ensure the last message is an AIMessage with tool calls
351
+ if not isinstance(last_message, AIMessage) or not getattr(last_message, 'tool_calls', None):
352
+ print("Warning: Tool node called unexpectedly without tool calls in the last AI message.")
353
+ # If this happens, it might indicate a routing issue or the LLM hallucinating flow.
354
+ # Returning empty list lets the agent proceed, potentially without needed info.
355
+ # Consider adding a ToolMessage indicating the issue if needed.
356
+ return {"messages": []}
357
 
358
  tool_calls = last_message.tool_calls
359
+ print(f"Tool calls received: {json.dumps(tool_calls, indent=2)}") # Log received calls
360
 
361
+ # Safety Check: Identify required interaction checks before prescriptions
362
+ prescriptions_requested = {} # medication_name_lower -> tool_call
363
+ interaction_checks_requested = {} # medication_name_lower -> tool_call
364
 
365
+ for call in tool_calls:
366
+ tool_name = call.get('name')
367
+ tool_args = call.get('args', {})
368
+ if tool_name == 'prescribe_medication':
369
+ med_name = tool_args.get('medication_name', '').lower()
370
+ if med_name:
371
+ prescriptions_requested[med_name] = call
372
+ elif tool_name == 'check_drug_interactions':
373
+ potential_med = tool_args.get('potential_prescription', '').lower()
374
+ if potential_med:
375
+ interaction_checks_requested[potential_med] = call
376
+
377
+ valid_tool_calls_for_execution = []
378
+
379
+ # Validate prescriptions against interaction checks
380
+ for med_name, prescribe_call in prescriptions_requested.items():
381
+ if med_name not in interaction_checks_requested:
382
+ st.error(f"**Safety Violation:** AI attempted to prescribe '{med_name}' without requesting `check_drug_interactions` in the *same turn*. Prescription blocked.")
383
  error_msg = ToolMessage(
384
+ content=json.dumps({"status": "error", "message": f"Interaction check for '{med_name}' must be requested *before or alongside* the prescription call."}),
385
+ tool_call_id=prescribe_call['id'],
386
+ name=prescribe_call['name'] # Include tool name in ToolMessage
387
  )
388
  tool_messages.append(error_msg)
389
+ else:
390
+ # Interaction check IS requested, allow prescription call to proceed
391
+ pass # The call will be added below if it's in the original tool_calls list
392
 
393
+ # Prepare list of calls to execute (all non-blocked calls)
394
+ blocked_ids = {msg.tool_call_id for msg in tool_messages if msg.content and '"status": "error"' in msg.content}
395
+ valid_tool_calls_for_execution = [call for call in tool_calls if call['id'] not in blocked_ids]
396
 
397
+ # Augment interaction checks with patient data from state
398
  patient_meds = state.get("patient_data", {}).get("medications", {}).get("names_only", [])
399
  patient_allergies = state.get("patient_data", {}).get("allergies", [])
400
+
401
+ for call in valid_tool_calls_for_execution:
402
  if call['name'] == 'check_drug_interactions':
403
+ # Ensure args exist before modifying
404
+ if 'args' not in call: call['args'] = {}
405
  call['args']['current_medications'] = patient_meds
406
  call['args']['allergies'] = patient_allergies
407
+ print(f"Augmented interaction check args for call ID {call['id']}: {call['args']}")
408
+
409
+
410
+ # Execute valid tool calls using batch for efficiency, capturing exceptions
411
+ if valid_tool_calls_for_execution:
412
+ print(f"Attempting to execute {len(valid_tool_calls_for_execution)} tools: {[c['name'] for c in valid_tool_calls_for_execution]}")
413
+ try:
414
+ responses = tool_executor.batch(valid_tool_calls_for_execution, return_exceptions=True)
415
+
416
+ # Process responses, creating ToolMessage for each
417
+ for call, resp in zip(valid_tool_calls_for_execution, responses):
418
+ tool_call_id = call['id']
419
+ tool_name = call['name']
420
+
421
+ if isinstance(resp, Exception):
422
+ # Handle exceptions returned by the batch call
423
+ error_type = type(resp).__name__
424
+ error_str = str(resp)
425
+ print(f"ERROR executing tool '{tool_name}' (ID: {tool_call_id}): {error_type} - {error_str}")
426
+ traceback.print_exc() # Log full traceback
427
+ st.error(f"Error executing action '{tool_name}': {error_type}")
428
+ error_content = json.dumps({
429
+ "status": "error",
430
+ "message": f"Failed to execute '{tool_name}': {error_type} - {error_str}"
431
+ })
432
+ tool_messages.append(ToolMessage(content=error_content, tool_call_id=tool_call_id, name=tool_name))
433
+
434
+ # Specific check for the error mentioned by user
435
+ if isinstance(resp, AttributeError) and "'dict' object has no attribute 'tool'" in error_str:
436
+ print("\n *** DETECTED SPECIFIC ATTRIBUTE ERROR ('dict' object has no attribute 'tool') ***")
437
+ print(f" Tool Call causing error: {json.dumps(call, indent=2)}")
438
+ print(" This likely indicates an internal issue within Langchain/LangGraph or ToolExecutor expecting a different object structure.")
439
+ print(" Ensure tool definitions (@tool decorators) and Pydantic schemas are correct.\n")
440
+
441
+ else:
442
+ # Process successful results
443
+ print(f"Tool '{tool_name}' (ID: {tool_call_id}) executed successfully. Result type: {type(resp)}")
444
+ # Ensure content is string for ToolMessage
445
+ content_str = str(resp)
446
+ tool_messages.append(ToolMessage(content=content_str, tool_call_id=tool_call_id, name=tool_name))
447
+
448
+ # Display result in Streamlit right away for feedback (optional, but helpful)
449
+ # This part might be better handled purely in the UI display loop later
450
+ # try:
451
+ # result_data = json.loads(content_str)
452
+ # status = result_data.get("status", "info")
453
+ # message = result_data.get("message", content_str)
454
+ # if status in ["success", "clear", "flagged"]: st.success(f"Action `{tool_name}` completed: {message}", icon="βœ…" if status != "flagged" else "🚨")
455
+ # elif status == "warning": st.warning(f"Action `{tool_name}` completed: {message}", icon="⚠️")
456
+ # else: st.info(f"Action `{tool_name}` completed: {message}") # Info for other statuses
457
+ # except json.JSONDecodeError:
458
+ # st.info(f"Action `{tool_name}` completed (non-JSON output).")
459
+
460
+ # Catch potential errors within the tool_node logic itself (e.g., preparing calls)
461
+ except Exception as e:
462
+ print(f"CRITICAL UNEXPECTED ERROR within tool_node logic: {type(e).__name__} - {e}")
463
+ traceback.print_exc()
464
+ st.error(f"Critical internal error processing actions: {e}")
465
+ # Create generic error messages for all calls that were intended
466
+ error_content = json.dumps({"status": "error", "message": f"Internal error processing tools: {e}"})
467
+ # Add error messages for calls that didn't get processed yet
468
+ processed_ids = {msg.tool_call_id for msg in tool_messages}
469
+ for call in valid_tool_calls_for_execution:
470
+ if call['id'] not in processed_ids:
471
+ tool_messages.append(ToolMessage(content=error_content, tool_call_id=call['id'], name=call['name']))
472
+
473
+ print(f"Returning {len(tool_messages)} tool messages.")
474
+ # print(f"Tool messages content snippets: {[m.content[:100] + '...' if len(m.content)>100 else m.content for m in tool_messages]}")
475
  return {"messages": tool_messages}
476
 
477
 
478
  # --- Graph Edges (Routing Logic) ---
479
  def should_continue(state: AgentState) -> str:
480
+ """Determines whether to call tools, end the conversation turn, or handle errors."""
481
+ print("\n---ROUTING DECISION---")
482
+ last_message = state['messages'][-1] if state['messages'] else None
483
+
484
+ if not isinstance(last_message, AIMessage):
485
+ # This case might happen if the graph starts with a non-AI message or after an error
486
+ print("Routing: Last message not AI. Ending turn.")
487
+ return "end_conversation_turn"
488
+
489
+ # If the LLM produced an error message (e.g., during invocation)
490
+ if "Sorry, an internal error occurred" in last_message.content:
491
+ print("Routing: AI returned internal error. Ending turn.")
492
+ return "end_conversation_turn"
493
+
494
+ # If the LLM made tool calls, execute them
495
+ if getattr(last_message, 'tool_calls', None):
496
+ print("Routing: AI requested tool calls. Continue to tools node.")
497
  return "continue_tools"
498
+ # Otherwise, the AI provided a response without tool calls, end the turn
499
  else:
500
+ print("Routing: AI provided final response or asked question. Ending turn.")
501
  return "end_conversation_turn"
502
 
503
  # --- Graph Definition ---
 
516
  should_continue, # Function to decide the route
517
  {
518
  "continue_tools": "tools", # If tool calls exist, go to tools node
519
+ "end_conversation_turn": END # Otherwise, end the graph iteration for this turn
520
  }
521
  )
522
 
 
527
  # memory = MemorySaverInMemory() # Optional: for persisting state across runs
528
  # app = workflow.compile(checkpointer=memory)
529
  app = workflow.compile()
530
+ print("LangGraph compiled successfully.")
531
 
532
+ # --- Streamlit UI ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  def main():
534
  st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
535
  st.title(f"🩺 {ClinicalAppSettings.APP_TITLE}")
536
  st.caption(f"Interactive Assistant | Powered by Langchain/LangGraph & Groq ({ClinicalAppSettings.MODEL_NAME})")
537
 
538
+ # Initialize session state
539
  if "messages" not in st.session_state:
540
+ st.session_state.messages = [] # Stores full conversation history
541
  if "patient_data" not in st.session_state:
542
  st.session_state.patient_data = None
 
 
543
  if "graph_app" not in st.session_state:
544
+ st.session_state.graph_app = app
545
 
546
+ # --- Patient Data Input Sidebar ---
547
  with st.sidebar:
548
  st.header("πŸ“„ Patient Intake Form")
 
549
  # Demographics
550
+ st.subheader("Demographics")
551
  age = st.number_input("Age", min_value=0, max_value=120, value=55, key="age_input")
552
  sex = st.selectbox("Biological Sex", ["Male", "Female", "Other/Prefer not to say"], key="sex_input")
553
  # HPI
554
+ st.subheader("History of Present Illness (HPI)")
555
  chief_complaint = st.text_input("Chief Complaint", "Chest pain", key="cc_input")
556
+ hpi_details = st.text_area("Detailed HPI", "55 y/o male presents with substernal chest pain started 2 hours ago, described as pressure, radiating to left arm. Associated with nausea and diaphoresis. Pain is 8/10 severity. No relief with rest.", key="hpi_input", height=150)
557
+ symptoms = st.multiselect("Associated Symptoms", ["Nausea", "Diaphoresis", "Shortness of Breath", "Dizziness", "Palpitations", "Fever", "Cough", "Severe Headache", "Syncope", "Hemoptysis"], default=["Nausea", "Diaphoresis"], key="sym_input")
558
  # History
559
+ st.subheader("Past History")
560
+ pmh = st.text_area("Past Medical History (PMH)", "Hypertension (HTN), Hyperlipidemia (HLD), Type 2 Diabetes Mellitus (DM2), History of MI", key="pmh_input")
561
  psh = st.text_area("Past Surgical History (PSH)", "Appendectomy (2005)", key="psh_input")
562
  # Meds & Allergies
563
+ st.subheader("Medications & Allergies")
564
  current_meds_str = st.text_area("Current Medications (name, dose, freq)", "Lisinopril 10mg daily\nMetformin 1000mg BID\nAtorvastatin 40mg daily\nAspirin 81mg daily", key="meds_input")
565
+ allergies_str = st.text_area("Allergies (comma separated, specify reaction if known)", "Penicillin (rash), Sulfa (hives)", key="allergy_input")
566
  # Social/Family
567
+ st.subheader("Social/Family History")
568
  social_history = st.text_area("Social History (SH)", "Smoker (1 ppd x 30 years), occasional alcohol.", key="sh_input")
569
  family_history = st.text_area("Family History (FHx)", "Father had MI at age 60. Mother has HTN.", key="fhx_input")
570
  # Vitals/Exam
571
+ st.subheader("Vitals & Exam Findings")
572
  col1, col2 = st.columns(2)
573
  with col1:
574
  temp_c = st.number_input("Temp (Β°C)", 35.0, 42.0, 36.8, format="%.1f", key="temp_input")
 
578
  bp_mmhg = st.text_input("BP (SYS/DIA)", "155/90", key="bp_input")
579
  spo2_percent = st.number_input("SpO2 (%)", 70, 100, 96, key="spo2_input")
580
  pain_scale = st.slider("Pain (0-10)", 0, 10, 8, key="pain_input")
581
+ exam_notes = st.text_area("Brief Physical Exam Notes", "Awake, alert, oriented x3. Mild distress. Lungs clear bilaterally. Cardiac exam: Regular rhythm, S1/S2 normal, no murmurs/gallops/rubs. Abdomen soft, non-tender. No lower extremity edema.", key="exam_input", height=100)
582
 
583
  # Compile Patient Data Dictionary on button press
584
  if st.button("Start/Update Consultation", key="start_button"):
585
  current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
586
+ # Basic name extraction (first word, lowercased) for interaction check
587
  current_med_names = []
 
588
  for med in current_meds_list:
589
  match = re.match(r"^\s*([a-zA-Z\-]+)", med)
590
  if match:
591
+ current_med_names.append(match.group(1).lower())
592
+
593
+ # Basic allergy extraction (first word or phrase before parenthesis, lowercased)
594
+ allergies_list = []
595
+ for a in allergies_str.split(','):
596
+ cleaned_allergy = a.strip()
597
+ if cleaned_allergy:
598
+ match = re.match(r"^\s*([a-zA-Z\-\s]+)(?:\s*\(.*\))?", cleaned_allergy)
599
+ if match:
600
+ allergies_list.append(match.group(1).strip().lower())
601
+ else: # Fallback if no parenthesis
602
+ allergies_list.append(cleaned_allergy.lower())
603
 
604
  st.session_state.patient_data = {
605
  "demographics": {"age": age, "sex": sex},
 
614
 
615
  # Initial Red Flag Check (Client-side)
616
  red_flags = check_red_flags(st.session_state.patient_data)
617
+ st.sidebar.markdown("---")
618
  if red_flags:
619
+ st.sidebar.warning("**Initial Red Flags Detected:**")
620
+ for flag in red_flags: st.sidebar.warning(f"- {flag.replace('Red Flag: ','')}")
621
+ else:
622
+ st.sidebar.success("No immediate red flags detected in initial data.")
623
 
624
  # Prepare initial message for the graph
625
+ initial_prompt = f"Initiate consultation for the patient described in the intake form. Start the analysis."
626
+ # Clear previous messages and start fresh
627
  st.session_state.messages = [HumanMessage(content=initial_prompt)]
 
628
  st.success("Patient data loaded. Ready for analysis.")
629
+ # No rerun needed here, chat input will trigger the graph
 
630
 
631
  # --- Main Chat Interface Area ---
632
  st.header("πŸ’¬ Clinical Consultation")
633
 
634
+ # Display chat messages from history
635
+ for msg_index, msg in enumerate(st.session_state.messages):
636
+ unique_key = f"msg_{msg_index}" # Basic unique key
637
  if isinstance(msg, HumanMessage):
638
+ with st.chat_message("user", key=f"{unique_key}_user"):
639
  st.markdown(msg.content)
640
  elif isinstance(msg, AIMessage):
641
+ with st.chat_message("assistant", key=f"{unique_key}_ai"):
642
+ # Display AI text content
643
+ ai_content = msg.content
644
  structured_output = None
645
+
646
+ # Attempt to parse structured JSON if present
647
  try:
648
+ # Look for ```json ... ``` block
649
+ json_match = re.search(r"```json\s*(\{.*?\})\s*```", ai_content, re.DOTALL | re.IGNORECASE)
650
  if json_match:
651
+ json_str = json_match.group(1)
652
+ # Display content before/after the JSON block if any
653
+ prefix = ai_content[:json_match.start()].strip()
654
+ suffix = ai_content[json_match.end():].strip()
655
+ if prefix: st.markdown(prefix)
656
+ structured_output = json.loads(json_str)
657
+ if suffix: st.markdown(suffix)
658
+ # Check if the entire message might be JSON
659
+ elif ai_content.strip().startswith("{") and ai_content.strip().endswith("}"):
660
+ structured_output = json.loads(ai_content)
661
+ ai_content = "" # Don't display raw JSON if parsed ok
662
  else:
663
+ # No JSON found, display content as is
664
+ st.markdown(ai_content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665
 
666
  except json.JSONDecodeError:
667
+ # Failed to parse, display raw content
668
+ st.markdown(ai_content)
669
+ st.warning("Note: Could not parse structured JSON in AI response.", icon="⚠️")
670
+ except Exception as e:
671
+ st.markdown(ai_content) # Display raw on other errors
672
+ st.error(f"Error processing AI message display: {e}", icon="❌")
673
+
674
+ # Display structured data nicely if parsed
675
+ if structured_output and isinstance(structured_output, dict):
676
+ st.divider()
677
+ st.subheader("πŸ“Š AI Analysis & Recommendations")
678
+ cols = st.columns(2)
679
+ with cols[0]:
680
+ st.markdown(f"**Assessment:**")
681
+ st.markdown(f"> {structured_output.get('assessment', 'N/A')}")
682
+
683
+ st.markdown(f"**Differential Diagnosis:**")
684
+ ddx = structured_output.get('differential_diagnosis', [])
685
+ if ddx:
686
+ for item in ddx:
687
+ likelihood = item.get('likelihood', 'Unknown').capitalize()
688
+ icon = "πŸ₯‡" if likelihood=="High" else ("πŸ₯ˆ" if likelihood=="Medium" else "πŸ₯‰")
689
+ with st.expander(f"{icon} {item.get('diagnosis', 'Unknown')} ({likelihood})"):
690
+ st.write(f"**Rationale:** {item.get('rationale', 'N/A')}")
691
+ else: st.info("No differential diagnosis provided.")
692
+
693
+ st.markdown(f"**Risk Assessment:**")
694
+ risk = structured_output.get('risk_assessment', {})
695
+ flags = risk.get('identified_red_flags', [])
696
+ if flags: st.warning(f"**Flags:** {', '.join(flags)}")
697
+ if risk.get("immediate_concerns"): st.warning(f"**Concerns:** {', '.join(risk.get('immediate_concerns'))}")
698
+ if risk.get("potential_complications"): st.info(f"**Potential Complications:** {', '.join(risk.get('potential_complications'))}")
699
+ if not flags and not risk.get("immediate_concerns"): st.success("No major risks highlighted in this assessment.")
700
+
701
+ with cols[1]:
702
+ st.markdown(f"**Recommended Plan:**")
703
+ plan = structured_output.get('recommended_plan', {})
704
+ sub_sections = ["investigations", "therapeutics", "consultations", "patient_education"]
705
+ for section in sub_sections:
706
+ st.markdown(f"_{section.replace('_',' ').capitalize()}:_")
707
+ items = plan.get(section)
708
+ if items and isinstance(items, list):
709
+ for item in items: st.markdown(f"- {item}")
710
+ elif items: # Handle if it's just a string
711
+ st.markdown(f"- {items}")
712
+ else: st.markdown("_None suggested._")
713
+ st.markdown("") # Add space
714
+
715
+ # Display Rationale and Interaction Summary below columns
716
+ st.markdown(f"**Rationale & Guideline Check:**")
717
+ st.markdown(f"> {structured_output.get('rationale_summary', 'N/A')}")
718
+ interaction_summary = structured_output.get("interaction_check_summary", "")
719
+ if interaction_summary:
720
+ st.markdown(f"**Interaction Check Summary:**")
721
+ st.markdown(f"> {interaction_summary}")
722
+
723
+ st.divider()
724
+
725
+ # Display tool calls requested in this AI turn
726
+ if getattr(msg, 'tool_calls', None):
727
  with st.expander("πŸ› οΈ AI requested actions", expanded=False):
728
  for tc in msg.tool_calls:
729
+ try:
730
+ # Safely display, default args to empty dict if missing
731
+ st.code(f"Action: {tc.get('name', 'Unknown Tool')}\nArgs: {json.dumps(tc.get('args', {}), indent=2)}", language="json")
732
+ except Exception as display_e:
733
+ st.error(f"Could not display tool call: {display_e}")
734
+ st.code(str(tc)) # Raw display as fallback
735
 
736
  elif isinstance(msg, ToolMessage):
737
+ # Safely get tool name
738
+ tool_name_display = getattr(msg, 'name', 'tool_execution') # Use 'name' attribute added in tool_node
739
+ with st.chat_message(tool_name_display, avatar="πŸ› οΈ", key=f"{unique_key}_tool"):
740
  try:
741
+ # Attempt to parse content as JSON for structured display
742
  tool_data = json.loads(msg.content)
743
  status = tool_data.get("status", "info")
744
  message = tool_data.get("message", msg.content)
 
746
  warnings = tool_data.get("warnings")
747
 
748
  if status == "success" or status == "clear" or status == "flagged":
749
+ st.success(f"{message}", icon="βœ…" if status != "flagged" else "🚨")
750
  elif status == "warning":
751
+ st.warning(f"{message}", icon="⚠️")
752
+ if warnings and isinstance(warnings, list):
753
+ st.caption("Details:")
754
  for warn in warnings: st.caption(f"- {warn}")
755
  else: # Error or unknown status
756
+ st.error(f"{message}", icon="❌")
757
 
758
  if details: st.caption(f"Details: {details}")
759
 
760
  except json.JSONDecodeError:
761
+ # If content is not JSON, display it plainly
762
+ st.info(f"{msg.content}")
763
+ except Exception as e:
764
+ st.error(f"Error displaying tool message: {e}", icon="❌")
765
+ st.caption(f"Raw content: {msg.content}")
766
 
767
+ # --- Chat Input Logic ---
768
  if prompt := st.chat_input("Your message or follow-up query..."):
769
  if not st.session_state.patient_data:
770
  st.warning("Please load patient data using the sidebar first.")
771
+ st.stop() # Prevent execution if no patient data
772
+
773
+ # Add user message to state and display it immediately
774
+ user_message = HumanMessage(content=prompt)
775
+ st.session_state.messages.append(user_message)
776
+ with st.chat_message("user"):
777
+ st.markdown(prompt)
778
+
779
+ # Prepare state for graph invocation
780
+ current_state = AgentState(
781
+ messages=st.session_state.messages,
782
+ patient_data=st.session_state.patient_data
783
+ )
784
+
785
+ # Invoke the graph
786
+ with st.spinner("SynapseAI is thinking..."):
787
+ try:
788
+ # Use invoke to run the graph until it ends for this turn
789
+ final_state = st.session_state.graph_app.invoke(
790
+ current_state,
791
+ {"recursion_limit": 15} # Add recursion limit for safety
792
+ )
793
+ # Update the session state messages with the final list from the graph run
794
+ st.session_state.messages = final_state['messages']
795
+
796
+ except Exception as e:
797
+ print(f"CRITICAL ERROR during graph invocation: {type(e).__name__} - {e}")
798
+ traceback.print_exc()
799
+ st.error(f"An error occurred during the conversation turn: {e}", icon="❌")
800
+ # Attempt to add an error message to the history for visibility
801
+ error_ai_msg = AIMessage(content=f"Sorry, a critical error occurred: {type(e).__name__}. Please check logs or try again.")
802
+ # Avoid modifying state directly during exception handling if possible,
803
+ # but appending might be okay for display purposes.
804
+ # st.session_state.messages.append(error_ai_msg) # Be cautious with state modification here
805
+
806
+ # Rerun the script to display the updated chat history, including AI response and tool results
807
+ st.rerun()
808
+
809
+ # Disclaimer at the bottom
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
810
  st.markdown("---")
811
+ st.warning(
812
+ """**Disclaimer:** SynapseAI is an AI assistant for clinical decision support and does not replace professional medical judgment.
813
+ All outputs must be critically reviewed and verified by a qualified healthcare provider before making any clinical decisions.
814
+ Validate all information, especially diagnoses, dosages, and interactions, independently using standard medical resources."""
815
+ )
816
 
817
  if __name__ == "__main__":
818
  main()