mgbam commited on
Commit
4258926
·
verified ·
1 Parent(s): 896de2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +471 -436
app.py CHANGED
@@ -1,116 +1,117 @@
1
  import streamlit as st
2
  from langchain_groq import ChatGroq
3
  from langchain_community.tools.tavily_search import TavilySearchResults
4
- from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
5
  from langchain_core.prompts import ChatPromptTemplate
6
- from langchain_core.output_parsers import StrOutputParser
7
  from langchain_core.pydantic_v1 import BaseModel, Field
8
  from langchain_core.tools import tool
9
- from typing import Optional, List, Dict, Any
 
 
 
 
10
  import json
11
- import re # For parsing vitals like BP
 
12
 
13
- # --- Configuration & Constants ---
14
  class ClinicalAppSettings:
15
- APP_TITLE = "SynapseAI: Advanced Clinical Decision Support"
16
  PAGE_LAYOUT = "wide"
17
- MODEL_NAME = "llama3-70b-8192" # Use a powerful model like Groq's Llama3 70b
18
  TEMPERATURE = 0.1
19
  MAX_SEARCH_RESULTS = 3
20
 
21
  class ClinicalPrompts:
 
22
  SYSTEM_PROMPT = """
23
- You are SynapseAI, an expert AI clinical assistant designed to support healthcare professionals.
24
- Your primary function is to analyze patient data, provide differential diagnoses, suggest evidence-based management plans, and identify potential risks according to the latest medical guidelines and safety protocols.
25
 
26
- **Core Directives:**
27
- 1. **Comprehensive Analysis:** Thoroughly analyze ALL provided patient data (demographics, HPI, PMH, PSH, Allergies, Meds, SH, FH, ROS, Vitals, Exam).
28
- 2. **Structured Output:** ALWAYS format your response using the following JSON structure:
 
29
  ```json
30
  {
31
- "assessment": "Concise summary of the patient's presentation and key findings.",
32
  "differential_diagnosis": [
33
- {"diagnosis": "Primary Diagnosis", "likelihood": "High/Medium/Low", "rationale": "Supporting evidence..."},
34
  {"diagnosis": "Alternative Diagnosis 1", "likelihood": "Medium/Low", "rationale": "Supporting/Refuting evidence..."},
35
  {"diagnosis": "Alternative Diagnosis 2", "likelihood": "Low", "rationale": "Why it's less likely but considered..."}
36
  ],
37
  "risk_assessment": {
38
- "identified_red_flags": ["List any triggered red flags based on input"],
39
- "immediate_concerns": ["Specific urgent issues requiring attention (e.g., sepsis risk, ACS rule-out)"],
40
- "potential_complications": ["Possible future issues based on presentation"]
41
  },
42
  "recommended_plan": {
43
- "investigations": ["List specific lab tests or imaging required. Use 'order_lab_test' tool."],
44
- "therapeutics": ["Suggest specific treatments or prescriptions. Use 'prescribe_medication' tool."],
45
- "consultations": ["Recommend specialist consultations if needed."],
46
  "patient_education": ["Key points for patient communication."]
47
  },
48
- "rationale_summary": "Brief justification for the overall assessment and plan, referencing guidelines or evidence where possible. Use 'tavily_search_results' tool if needed to find supporting evidence/guidelines.",
49
- "interaction_check_summary": "Summary of findings from the 'check_drug_interactions' tool IF a new medication was considered or prescribed."
50
  }
51
  ```
52
- 3. **Safety First - Red Flags:** Immediately identify and report any conditions matching the defined RED_FLAGS. Use the `flag_risk` tool if critical.
53
- 4. **Safety First - Drug Interactions:** BEFORE suggesting *any* new prescription, you MUST use the `check_drug_interactions` tool to verify against the patient's current medications and allergies. Mention the result in `interaction_check_summary`.
54
- 5. **Tool Utilization:** Employ the provided tools (`order_lab_test`, `prescribe_medication`, `check_drug_interactions`, `flag_risk`, `tavily_search_results`) precisely when indicated by your plan. Adhere strictly to tool schemas. Do NOT hallucinate tool usage results; wait for actual tool output if required in a multi-turn scenario (though this implementation focuses on single-turn analysis with tool calls).
55
- 6. **Evidence-Based:** Briefly cite reasoning, drawing on general medical knowledge. Use Tavily Search for specific guideline checks or novel information when necessary.
56
- 7. **Clarity and Conciseness:** Be clear, avoiding ambiguity. Use standard medical terminology.
57
  """
58
 
59
- # --- Mock Data / Helpers ---
60
- # (In a real system, this would be a proper API/database)
61
  MOCK_INTERACTION_DB = {
62
- ("Lisinopril", "Spironolactone"): "High risk of hyperkalemia. Monitor potassium closely.",
63
- ("Warfarin", "Amiodarone"): "Increased bleeding risk. Monitor INR frequently and adjust Warfarin dose.",
64
- ("Simvastatin", "Clarithromycin"): "Increased risk of myopathy/rhabdomyolysis. Avoid combination or use lower statin dose.",
65
- ("Aspirin", "Ibuprofen"): "Concurrent use may decrease Aspirin's cardioprotective effect. Potential for increased GI bleeding."
66
  }
67
 
68
  ALLERGY_INTERACTIONS = {
69
- "Penicillin": ["Amoxicillin", "Ampicillin", "Piperacillin"],
70
- "Sulfa": ["Sulfamethoxazole", "Sulfasalazine"],
71
- "Aspirin": ["Ibuprofen", "Naproxen"] # Cross-reactivity example for NSAIDs
72
  }
73
 
74
  def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
75
- """Parses BP string like '120/80' into (systolic, diastolic) integers."""
76
  match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string)
77
- if match:
78
- return int(match.group(1)), int(match.group(2))
79
  return None
80
 
81
  def check_red_flags(patient_data: dict) -> List[str]:
82
- """Checks patient data against predefined red flags."""
83
  flags = []
84
  symptoms = patient_data.get("hpi", {}).get("symptoms", [])
85
  vitals = patient_data.get("vitals", {})
86
  history = patient_data.get("pmh", {}).get("conditions", "")
 
 
 
 
 
 
87
 
88
- # Symptom Flags
89
- if "chest pain" in [s.lower() for s in symptoms]: flags.append("Red Flag: Chest Pain reported.")
90
- if "shortness of breath" in [s.lower() for s in symptoms]: flags.append("Red Flag: Shortness of Breath reported.")
91
- if "severe headache" in [s.lower() for s in symptoms]: flags.append("Red Flag: Severe Headache reported.")
92
- if "sudden vision loss" in [s.lower() for s in symptoms]: flags.append("Red Flag: Sudden Vision Loss reported.")
93
- if "weakness on one side" in [s.lower() for s in symptoms]: flags.append("Red Flag: Unilateral Weakness reported (potential stroke).")
94
-
95
- # Vital Sign Flags (add more checks as needed)
96
- if "temp_c" in vitals and vitals["temp_c"] >= 38.5: flags.append(f"Red Flag: Fever (Temperature: {vitals['temp_c']}°C).")
97
- if "hr_bpm" in vitals and vitals["hr_bpm"] >= 120: flags.append(f"Red Flag: Tachycardia (Heart Rate: {vitals['hr_bpm']} bpm).")
98
- if "rr_rpm" in vitals and vitals["rr_rpm"] >= 24: flags.append(f"Red Flag: Tachypnea (Respiratory Rate: {vitals['rr_rpm']} rpm).")
99
- if "spo2_percent" in vitals and vitals["spo2_percent"] <= 92: flags.append(f"Red Flag: Hypoxia (SpO2: {vitals['spo2_percent']}%).")
100
  if "bp_mmhg" in vitals:
101
  bp = parse_bp(vitals["bp_mmhg"])
102
- if bp:
103
- if bp[0] >= 180 or bp[1] >= 110: flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {vitals['bp_mmhg']} mmHg).")
104
- if bp[0] <= 90 and bp[1] <= 60: flags.append(f"Red Flag: Hypotension (BP: {vitals['bp_mmhg']} mmHg).")
105
-
106
- # History Flags (Simple examples)
107
- if "history of mi" in history.lower() and "chest pain" in [s.lower() for s in symptoms]: flags.append("Red Flag: History of MI with current Chest Pain.")
108
 
 
 
109
  return flags
110
 
111
- # --- Enhanced Tool Definitions ---
112
 
113
- # Use Pydantic models for robust argument validation
 
 
 
 
114
  class LabOrderInput(BaseModel):
115
  test_name: str = Field(..., description="Specific name of the lab test or panel (e.g., 'CBC', 'BMP', 'Troponin I', 'Urinalysis').")
116
  reason: str = Field(..., description="Clinical justification for ordering the test (e.g., 'Rule out infection', 'Assess renal function', 'Evaluate for ACS').")
@@ -119,11 +120,7 @@ class LabOrderInput(BaseModel):
119
  @tool("order_lab_test", args_schema=LabOrderInput)
120
  def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
121
  """Orders a specific lab test with clinical justification and priority."""
122
- return json.dumps({
123
- "status": "success",
124
- "message": f"Lab Ordered: {test_name} ({priority})",
125
- "details": f"Reason: {reason}"
126
- })
127
 
128
  class PrescriptionInput(BaseModel):
129
  medication_name: str = Field(..., description="Name of the medication.")
@@ -136,12 +133,8 @@ class PrescriptionInput(BaseModel):
136
  @tool("prescribe_medication", args_schema=PrescriptionInput)
137
  def prescribe_medication(medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str) -> str:
138
  """Prescribes a medication with detailed instructions and clinical indication."""
139
- # In a real scenario, this would trigger an e-prescription workflow
140
- return json.dumps({
141
- "status": "success",
142
- "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}",
143
- "details": f"Duration: {duration}. Reason: {reason}"
144
- })
145
 
146
  class InteractionCheckInput(BaseModel):
147
  potential_prescription: str = Field(..., description="The name of the NEW medication being considered.")
@@ -153,35 +146,34 @@ def check_drug_interactions(potential_prescription: str, current_medications: Li
153
  """Checks for potential drug-drug and drug-allergy interactions BEFORE prescribing."""
154
  warnings = []
155
  potential_med_lower = potential_prescription.lower()
 
 
156
 
157
- # Check Allergies
158
- for allergy in allergies:
159
- allergy_lower = allergy.lower()
160
- # Simple direct check
161
- if allergy_lower == potential_med_lower:
162
  warnings.append(f"CRITICAL ALLERGY: Patient allergic to {allergy}. Cannot prescribe {potential_prescription}.")
163
  continue
164
- # Check cross-reactivity (using simplified mock data)
165
- if allergy_lower in ALLERGY_INTERACTIONS:
166
- for cross_reactant in ALLERGY_INTERACTIONS[allergy_lower]:
167
  if cross_reactant.lower() == potential_med_lower:
168
  warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to {allergy}. High risk with {potential_prescription}.")
169
 
170
- # Check Drug-Drug Interactions (using simplified mock data)
171
- current_meds_lower = [med.lower() for med in current_medications]
172
  for current_med in current_meds_lower:
173
- # Check pairs in both orders
174
  pair1 = (current_med, potential_med_lower)
175
  pair2 = (potential_med_lower, current_med)
 
 
 
 
176
  if pair1 in MOCK_INTERACTION_DB:
177
- warnings.append(f"Interaction Found: {potential_prescription} with {current_med.capitalize()} - {MOCK_INTERACTION_DB[pair1]}")
178
  elif pair2 in MOCK_INTERACTION_DB:
179
- warnings.append(f"Interaction Found: {potential_prescription} with {current_med.capitalize()} - {MOCK_INTERACTION_DB[pair2]}")
 
 
 
 
180
 
181
- if not warnings:
182
- return json.dumps({"status": "clear", "message": f"No major interactions identified for {potential_prescription} with current meds/allergies.", "warnings": []})
183
- else:
184
- return json.dumps({"status": "warning", "message": f"Potential interactions identified for {potential_prescription}.", "warnings": warnings})
185
 
186
  class FlagRiskInput(BaseModel):
187
  risk_description: str = Field(..., description="Specific critical risk identified (e.g., 'Suspected Sepsis', 'Acute Coronary Syndrome', 'Stroke Alert').")
@@ -190,372 +182,415 @@ class FlagRiskInput(BaseModel):
190
  @tool("flag_risk", args_schema=FlagRiskInput)
191
  def flag_risk(risk_description: str, urgency: str) -> str:
192
  """Flags a critical risk identified during analysis for immediate attention."""
193
- st.error(f"🚨 **{urgency.upper()} RISK FLAGGED:** {risk_description}", icon="🚨")
194
- return json.dumps({
195
- "status": "flagged",
196
- "message": f"Risk '{risk_description}' flagged with {urgency} urgency."
197
- })
198
-
199
 
200
  # Initialize Search Tool
201
- search_tool = TavilySearchResults(max_results=ClinicalAppSettings.MAX_SEARCH_RESULTS)
202
-
203
- # --- Core Agent Logic ---
204
- class ClinicalAgent:
205
- def __init__(self):
206
- self.model = ChatGroq(
207
- temperature=ClinicalAppSettings.TEMPERATURE,
208
- model=ClinicalAppSettings.MODEL_NAME
209
- )
210
- # Combine all tools
211
- self.tools = [
212
- order_lab_test,
213
- prescribe_medication,
214
- check_drug_interactions,
215
- flag_risk,
216
- search_tool
217
- ]
218
- # Bind tools to the model
219
- self.model_with_tools = self.model.bind_tools(self.tools)
220
- # History for context (simple implementation)
221
- self.history = []
222
-
223
- def _format_patient_data_for_prompt(self, data: dict) -> str:
224
- """Formats the patient dictionary into a readable string for the LLM."""
225
- prompt_str = "Patient Data:\n"
226
- for key, value in data.items():
227
- if isinstance(value, dict):
228
- prompt_str += f" {key.replace('_', ' ').title()}:\n"
229
- for sub_key, sub_value in value.items():
230
- if sub_value: # Only include if there's data
231
- prompt_str += f" - {sub_key.replace('_', ' ').title()}: {sub_value}\n"
232
- elif isinstance(value, list) and value:
233
- prompt_str += f" {key.replace('_', ' ').title()}: {', '.join(map(str, value))}\n"
234
- elif value: # Only include non-empty fields
235
- prompt_str += f" {key.replace('_', ' ').title()}: {value}\n"
236
- return prompt_str.strip()
237
-
238
-
239
- def analyze(self, patient_data: dict) -> tuple[Optional[dict], List[dict]]:
240
- """Runs the analysis, handling tool calls and parsing the structured output."""
241
- try:
242
- # Add System Prompt and formatted Patient Data
243
- # Simple history management: add previous messages if any
244
- messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)]
245
- # Include history if needed - consider token limits
246
- # messages.extend(self.history)
247
- formatted_data = self._format_patient_data_for_prompt(patient_data)
248
- messages.append(HumanMessage(content=formatted_data))
249
-
250
- # Invoke the model
251
- ai_response = self.model_with_tools.invoke(messages)
252
-
253
- # Store conversation turn
254
- # self.history.append(HumanMessage(content=formatted_data))
255
- # self.history.append(ai_response) # AIMessage includes tool calls
256
-
257
- response_content = None
258
- tool_calls = []
259
-
260
- if isinstance(ai_response, AIMessage):
261
- # Check if the response contains the structured JSON output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  try:
263
- # Sometimes the JSON is embedded in the content, sometimes it's the primary content
264
- # Look for ```json ... ``` block first
265
- json_match = re.search(r"```json\n(\{.*?\})\n```", ai_response.content, re.DOTALL)
266
  if json_match:
267
- response_content = json.loads(json_match.group(1))
 
 
 
 
 
 
 
 
268
  else:
269
- # Try parsing the whole content as JSON
270
- response_content = json.loads(ai_response.content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  except json.JSONDecodeError:
272
- st.warning("AI did not return valid JSON in the expected format. Displaying raw content.")
273
- st.code(ai_response.content, language=None) # Display raw if not JSON
274
- response_content = {"assessment": ai_response.content, "error": "Output format incorrect"}
275
 
276
- # Extract tool calls separately
277
- if ai_response.tool_calls:
278
- tool_calls = ai_response.tool_calls
279
 
280
- return response_content, tool_calls
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
- except Exception as e:
283
- st.error(f"Error during AI analysis: {str(e)}")
284
- return None, []
285
 
286
- def process_tool_call(self, tool_call: Dict[str, Any]) -> Any:
287
- """Executes a single tool call."""
288
- tool_name = tool_call.get("name")
289
- tool_args = tool_call.get("args", {})
290
- selected_tool = {t.name: t for t in self.tools}.get(tool_name)
 
291
 
292
- if not selected_tool:
293
- return json.dumps({"status": "error", "message": f"Unknown tool: {tool_name}"})
294
 
295
- try:
296
- # Ensure args are correctly passed (Pydantic models handle validation)
297
- return selected_tool.invoke(tool_args)
298
- except Exception as e:
299
- st.error(f"Error executing tool '{tool_name}': {str(e)}")
300
- return json.dumps({"status": "error", "message": f"Failed to execute {tool_name}: {str(e)}"})
301
 
302
- # --- Streamlit UI ---
303
- def main():
304
- st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
305
- st.title(f"🩺 {ClinicalAppSettings.APP_TITLE}")
306
- st.caption(f"Powered by Langchain & Groq ({ClinicalAppSettings.MODEL_NAME})")
307
-
308
- # Initialize Agent in session state
309
- if 'agent' not in st.session_state:
310
- st.session_state.agent = ClinicalAgent()
311
- if 'analysis_complete' not in st.session_state:
312
- st.session_state.analysis_complete = False
313
- if 'analysis_result' not in st.session_state:
314
- st.session_state.analysis_result = None
315
- if 'tool_call_results' not in st.session_state:
316
- st.session_state.tool_call_results = []
317
- if 'red_flags' not in st.session_state:
318
- st.session_state.red_flags = []
319
-
320
- # --- Patient Data Input Sidebar ---
321
- with st.sidebar:
322
- st.header("📄 Patient Intake Form")
323
 
324
- # Demographics
325
- st.subheader("Demographics")
326
- age = st.number_input("Age", min_value=0, max_value=120, value=55)
327
- sex = st.selectbox("Biological Sex", ["Male", "Female", "Other/Prefer not to say"])
328
-
329
- # History of Present Illness (HPI)
330
- st.subheader("History of Present Illness (HPI)")
331
- chief_complaint = st.text_input("Chief Complaint", "Chest pain")
332
- hpi_details = st.text_area("Detailed HPI", "55 y/o male presents with substernal chest pain started 2 hours ago, described as pressure, radiating to left arm. Associated with nausea and diaphoresis. Pain is 8/10 severity. No relief with rest.")
333
- symptoms = st.multiselect("Associated Symptoms", ["Nausea", "Diaphoresis", "Shortness of Breath", "Dizziness", "Palpitations", "Fever", "Cough"], default=["Nausea", "Diaphoresis"])
334
-
335
- # Past Medical/Surgical History (PMH/PSH)
336
- st.subheader("Past History")
337
- pmh = st.text_area("Past Medical History (PMH)", "Hypertension (HTN), Hyperlipidemia (HLD), Type 2 Diabetes Mellitus (DM2)")
338
- psh = st.text_area("Past Surgical History (PSH)", "Appendectomy (2005)")
339
-
340
- # Medications & Allergies
341
- st.subheader("Medications & Allergies")
342
- current_meds = st.text_area("Current Medications (name, dose, freq)", "Lisinopril 10mg daily\nMetformin 1000mg BID\nAtorvastatin 40mg daily\nAspirin 81mg daily")
343
- allergies = st.text_area("Allergies (comma separated)", "Penicillin (rash)")
344
-
345
- # Social & Family History (SH/FH)
346
- st.subheader("Social/Family History")
347
- social_history = st.text_area("Social History (SH)", "Smoker (1 ppd x 30 years), occasional alcohol.")
348
- family_history = st.text_area("Family History (FHx)", "Father had MI at age 60. Mother has HTN.")
349
-
350
- # Review of Systems (ROS) - Simplified
351
- # st.subheader("Review of Systems (ROS)") # Keep UI cleaner for now
352
- # ros_constitutional = st.checkbox("ROS: Constitutional (Fever, Chills, Weight loss)")
353
- # ros_cardiac = st.checkbox("ROS: Cardiac (Chest pain, Palpitations)", value=True) # Pre-check based on HPI
354
-
355
- # Vitals & Basic Exam
356
- st.subheader("Vitals & Exam Findings")
357
- col1, col2 = st.columns(2)
358
- with col1:
359
- temp_c = st.number_input("Temperature (°C)", 35.0, 42.0, 36.8, format="%.1f")
360
- hr_bpm = st.number_input("Heart Rate (bpm)", 30, 250, 95)
361
- rr_rpm = st.number_input("Respiratory Rate (rpm)", 5, 50, 18)
362
- with col2:
363
- bp_mmhg = st.text_input("Blood Pressure (SYS/DIA)", "155/90")
364
- spo2_percent = st.number_input("SpO2 (%)", 70, 100, 96)
365
- pain_scale = st.slider("Pain (0-10)", 0, 10, 8)
366
- exam_notes = st.text_area("Brief Physical Exam Notes", "Awake, alert, oriented x3. Mild distress. Lungs clear. Cardiac exam: Regular rhythm, no murmurs/gallops. Abdomen soft. No edema.")
367
-
368
- # Clean medication list and allergies for processing
369
- current_meds_list = [med.strip() for med in current_meds.split('\n') if med.strip()]
370
- current_med_names = [med.split(' ')[0].strip() for med in current_meds_list] # Simplified name extraction
371
- allergies_list = [a.strip() for a in allergies.split(',') if a.strip()]
372
-
373
- # Compile Patient Data Dictionary
374
- patient_data = {
375
- "demographics": {"age": age, "sex": sex},
376
- "hpi": {"chief_complaint": chief_complaint, "details": hpi_details, "symptoms": symptoms},
377
- "pmh": {"conditions": pmh},
378
- "psh": {"procedures": psh},
379
- "medications": {"current": current_meds_list, "names_only": current_med_names},
380
- "allergies": allergies_list,
381
- "social_history": {"details": social_history},
382
- "family_history": {"details": family_history},
383
- # "ros": {"constitutional": ros_constitutional, "cardiac": ros_cardiac}, # Add if using ROS inputs
384
- "vitals": {
385
- "temp_c": temp_c, "hr_bpm": hr_bpm, "bp_mmhg": bp_mmhg,
386
- "rr_rpm": rr_rpm, "spo2_percent": spo2_percent, "pain_scale": pain_scale
387
- },
388
- "exam_findings": {"notes": exam_notes}
389
- }
390
 
391
- # --- Main Analysis Area ---
392
- st.header("🤖 AI Clinical Analysis")
393
-
394
- # Action Button
395
- if st.button("Analyze Patient Data", type="primary", use_container_width=True):
396
- st.session_state.analysis_complete = False
397
- st.session_state.analysis_result = None
398
- st.session_state.tool_call_results = []
399
- st.session_state.red_flags = []
400
-
401
- # 1. Initial Red Flag Check (Client-side before LLM)
402
- st.session_state.red_flags = check_red_flags(patient_data)
403
- if st.session_state.red_flags:
404
- st.warning("**Initial Red Flags Detected:**")
405
- for flag in st.session_state.red_flags:
406
- st.warning(f"- {flag}")
407
- st.warning("Proceeding with AI analysis, but these require immediate attention.")
408
-
409
- # 2. Call AI Agent
410
- with st.spinner("SynapseAI is processing the case... Please wait."):
411
- analysis_output, tool_calls = st.session_state.agent.analyze(patient_data)
412
-
413
- if analysis_output:
414
- st.session_state.analysis_result = analysis_output
415
- st.session_state.analysis_complete = True
416
-
417
- # 3. Process any Tool Calls requested by the AI
418
- if tool_calls:
419
- st.info(f"AI recommended {len(tool_calls)} action(s). Executing...")
420
- tool_results = []
421
- with st.spinner("Executing recommended actions..."):
422
- for call in tool_calls:
423
- st.write(f"⚙️ Requesting: `{call['name']}` with args `{call['args']}`")
424
- # Pass patient context if needed (e.g., for interaction check)
425
- if call['name'] == 'check_drug_interactions':
426
- call['args']['current_medications'] = patient_data['medications']['names_only']
427
- call['args']['allergies'] = patient_data['allergies']
428
- elif call['name'] == 'prescribe_medication':
429
- # Pre-flight check: Ensure interaction check was requested *before* this prescribe call
430
- interaction_check_requested = any(tc['name'] == 'check_drug_interactions' and tc['args'].get('potential_prescription') == call['args'].get('medication_name') for tc in tool_calls)
431
- if not interaction_check_requested:
432
- st.error(f"**Safety Violation:** AI attempted to prescribe '{call['args'].get('medication_name')}' without requesting `check_drug_interactions` first. Prescription blocked.")
433
- tool_results.append({"tool_call_id": call['id'], "name": call['name'], "output": json.dumps({"status":"error", "message": "Interaction check not performed prior to prescription attempt."})})
434
- continue # Skip this tool call
435
-
436
- result = st.session_state.agent.process_tool_call(call)
437
- tool_results.append({"tool_call_id": call['id'], "name": call['name'], "output": result}) # Store result with ID
438
-
439
- # Display tool result immediately
440
- try:
441
- result_data = json.loads(result)
442
- if result_data.get("status") == "success" or result_data.get("status") == "clear" or result_data.get("status") == "flagged":
443
- st.success(f"✅ Action `{call['name']}`: {result_data.get('message')}", icon="✅")
444
- if result_data.get("details"): st.caption(f"Details: {result_data.get('details')}")
445
- elif result_data.get("status") == "warning":
446
- st.warning(f"⚠️ Action `{call['name']}`: {result_data.get('message')}", icon="⚠️")
447
- if result_data.get("warnings"):
448
- for warn in result_data["warnings"]: st.caption(f"- {warn}")
449
- else:
450
- st.error(f"❌ Action `{call['name']}`: {result_data.get('message')}", icon="❌")
451
- except json.JSONDecodeError:
452
- st.error(f"Tool `{call['name']}` returned non-JSON: {result}") # Fallback for non-JSON results
453
-
454
- st.session_state.tool_call_results = tool_results
455
- # Optionally: Send results back to LLM for final summary (requires multi-turn agent)
456
- else:
457
- st.error("Analysis failed. Please check the input data or try again.")
458
-
459
- # --- Display Analysis Results ---
460
- if st.session_state.analysis_complete and st.session_state.analysis_result:
461
- st.divider()
462
- st.header("📊 Analysis & Recommendations")
463
-
464
- res = st.session_state.analysis_result
465
-
466
- # Layout columns for better readability
467
- col_assessment, col_plan = st.columns(2)
468
-
469
- with col_assessment:
470
- st.subheader("📋 Assessment")
471
- st.write(res.get("assessment", "N/A"))
472
-
473
- st.subheader("🤔 Differential Diagnosis")
474
- ddx = res.get("differential_diagnosis", [])
475
- if ddx:
476
- for item in ddx:
477
- likelihood = item.get('likelihood', 'Unknown').capitalize()
478
- icon = "🥇" if likelihood=="High" else ("🥈" if likelihood=="Medium" else "🥉")
479
- with st.expander(f"{icon} {item.get('diagnosis', 'Unknown Diagnosis')} ({likelihood} Likelihood)", expanded=(likelihood=="High")):
480
- st.write(f"**Rationale:** {item.get('rationale', 'N/A')}")
481
- else:
482
- st.info("No differential diagnosis provided.")
483
-
484
- st.subheader("🚨 Risk Assessment")
485
- risk = res.get("risk_assessment", {})
486
- flags = risk.get("identified_red_flags", []) + [f.replace("Red Flag: ", "") for f in st.session_state.red_flags] # Combine AI and initial flags
487
- if flags:
488
- st.warning(f"**Identified Red Flags:** {', '.join(flags)}")
489
- else:
490
- st.success("No immediate red flags identified by AI in this analysis.")
491
-
492
- if risk.get("immediate_concerns"):
493
- st.warning(f"**Immediate Concerns:** {', '.join(risk.get('immediate_concerns'))}")
494
- if risk.get("potential_complications"):
495
- st.info(f"**Potential Complications:** {', '.join(risk.get('potential_complications'))}")
496
-
497
-
498
- with col_plan:
499
- st.subheader("📝 Recommended Plan")
500
- plan = res.get("recommended_plan", {})
501
-
502
- st.markdown("**Investigations:**")
503
- if plan.get("investigations"):
504
- st.markdown("\n".join([f"- {inv}" for inv in plan.get("investigations")]))
505
- else: st.markdown("_None suggested._")
506
-
507
- st.markdown("**Therapeutics:**")
508
- if plan.get("therapeutics"):
509
- st.markdown("\n".join([f"- {thx}" for thx in plan.get("therapeutics")]))
510
- else: st.markdown("_None suggested._")
511
-
512
- st.markdown("**Consultations:**")
513
- if plan.get("consultations"):
514
- st.markdown("\n".join([f"- {con}" for con in plan.get("consultations")]))
515
- else: st.markdown("_None suggested._")
516
-
517
- st.markdown("**Patient Education:**")
518
- if plan.get("patient_education"):
519
- st.markdown("\n".join([f"- {edu}" for edu in plan.get("patient_education")]))
520
- else: st.markdown("_None specified._")
521
-
522
- # Display Rationale and Interaction Summary below the columns
523
- st.subheader("🧠 AI Rationale & Checks")
524
- with st.expander("Show AI Reasoning Summary", expanded=False):
525
- st.write(res.get("rationale_summary", "No rationale summary provided."))
526
-
527
- interaction_summary = res.get("interaction_check_summary", "")
528
- if interaction_summary: # Only show if interaction check was relevant/performed
529
- with st.expander("Drug Interaction Check Summary", expanded=True):
530
- st.write(interaction_summary)
531
- # Also show detailed results from the tool call itself if available
532
- for tool_res in st.session_state.tool_call_results:
533
- if tool_res['name'] == 'check_drug_interactions':
534
- try:
535
- data = json.loads(tool_res['output'])
536
- if data.get('warnings'):
537
- st.warning("Interaction Details:")
538
- for warn in data['warnings']:
539
- st.caption(f"- {warn}")
540
- else:
541
- st.success("Interaction Details: " + data.get('message', 'Check complete.'))
542
- except: pass # Ignore parsing errors here
543
-
544
- # Display raw JSON if needed for debugging
545
- with st.expander("Show Raw AI Output (JSON)"):
546
- st.json(res)
547
-
548
- st.divider()
549
- st.success("Analysis Complete.")
550
 
551
  # Disclaimer
552
  st.markdown("---")
553
- st.warning(
554
- """**Disclaimer:** SynapseAI is an AI assistant for clinical decision support and does not replace professional medical judgment.
555
- All outputs should be critically reviewed by a qualified healthcare provider before making any clinical decisions.
556
- Verify all information, especially dosages and interactions, independently."""
557
- )
558
-
559
 
560
  if __name__ == "__main__":
561
  main()
 
1
  import streamlit as st
2
  from langchain_groq import ChatGroq
3
  from langchain_community.tools.tavily_search import TavilySearchResults
4
+ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
5
  from langchain_core.prompts import ChatPromptTemplate
 
6
  from langchain_core.pydantic_v1 import BaseModel, Field
7
  from langchain_core.tools import tool
8
+ from langgraph.prebuilt import ToolExecutor
9
+ from langgraph.graph import StateGraph, END
10
+ from langgraph.checkpoint.memory import MemorySaver # For state persistence (optional but good)
11
+
12
+ from typing import Optional, List, Dict, Any, TypedDict, Annotated
13
  import json
14
+ import re
15
+ import operator
16
 
17
+ # --- Configuration & Constants --- (Keep previous ones like ClinicalAppSettings)
18
  class ClinicalAppSettings:
19
+ APP_TITLE = "SynapseAI: Interactive Clinical Decision Support"
20
  PAGE_LAYOUT = "wide"
21
+ MODEL_NAME = "llama3-70b-8192"
22
  TEMPERATURE = 0.1
23
  MAX_SEARCH_RESULTS = 3
24
 
25
  class ClinicalPrompts:
26
+ # UPDATED SYSTEM PROMPT FOR CONVERSATIONAL FLOW & GUIDELINES
27
  SYSTEM_PROMPT = """
28
+ You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation.
29
+ Your goal is to support healthcare professionals by analyzing patient data, providing differential diagnoses, suggesting evidence-based management plans, and identifying risks according to current standards of care.
30
 
31
+ **Core Directives for this Conversation:**
32
+ 1. **Analyze Sequentially:** Process information turn-by-turn. You will receive initial patient data, and potentially follow-up messages or results from tools you requested. Base your responses on the *entire* conversation history.
33
+ 2. **Seek Clarity:** If the provided information is insufficient or ambiguous for a safe assessment, CLEARLY STATE what specific additional information or clarification is needed. Do NOT guess or make unsafe assumptions.
34
+ 3. **Structured Assessment (When Ready):** When you have sufficient information and have performed necessary checks (like interactions), provide a comprehensive assessment using the following JSON structure. Only output this structure when you believe you have a complete initial analysis or plan. Do NOT output incomplete JSON.
35
  ```json
36
  {
37
+ "assessment": "Concise summary of the patient's presentation and key findings based on the conversation.",
38
  "differential_diagnosis": [
39
+ {"diagnosis": "Primary Diagnosis", "likelihood": "High/Medium/Low", "rationale": "Supporting evidence from conversation..."},
40
  {"diagnosis": "Alternative Diagnosis 1", "likelihood": "Medium/Low", "rationale": "Supporting/Refuting evidence..."},
41
  {"diagnosis": "Alternative Diagnosis 2", "likelihood": "Low", "rationale": "Why it's less likely but considered..."}
42
  ],
43
  "risk_assessment": {
44
+ "identified_red_flags": ["List any triggered red flags"],
45
+ "immediate_concerns": ["Specific urgent issues (e.g., sepsis risk, ACS rule-out)"],
46
+ "potential_complications": ["Possible future issues"]
47
  },
48
  "recommended_plan": {
49
+ "investigations": ["List specific lab tests or imaging needed. Use 'order_lab_test' tool."],
50
+ "therapeutics": ["Suggest specific treatments/prescriptions. Use 'prescribe_medication' tool. MUST check interactions first."],
51
+ "consultations": ["Recommend specialist consultations."],
52
  "patient_education": ["Key points for patient communication."]
53
  },
54
+ "rationale_summary": "Justification for assessment/plan. **Crucially, if relevant (e.g., ACS, sepsis, common infections), use 'tavily_search_results' to find and cite current clinical practice guidelines (e.g., 'latest ACC/AHA chest pain guidelines 202X', 'Surviving Sepsis Campaign guidelines') supporting your recommendations.**",
55
+ "interaction_check_summary": "Summary of findings from 'check_drug_interactions' if performed."
56
  }
57
  ```
58
+ 4. **Safety First - Interactions:** BEFORE suggesting a new prescription via `prescribe_medication`, you MUST FIRST use `check_drug_interactions`. Report the findings. If interactions exist, modify the plan or state the contraindication.
59
+ 5. **Safety First - Red Flags:** Use the `flag_risk` tool IMMEDIATELY if critical red flags requiring urgent action are identified at any point.
60
+ 6. **Tool Use:** Employ tools (`order_lab_test`, `prescribe_medication`, `check_drug_interactions`, `flag_risk`, `tavily_search_results`) logically within the conversational flow. Wait for tool results before proceeding if the result is needed for the next step (e.g., wait for interaction check before confirming prescription).
61
+ 7. **Evidence & Guidelines:** Actively use `tavily_search_results` not just for general knowledge, but specifically to query for and incorporate **current clinical practice guidelines** relevant to the patient's presentation (e.g., chest pain, shortness of breath, suspected infection). Summarize findings in the `rationale_summary` when providing the structured output.
62
+ 8. **Conciseness:** Be medically accurate and concise. Use standard terminology. Respond naturally in conversation until ready for the full structured JSON output.
63
  """
64
 
65
+ # --- Mock Data / Helpers --- (Keep previous ones like MOCK_INTERACTION_DB, ALLERGY_INTERACTIONS, parse_bp, check_red_flags)
66
+ # (Include the helper functions from the previous response here)
67
  MOCK_INTERACTION_DB = {
68
+ ("lisinopril", "spironolactone"): "High risk of hyperkalemia. Monitor potassium closely.",
69
+ ("warfarin", "amiodarone"): "Increased bleeding risk. Monitor INR frequently and adjust Warfarin dose.",
70
+ ("simvastatin", "clarithromycin"): "Increased risk of myopathy/rhabdomyolysis. Avoid combination or use lower statin dose.",
71
+ ("aspirin", "ibuprofen"): "Concurrent use may decrease Aspirin's cardioprotective effect. Potential for increased GI bleeding."
72
  }
73
 
74
  ALLERGY_INTERACTIONS = {
75
+ "penicillin": ["amoxicillin", "ampicillin", "piperacillin"],
76
+ "sulfa": ["sulfamethoxazole", "sulfasalazine"],
77
+ "aspirin": ["ibuprofen", "naproxen"] # Cross-reactivity example for NSAIDs
78
  }
79
 
80
  def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
 
81
  match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string)
82
+ if match: return int(match.group(1)), int(match.group(2))
 
83
  return None
84
 
85
  def check_red_flags(patient_data: dict) -> List[str]:
 
86
  flags = []
87
  symptoms = patient_data.get("hpi", {}).get("symptoms", [])
88
  vitals = patient_data.get("vitals", {})
89
  history = patient_data.get("pmh", {}).get("conditions", "")
90
+ symptoms_lower = [s.lower() for s in symptoms]
91
+
92
+ if "chest pain" in symptoms_lower: flags.append("Red Flag: Chest Pain reported.")
93
+ if "shortness of breath" in symptoms_lower: flags.append("Red Flag: Shortness of Breath reported.")
94
+ if "severe headache" in symptoms_lower: flags.append("Red Flag: Severe Headache reported.")
95
+ # Add other symptom checks...
96
 
97
+ if "temp_c" in vitals and vitals["temp_c"] >= 38.5: flags.append(f"Red Flag: Fever ({vitals['temp_c']}°C).")
98
+ if "hr_bpm" in vitals and vitals["hr_bpm"] >= 120: flags.append(f"Red Flag: Tachycardia ({vitals['hr_bpm']} bpm).")
 
 
 
 
 
 
 
 
 
 
99
  if "bp_mmhg" in vitals:
100
  bp = parse_bp(vitals["bp_mmhg"])
101
+ if bp and (bp[0] >= 180 or bp[1] >= 110): flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {vitals['bp_mmhg']} mmHg).")
102
+ if bp and (bp[0] <= 90 or bp[1] <= 60): flags.append(f"Red Flag: Hypotension (BP: {vitals['bp_mmhg']} mmHg).")
103
+ # Add other vital checks...
 
 
 
104
 
105
+ if "history of mi" in history.lower() and "chest pain" in symptoms_lower: flags.append("Red Flag: History of MI with current Chest Pain.")
106
+ # Add other history checks...
107
  return flags
108
 
 
109
 
110
+ # --- Enhanced Tool Definitions --- (Keep previous Pydantic models and @tool functions)
111
+ # (Include LabOrderInput, PrescriptionInput, InteractionCheckInput, FlagRiskInput
112
+ # and the corresponding @tool functions: order_lab_test, prescribe_medication,
113
+ # check_drug_interactions, flag_risk from the previous response here)
114
+
115
  class LabOrderInput(BaseModel):
116
  test_name: str = Field(..., description="Specific name of the lab test or panel (e.g., 'CBC', 'BMP', 'Troponin I', 'Urinalysis').")
117
  reason: str = Field(..., description="Clinical justification for ordering the test (e.g., 'Rule out infection', 'Assess renal function', 'Evaluate for ACS').")
 
120
  @tool("order_lab_test", args_schema=LabOrderInput)
121
  def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
122
  """Orders a specific lab test with clinical justification and priority."""
123
+ return json.dumps({"status": "success", "message": f"Lab Ordered: {test_name} ({priority})", "details": f"Reason: {reason}"})
 
 
 
 
124
 
125
  class PrescriptionInput(BaseModel):
126
  medication_name: str = Field(..., description="Name of the medication.")
 
133
  @tool("prescribe_medication", args_schema=PrescriptionInput)
134
  def prescribe_medication(medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str) -> str:
135
  """Prescribes a medication with detailed instructions and clinical indication."""
136
+ # NOTE: Interaction check should have been done *before* calling this via a separate tool call
137
+ return json.dumps({"status": "success", "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", "details": f"Duration: {duration}. Reason: {reason}"})
 
 
 
 
138
 
139
  class InteractionCheckInput(BaseModel):
140
  potential_prescription: str = Field(..., description="The name of the NEW medication being considered.")
 
146
  """Checks for potential drug-drug and drug-allergy interactions BEFORE prescribing."""
147
  warnings = []
148
  potential_med_lower = potential_prescription.lower()
149
+ current_meds_lower = [med.lower() for med in current_medications]
150
+ allergies_lower = [a.lower() for a in allergies]
151
 
152
+ for allergy in allergies_lower:
153
+ if allergy == potential_med_lower:
 
 
 
154
  warnings.append(f"CRITICAL ALLERGY: Patient allergic to {allergy}. Cannot prescribe {potential_prescription}.")
155
  continue
156
+ if allergy in ALLERGY_INTERACTIONS:
157
+ for cross_reactant in ALLERGY_INTERACTIONS[allergy]:
 
158
  if cross_reactant.lower() == potential_med_lower:
159
  warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to {allergy}. High risk with {potential_prescription}.")
160
 
 
 
161
  for current_med in current_meds_lower:
 
162
  pair1 = (current_med, potential_med_lower)
163
  pair2 = (potential_med_lower, current_med)
164
+ # Normalize keys for lookup if necessary (e.g., if DB keys are canonical names)
165
+ key1 = tuple(sorted(pair1))
166
+ key2 = tuple(sorted(pair2)) # Although redundant if always sorted
167
+
168
  if pair1 in MOCK_INTERACTION_DB:
169
+ warnings.append(f"Interaction: {potential_prescription.capitalize()} with {current_med.capitalize()} - {MOCK_INTERACTION_DB[pair1]}")
170
  elif pair2 in MOCK_INTERACTION_DB:
171
+ warnings.append(f"Interaction: {potential_prescription.capitalize()} with {current_med.capitalize()} - {MOCK_INTERACTION_DB[pair2]}")
172
+
173
+ status = "warning" if warnings else "clear"
174
+ message = f"Interaction check for {potential_prescription}: {len(warnings)} potential issue(s) found." if warnings else f"No major interactions identified for {potential_prescription}."
175
+ return json.dumps({"status": status, "message": message, "warnings": warnings})
176
 
 
 
 
 
177
 
178
  class FlagRiskInput(BaseModel):
179
  risk_description: str = Field(..., description="Specific critical risk identified (e.g., 'Suspected Sepsis', 'Acute Coronary Syndrome', 'Stroke Alert').")
 
182
  @tool("flag_risk", args_schema=FlagRiskInput)
183
  def flag_risk(risk_description: str, urgency: str) -> str:
184
  """Flags a critical risk identified during analysis for immediate attention."""
185
+ # Display in Streamlit immediately
186
+ st.error(f"🚨 **{urgency.upper()} RISK FLAGGED by AI:** {risk_description}", icon="🚨")
187
+ return json.dumps({"status": "flagged", "message": f"Risk '{risk_description}' flagged with {urgency} urgency."})
 
 
 
188
 
189
  # Initialize Search Tool
190
+ search_tool = TavilySearchResults(max_results=ClinicalAppSettings.MAX_SEARCH_RESULTS, name="tavily_search_results")
191
+
192
+
193
+ # --- LangGraph Setup ---
194
+
195
+ # Define the state structure
196
+ class AgentState(TypedDict):
197
+ messages: Annotated[list[Any], operator.add] # Accumulates messages (Human, AI, Tool)
198
+ patient_data: Optional[dict] # Holds the structured patient data (can be updated if needed)
199
+ # Potentially add other state elements like 'interaction_check_needed_for': Optional[str]
200
+
201
+ # Define Tools and Tool Executor
202
+ tools = [
203
+ order_lab_test,
204
+ prescribe_medication,
205
+ check_drug_interactions,
206
+ flag_risk,
207
+ search_tool
208
+ ]
209
+ tool_executor = ToolExecutor(tools)
210
+
211
+ # Define the Agent Model
212
+ model = ChatGroq(
213
+ temperature=ClinicalAppSettings.TEMPERATURE,
214
+ model=ClinicalAppSettings.MODEL_NAME
215
+ )
216
+ model_with_tools = model.bind_tools(tools) # Bind tools for the LLM to know about them
217
+
218
+ # --- Graph Nodes ---
219
+
220
+ # 1. Agent Node: Calls the LLM
221
+ def agent_node(state: AgentState):
222
+ """Invokes the LLM to decide the next action or response."""
223
+ print("---AGENT NODE---")
224
+ # Make sure patient data is included in the first message if not already there
225
+ # This is a basic way; more robust would be merging patient_data into context
226
+ current_messages = state['messages']
227
+ if len(current_messages) == 1 and isinstance(current_messages[0], HumanMessage) and state.get('patient_data'):
228
+ # Augment the first human message with formatted patient data
229
+ formatted_data = format_patient_data_for_prompt(state['patient_data']) # Need this helper function
230
+ current_messages = [
231
+ SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT), # Ensure system prompt is first
232
+ HumanMessage(content=f"{current_messages[0].content}\n\n**Initial Patient Data:**\n{formatted_data}")
233
+ ]
234
+ elif not any(isinstance(m, SystemMessage) for m in current_messages):
235
+ # Add system prompt if missing
236
+ current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages
237
+
238
+
239
+ response = model_with_tools.invoke(current_messages)
240
+ print(f"Agent response: {response}")
241
+ return {"messages": [response]}
242
+
243
+ # 2. Tool Node: Executes tools called by the Agent
244
+ def tool_node(state: AgentState):
245
+ """Executes tools called by the LLM and returns results."""
246
+ print("---TOOL NODE---")
247
+ last_message = state['messages'][-1]
248
+ if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
249
+ print("No tool calls in last message.")
250
+ return {} # Should not happen if routing is correct, but safety check
251
+
252
+ tool_calls = last_message.tool_calls
253
+ tool_messages = []
254
+
255
+ # Safety Check: Ensure interaction check happens *before* prescribing the *same* drug
256
+ prescribe_calls = {call['args'].get('medication_name'): call['id'] for call in tool_calls if call['name'] == 'prescribe_medication'}
257
+ interaction_check_calls = {call['args'].get('potential_prescription'): call['id'] for call in tool_calls if call['name'] == 'check_drug_interactions'}
258
+
259
+ for med_name, prescribe_call_id in prescribe_calls.items():
260
+ if med_name not in interaction_check_calls:
261
+ st.error(f"**Safety Violation:** AI attempted to prescribe '{med_name}' without requesting `check_drug_interactions` in the *same turn*. Prescription blocked for this turn.")
262
+ # Create an error ToolMessage to send back to the LLM
263
+ error_msg = ToolMessage(
264
+ content=json.dumps({"status": "error", "message": f"Interaction check for {med_name} must be requested *before or alongside* the prescription call."}),
265
+ tool_call_id=prescribe_call_id
266
+ )
267
+ tool_messages.append(error_msg)
268
+ # Remove the invalid prescribe call to prevent execution
269
+ tool_calls = [call for call in tool_calls if call['id'] != prescribe_call_id]
270
+
271
+
272
+ # Add patient context to interaction checks if needed
273
+ patient_meds = state.get("patient_data", {}).get("medications", {}).get("names_only", [])
274
+ patient_allergies = state.get("patient_data", {}).get("allergies", [])
275
+ for call in tool_calls:
276
+ if call['name'] == 'check_drug_interactions':
277
+ call['args']['current_medications'] = patient_meds
278
+ call['args']['allergies'] = patient_allergies
279
+ print(f"Augmented interaction check args: {call['args']}")
280
+
281
+
282
+ # Execute remaining valid tool calls
283
+ if tool_calls:
284
+ responses = tool_executor.batch(tool_calls)
285
+ # Responses is a list of tool outputs corresponding to tool_calls
286
+ # We need to create ToolMessage objects
287
+ tool_messages.extend([
288
+ ToolMessage(content=str(resp), tool_call_id=call['id'])
289
+ for call, resp in zip(tool_calls, responses)
290
+ ])
291
+ print(f"Tool results: {tool_messages}")
292
+
293
+ return {"messages": tool_messages}
294
+
295
+
296
+ # --- Graph Edges (Routing Logic) ---
297
+ def should_continue(state: AgentState) -> str:
298
+ """Determines whether to continue the loop or end."""
299
+ last_message = state['messages'][-1]
300
+ # If the LLM made tool calls, we execute them
301
+ if isinstance(last_message, AIMessage) and last_message.tool_calls:
302
+ print("Routing: continue_tools")
303
+ return "continue_tools"
304
+ # Otherwise, we end the loop (AI provided a direct answer or finished)
305
+ else:
306
+ print("Routing: end_conversation_turn")
307
+ return "end_conversation_turn"
308
+
309
+ # --- Graph Definition ---
310
+ workflow = StateGraph(AgentState)
311
+
312
+ # Add nodes
313
+ workflow.add_node("agent", agent_node)
314
+ workflow.add_node("tools", tool_node)
315
+
316
+ # Define entry point
317
+ workflow.set_entry_point("agent")
318
+
319
+ # Add conditional edges
320
+ workflow.add_conditional_edges(
321
+ "agent", # Source node
322
+ should_continue, # Function to decide the route
323
+ {
324
+ "continue_tools": "tools", # If tool calls exist, go to tools node
325
+ "end_conversation_turn": END # Otherwise, end the graph iteration
326
+ }
327
+ )
328
+
329
+ # Add edge from tools back to agent
330
+ workflow.add_edge("tools", "agent")
331
+
332
+ # Compile the graph
333
+ # memory = MemorySaverInMemory() # Optional: for persisting state across runs
334
+ # app = workflow.compile(checkpointer=memory)
335
+ app = workflow.compile()
336
+
337
+ # --- Helper Function to Format Patient Data ---
338
+ def format_patient_data_for_prompt(data: dict) -> str:
339
+ """Formats the patient dictionary into a readable string for the LLM."""
340
+ prompt_str = ""
341
+ for key, value in data.items():
342
+ if isinstance(value, dict):
343
+ section_title = key.replace('_', ' ').title()
344
+ prompt_str += f"**{section_title}:**\n"
345
+ for sub_key, sub_value in value.items():
346
+ if sub_value:
347
+ prompt_str += f" - {sub_key.replace('_', ' ').title()}: {sub_value}\n"
348
+ elif isinstance(value, list) and value:
349
+ prompt_str += f"**{key.replace('_', ' ').title()}:** {', '.join(map(str, value))}\n"
350
+ elif value:
351
+ prompt_str += f"**{key.replace('_', ' ').title()}:** {value}\n"
352
+ return prompt_str.strip()
353
+
354
+ # --- Streamlit UI (Modified for Conversation) ---
355
+ def main():
356
+ st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
357
+ st.title(f"🩺 {ClinicalAppSettings.APP_TITLE}")
358
+ st.caption(f"Interactive Assistant | Powered by Langchain/LangGraph & Groq ({ClinicalAppSettings.MODEL_NAME})")
359
+
360
+ # Initialize session state for conversation
361
+ if "messages" not in st.session_state:
362
+ st.session_state.messages = [] # Store entire conversation history (Human, AI, Tool)
363
+ if "patient_data" not in st.session_state:
364
+ st.session_state.patient_data = None
365
+ if "initial_analysis_done" not in st.session_state:
366
+ st.session_state.initial_analysis_done = False
367
+ if "graph_app" not in st.session_state:
368
+ st.session_state.graph_app = app # Store compiled graph
369
+
370
+ # --- Patient Data Input Sidebar --- (Similar to before)
371
+ with st.sidebar:
372
+ st.header("📄 Patient Intake Form")
373
+ # ... (Keep the input fields exactly as in the previous example) ...
374
+ # Demographics
375
+ age = st.number_input("Age", min_value=0, max_value=120, value=55, key="age_input")
376
+ sex = st.selectbox("Biological Sex", ["Male", "Female", "Other/Prefer not to say"], key="sex_input")
377
+ # HPI
378
+ chief_complaint = st.text_input("Chief Complaint", "Chest pain", key="cc_input")
379
+ hpi_details = st.text_area("Detailed HPI", "55 y/o male presents with substernal chest pain started 2 hours ago...", key="hpi_input")
380
+ symptoms = st.multiselect("Associated Symptoms", ["Nausea", "Diaphoresis", "Shortness of Breath", "Dizziness", "Palpitations", "Fever", "Cough"], default=["Nausea", "Diaphoresis"], key="sym_input")
381
+ # History
382
+ pmh = st.text_area("Past Medical History (PMH)", "Hypertension (HTN), Hyperlipidemia (HLD), Type 2 Diabetes Mellitus (DM2)", key="pmh_input")
383
+ psh = st.text_area("Past Surgical History (PSH)", "Appendectomy (2005)", key="psh_input")
384
+ # Meds & Allergies
385
+ current_meds_str = st.text_area("Current Medications (name, dose, freq)", "Lisinopril 10mg daily\nMetformin 1000mg BID\nAtorvastatin 40mg daily\nAspirin 81mg daily", key="meds_input")
386
+ allergies_str = st.text_area("Allergies (comma separated)", "Penicillin (rash)", key="allergy_input")
387
+ # Social/Family
388
+ social_history = st.text_area("Social History (SH)", "Smoker (1 ppd x 30 years), occasional alcohol.", key="sh_input")
389
+ family_history = st.text_area("Family History (FHx)", "Father had MI at age 60. Mother has HTN.", key="fhx_input")
390
+ # Vitals/Exam
391
+ col1, col2 = st.columns(2)
392
+ with col1:
393
+ temp_c = st.number_input("Temp (°C)", 35.0, 42.0, 36.8, format="%.1f", key="temp_input")
394
+ hr_bpm = st.number_input("HR (bpm)", 30, 250, 95, key="hr_input")
395
+ rr_rpm = st.number_input("RR (rpm)", 5, 50, 18, key="rr_input")
396
+ with col2:
397
+ bp_mmhg = st.text_input("BP (SYS/DIA)", "155/90", key="bp_input")
398
+ spo2_percent = st.number_input("SpO2 (%)", 70, 100, 96, key="spo2_input")
399
+ pain_scale = st.slider("Pain (0-10)", 0, 10, 8, key="pain_input")
400
+ exam_notes = st.text_area("Brief Physical Exam Notes", "Awake, alert, oriented x3...", key="exam_input")
401
+
402
+ # Compile Patient Data Dictionary on button press
403
+ if st.button("Start/Update Consultation", key="start_button"):
404
+ current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
405
+ current_med_names = []
406
+ # Improved parsing for names (still basic, assumes name is first word)
407
+ for med in current_meds_list:
408
+ match = re.match(r"^\s*([a-zA-Z\-]+)", med)
409
+ if match:
410
+ current_med_names.append(match.group(1).lower()) # Use lower case for matching
411
+
412
+ allergies_list = [a.strip().lower() for a in allergies_str.split(',') if a.strip()] # Lowercase allergies
413
+
414
+ st.session_state.patient_data = {
415
+ "demographics": {"age": age, "sex": sex},
416
+ "hpi": {"chief_complaint": chief_complaint, "details": hpi_details, "symptoms": symptoms},
417
+ "pmh": {"conditions": pmh}, "psh": {"procedures": psh},
418
+ "medications": {"current": current_meds_list, "names_only": current_med_names},
419
+ "allergies": allergies_list,
420
+ "social_history": {"details": social_history}, "family_history": {"details": family_history},
421
+ "vitals": { "temp_c": temp_c, "hr_bpm": hr_bpm, "bp_mmhg": bp_mmhg, "rr_rpm": rr_rpm, "spo2_percent": spo2_percent, "pain_scale": pain_scale},
422
+ "exam_findings": {"notes": exam_notes}
423
+ }
424
+
425
+ # Initial Red Flag Check (Client-side)
426
+ red_flags = check_red_flags(st.session_state.patient_data)
427
+ if red_flags:
428
+ st.warning("**Initial Red Flags Detected:**")
429
+ for flag in red_flags: st.warning(f"- {flag}")
430
+
431
+ # Prepare initial message for the graph
432
+ initial_prompt = f"Analyze the following patient case:\nChief Complaint: {chief_complaint}\nSummary: {age} y/o {sex} presenting with..." # Keep it brief, full data is in state
433
+ st.session_state.messages = [HumanMessage(content=initial_prompt)]
434
+ st.session_state.initial_analysis_done = False # Reset analysis state
435
+ st.success("Patient data loaded. Ready for analysis.")
436
+ st.rerun() # Refresh main area to show chat
437
+
438
+
439
+ # --- Main Chat Interface Area ---
440
+ st.header("💬 Clinical Consultation")
441
+
442
+ # Display chat messages
443
+ for msg in st.session_state.messages:
444
+ if isinstance(msg, HumanMessage):
445
+ with st.chat_message("user"):
446
+ st.markdown(msg.content)
447
+ elif isinstance(msg, AIMessage):
448
+ with st.chat_message("assistant"):
449
+ # Check for structured JSON output
450
+ structured_output = None
451
  try:
452
+ # Try to find JSON block first
453
+ json_match = re.search(r"```json\n(\{.*?\})\n```", msg.content, re.DOTALL)
 
454
  if json_match:
455
+ structured_output = json.loads(json_match.group(1))
456
+ # Display non-JSON parts if any
457
+ non_json_content = msg.content.replace(json_match.group(0), "").strip()
458
+ if non_json_content:
459
+ st.markdown(non_json_content)
460
+ st.divider() # Separate text from structured output visually
461
+ elif msg.content.strip().startswith("{") and msg.content.strip().endswith("}"):
462
+ # Maybe the whole message is JSON
463
+ structured_output = json.loads(msg.content)
464
  else:
465
+ # No JSON found, display raw content
466
+ st.markdown(msg.content)
467
+
468
+ if structured_output:
469
+ # Display the structured data nicely (reuse parts of previous UI display logic)
470
+ st.subheader("📊 AI Analysis & Recommendations")
471
+ # ... (Add logic here to display assessment, ddx, plan etc. from structured_output)
472
+ # Example:
473
+ st.write(f"**Assessment:** {structured_output.get('assessment', 'N/A')}")
474
+ # Display DDx, Plan etc. using expanders or tabs
475
+ # ...
476
+ # Display Rationale & Interaction Summary
477
+ with st.expander("Rationale & Guideline Check"):
478
+ st.write(structured_output.get("rationale_summary", "N/A"))
479
+ if structured_output.get("interaction_check_summary"):
480
+ with st.expander("Interaction Check"):
481
+ st.write(structured_output.get("interaction_check_summary"))
482
+
483
+
484
+ except json.JSONDecodeError:
485
+ st.markdown(msg.content) # Display raw if JSON parsing fails
486
+
487
+ # Display tool calls if any were made in this AI turn
488
+ if msg.tool_calls:
489
+ with st.expander("🛠️ AI requested actions", expanded=False):
490
+ for tc in msg.tool_calls:
491
+ st.code(f"{tc['name']}(args={tc['args']})", language="python")
492
+
493
+ elif isinstance(msg, ToolMessage):
494
+ with st.chat_message("tool", avatar="🛠️"):
495
+ try:
496
+ tool_data = json.loads(msg.content)
497
+ status = tool_data.get("status", "info")
498
+ message = tool_data.get("message", msg.content)
499
+ details = tool_data.get("details")
500
+ warnings = tool_data.get("warnings")
501
+
502
+ if status == "success" or status == "clear" or status == "flagged":
503
+ st.success(f"Tool Result ({msg.name}): {message}", icon="✅" if status != "flagged" else "🚨")
504
+ elif status == "warning":
505
+ st.warning(f"Tool Result ({msg.name}): {message}", icon="⚠️")
506
+ if warnings:
507
+ for warn in warnings: st.caption(f"- {warn}")
508
+ else: # Error or unknown status
509
+ st.error(f"Tool Result ({msg.name}): {message}", icon="❌")
510
+
511
+ if details: st.caption(f"Details: {details}")
512
+
513
  except json.JSONDecodeError:
514
+ st.info(f"Tool Result ({msg.name}): {msg.content}") # Display raw if not JSON
 
 
515
 
 
 
 
516
 
517
+ # Chat input for user
518
+ if prompt := st.chat_input("Your message or follow-up query..."):
519
+ if not st.session_state.patient_data:
520
+ st.warning("Please load patient data using the sidebar first.")
521
+ else:
522
+ # Add user message to state
523
+ st.session_state.messages.append(HumanMessage(content=prompt))
524
+ with st.chat_message("user"):
525
+ st.markdown(prompt)
526
+
527
+ # Prepare state for graph invocation
528
+ current_state = AgentState(
529
+ messages=st.session_state.messages,
530
+ patient_data=st.session_state.patient_data
531
+ )
532
+
533
+ # Stream graph execution
534
+ with st.chat_message("assistant"):
535
+ message_placeholder = st.empty()
536
+ full_response = ""
537
+
538
+ # Use stream to get intermediate steps (optional but good for UX)
539
+ # This shows AI thinking and tool calls/results progressively
540
+ try:
541
+ for event in st.session_state.graph_app.stream(current_state, {"recursion_limit": 15}):
542
+ # event is a dictionary, keys are node names
543
+ if "agent" in event:
544
+ ai_msg = event["agent"]["messages"][-1] # Get the latest AI message
545
+ if isinstance(ai_msg, AIMessage):
546
+ full_response += ai_msg.content # Append content for final display
547
+ message_placeholder.markdown(full_response + "▌") # Show typing indicator
548
 
549
+ # Display tool calls as they happen (optional)
550
+ # if ai_msg.tool_calls:
551
+ # st.info(f"Requesting tools: {[tc['name'] for tc in ai_msg.tool_calls]}")
552
 
553
+ elif "tools" in event:
554
+ # Display tool results as they come back (optional, already handled by message display loop)
555
+ pass
556
+ # tool_msgs = event["tools"]["messages"]
557
+ # for tool_msg in tool_msgs:
558
+ # st.info(f"Tool {tool_msg.name} result received.")
559
 
 
 
560
 
561
+ # Final display after streaming
562
+ message_placeholder.markdown(full_response)
 
 
 
 
563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
 
565
+ # Update session state with the final messages from the graph run
566
+ # The graph state itself isn't directly accessible after streaming finishes easily this way
567
+ # We need to get the final state if we used invoke, or reconstruct from stream events
568
+ # A simpler way for now: just append the *last* AI message and any Tool messages from the stream
569
+ # This assumes the stream provides the final state implicitly. For robust state, use invoke or checkpointer.
570
+
571
+ # A more robust way: invoke and get final state
572
+ # final_state = st.session_state.graph_app.invoke(current_state, {"recursion_limit": 15})
573
+ # st.session_state.messages = final_state['messages']
574
+ # --- Let's stick to appending for simplicity in this example ---
575
+ # Find the last AI message and tool messages from the stream (needs careful event parsing)
576
+ # Or, re-run invoke non-streamed just to get final state (less efficient)
577
+ final_state_capture = st.session_state.graph_app.invoke(current_state, {"recursion_limit": 15})
578
+ st.session_state.messages = final_state_capture['messages']
579
+
580
+
581
+ except Exception as e:
582
+ st.error(f"An error occurred during analysis: {e}")
583
+ # Attempt to add the error message to the history
584
+ st.session_state.messages.append(AIMessage(content=f"Sorry, an error occurred: {e}"))
585
+
586
+
587
+ # Rerun to display the updated chat history correctly
588
+ st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
 
591
  # Disclaimer
592
  st.markdown("---")
593
+ st.warning("**Disclaimer:** SynapseAI is for clinical decision support...") # Keep disclaimer
 
 
 
 
 
594
 
595
  if __name__ == "__main__":
596
  main()