mgbam commited on
Commit
b34efbf
Β·
verified Β·
1 Parent(s): 290645f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +409 -446
app.py CHANGED
@@ -1,4 +1,13 @@
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
2
  from langchain_groq import ChatGroq
3
  from langchain_community.tools.tavily_search import TavilySearchResults
4
  from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
@@ -7,24 +16,37 @@ from langchain_core.pydantic_v1 import BaseModel, Field
7
  from langchain_core.tools import tool
8
  from langgraph.prebuilt import ToolExecutor
9
  from langgraph.graph import StateGraph, END
10
- # from langgraph.checkpoint.memory import MemorySaverInMemory # Optional for state persistence
11
 
12
  from typing import Optional, List, Dict, Any, TypedDict, Annotated
13
- import json
14
- import re
15
- import operator
16
- import traceback # For detailed error logging
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # --- Configuration & Constants ---
19
  class ClinicalAppSettings:
20
- APP_TITLE = "SynapseAI: Interactive Clinical Decision Support"
21
  PAGE_LAYOUT = "wide"
22
  MODEL_NAME = "llama3-70b-8192" # Groq Llama3 70b
23
  TEMPERATURE = 0.1
24
  MAX_SEARCH_RESULTS = 3
25
 
26
  class ClinicalPrompts:
27
- # UPDATED SYSTEM PROMPT FOR CONVERSATIONAL FLOW, GUIDELINES & STRUCTURED OUTPUT FOCUS
 
28
  SYSTEM_PROMPT = """
29
  You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation.
30
  Your goal is to support healthcare professionals by analyzing patient data, providing differential diagnoses, suggesting evidence-based management plans, and identifying risks according to current standards of care.
@@ -63,45 +85,137 @@ class ClinicalPrompts:
63
  8. **Conciseness & Flow:** Be medically accurate and concise. Use standard terminology. Respond naturally in conversation (asking questions, acknowledging info) until ready for the full structured JSON output.
64
  """
65
 
66
- # --- Mock Data / Helpers ---
67
- MOCK_INTERACTION_DB = {
68
- ("lisinopril", "spironolactone"): "High risk of hyperkalemia. Monitor potassium closely.",
69
- ("warfarin", "amiodarone"): "Increased bleeding risk. Monitor INR frequently and adjust Warfarin dose.",
70
- ("simvastatin", "clarithromycin"): "Increased risk of myopathy/rhabdomyolysis. Avoid combination or use lower statin dose.",
71
- ("aspirin", "ibuprofen"): "Concurrent use may decrease Aspirin's cardioprotective effect. Potential for increased GI bleeding.",
72
- # Add lower case versions for easier lookup
73
- ("amiodarone", "warfarin"): "Increased bleeding risk. Monitor INR frequently and adjust Warfarin dose.",
74
- ("clarithromycin", "simvastatin"): "Increased risk of myopathy/rhabdomyolysis. Avoid combination or use lower statin dose.",
75
- ("ibuprofen", "aspirin"): "Concurrent use may decrease Aspirin's cardioprotective effect. Potential for increased GI bleeding.",
76
- ("spironolactone", "lisinopril"): "High risk of hyperkalemia. Monitor potassium closely.",
77
- }
78
-
79
- ALLERGY_INTERACTIONS = {
80
- "penicillin": ["amoxicillin", "ampicillin", "piperacillin", "augmentin"],
81
- "sulfa": ["sulfamethoxazole", "sulfasalazine", "bactrim"],
82
- "aspirin": ["ibuprofen", "naproxen", "nsaid"] # Cross-reactivity example for NSAIDs
83
- }
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
86
  """Parses BP string like '120/80' into (systolic, diastolic) integers."""
87
  if not isinstance(bp_string, str): return None
88
  match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string.strip())
89
- if match:
90
- return int(match.group(1)), int(match.group(2))
91
  return None
92
 
93
  def check_red_flags(patient_data: dict) -> List[str]:
94
  """Checks patient data against predefined red flags."""
 
95
  flags = []
96
  if not patient_data: return flags
97
-
98
  symptoms = patient_data.get("hpi", {}).get("symptoms", [])
99
  vitals = patient_data.get("vitals", {})
100
  history = patient_data.get("pmh", {}).get("conditions", "")
101
- # Ensure symptoms are strings and lowercased
102
  symptoms_lower = [str(s).lower() for s in symptoms if isinstance(s, str)]
103
 
104
- # Symptom Flags
105
  if "chest pain" in symptoms_lower: flags.append("Red Flag: Chest Pain reported.")
106
  if "shortness of breath" in symptoms_lower: flags.append("Red Flag: Shortness of Breath reported.")
107
  if "severe headache" in symptoms_lower: flags.append("Red Flag: Severe Headache reported.")
@@ -110,36 +224,31 @@ def check_red_flags(patient_data: dict) -> List[str]:
110
  if "hemoptysis" in symptoms_lower: flags.append("Red Flag: Hemoptysis (coughing up blood).")
111
  if "syncope" in symptoms_lower: flags.append("Red Flag: Syncope (fainting).")
112
 
113
- # Vital Sign Flags
114
  if vitals:
115
- temp = vitals.get("temp_c")
116
- hr = vitals.get("hr_bpm")
117
- rr = vitals.get("rr_rpm")
118
- spo2 = vitals.get("spo2_percent")
119
- bp_str = vitals.get("bp_mmhg")
120
-
121
- if temp is not None and temp >= 38.5: flags.append(f"Red Flag: Fever (Temperature: {temp}Β°C).")
122
- if hr is not None and hr >= 120: flags.append(f"Red Flag: Tachycardia (Heart Rate: {hr} bpm).")
123
- if hr is not None and hr <= 50: flags.append(f"Red Flag: Bradycardia (Heart Rate: {hr} bpm).")
124
- if rr is not None and rr >= 24: flags.append(f"Red Flag: Tachypnea (Respiratory Rate: {rr} rpm).")
125
- if spo2 is not None and spo2 <= 92: flags.append(f"Red Flag: Hypoxia (SpO2: {spo2}%).")
126
  if bp_str:
127
  bp = parse_bp(bp_str)
128
  if bp:
129
  if bp[0] >= 180 or bp[1] >= 110: flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {bp_str} mmHg).")
130
  if bp[0] <= 90 or bp[1] <= 60: flags.append(f"Red Flag: Hypotension (BP: {bp_str} mmHg).")
131
 
132
- # History Flags (Simple examples)
133
- if history and "history of mi" in history.lower() and "chest pain" in symptoms_lower:
134
- flags.append("Red Flag: History of MI with current Chest Pain.")
135
- if history and "history of dvt/pe" in history.lower() and "shortness of breath" in symptoms_lower:
136
- flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.")
 
137
 
138
- # Remove duplicates
139
- return list(set(flags))
140
 
141
  def format_patient_data_for_prompt(data: dict) -> str:
142
  """Formats the patient dictionary into a readable string for the LLM."""
 
143
  if not data: return "No patient data provided."
144
  prompt_str = ""
145
  for key, value in data.items():
@@ -149,34 +258,22 @@ def format_patient_data_for_prompt(data: dict) -> str:
149
  if has_content:
150
  prompt_str += f"**{section_title}:**\n"
151
  for sub_key, sub_value in value.items():
152
- if sub_value:
153
- prompt_str += f" - {sub_key.replace('_', ' ').title()}: {sub_value}\n"
154
  elif isinstance(value, list) and value:
155
  prompt_str += f"**{section_title}:** {', '.join(map(str, value))}\n"
156
- elif value and not isinstance(value, dict): # Check it's not an empty dict
157
  prompt_str += f"**{section_title}:** {value}\n"
158
  return prompt_str.strip()
159
 
160
 
161
  # --- Tool Definitions ---
162
 
163
- # Pydantic models for robust argument validation
164
  class LabOrderInput(BaseModel):
165
  test_name: str = Field(..., description="Specific name of the lab test or panel (e.g., 'CBC', 'BMP', 'Troponin I', 'Urinalysis', 'D-dimer').")
166
  reason: str = Field(..., description="Clinical justification for ordering the test (e.g., 'Rule out infection', 'Assess renal function', 'Evaluate for ACS', 'Assess for PE').")
167
  priority: str = Field("Routine", description="Priority of the test (e.g., 'STAT', 'Routine').")
168
 
169
- @tool("order_lab_test", args_schema=LabOrderInput)
170
- def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
171
- """Orders a specific lab test with clinical justification and priority."""
172
- print(f"Executing order_lab_test: {test_name}, Reason: {reason}, Priority: {priority}")
173
- # In a real system, this would integrate with an LIS/EMR API
174
- return json.dumps({
175
- "status": "success",
176
- "message": f"Lab Ordered: {test_name} ({priority})",
177
- "details": f"Reason: {reason}"
178
- })
179
-
180
  class PrescriptionInput(BaseModel):
181
  medication_name: str = Field(..., description="Name of the medication.")
182
  dosage: str = Field(..., description="Dosage amount and unit (e.g., '500 mg', '10 mg', '81 mg').")
@@ -185,105 +282,165 @@ class PrescriptionInput(BaseModel):
185
  duration: str = Field("As directed", description="Duration of treatment (e.g., '7 days', '1 month', 'Ongoing', 'Until follow-up').")
186
  reason: str = Field(..., description="Clinical indication for the prescription.")
187
 
188
- @tool("prescribe_medication", args_schema=PrescriptionInput)
189
- def prescribe_medication(medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str) -> str:
190
- """Prescribes a medication with detailed instructions and clinical indication. IMPORTANT: Requires prior interaction check."""
191
- print(f"Executing prescribe_medication: {medication_name} {dosage}...")
192
- # NOTE: The safety check (ensuring interaction check was requested) happens in the tool_node *before* this function is called.
193
- # In a real system, this would trigger an e-prescription workflow.
194
- return json.dumps({
195
- "status": "success",
196
- "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}",
197
- "details": f"Duration: {duration}. Reason: {reason}"
198
- })
199
-
200
  class InteractionCheckInput(BaseModel):
201
  potential_prescription: str = Field(..., description="The name of the NEW medication being considered for prescribing.")
202
- # These next two args are now populated by the tool_node using AgentState
203
- # current_medications: List[str] = Field(..., description="List of the patient's CURRENT medication names.")
204
- # allergies: List[str] = Field(..., description="List of the patient's known allergies.")
205
- # Make them optional in the schema, mandatory in the node logic
206
  current_medications: Optional[List[str]] = Field(None, description="List of patient's current medication names (populated from state).")
207
  allergies: Optional[List[str]] = Field(None, description="List of patient's known allergies (populated from state).")
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
 
210
  @tool("check_drug_interactions", args_schema=InteractionCheckInput)
211
  def check_drug_interactions(potential_prescription: str, current_medications: Optional[List[str]] = None, allergies: Optional[List[str]] = None) -> str:
212
- """Checks for potential drug-drug and drug-allergy interactions BEFORE prescribing."""
213
- print(f"Executing check_drug_interactions for: {potential_prescription}")
 
 
 
 
214
  warnings = []
215
- potential_med_lower = potential_prescription.lower()
216
 
217
- # Use provided lists or default to empty if None (should be populated by tool_node)
218
  current_meds_list = current_medications or []
219
  allergies_list = allergies or []
220
-
221
- current_meds_lower = [str(med).lower() for med in current_meds_list]
222
- allergies_lower = [str(a).lower() for a in allergies_list]
223
-
224
- print(f" Checking against Meds: {current_meds_lower}")
225
- print(f" Checking against Allergies: {allergies_lower}")
226
-
227
- # Check Allergies
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  for allergy in allergies_lower:
229
- # Direct match
230
  if allergy == potential_med_lower:
231
- warnings.append(f"CRITICAL ALLERGY: Patient allergic to '{allergy}'. Cannot prescribe '{potential_prescription}'.")
232
- continue # Don't check cross-reactivity if direct match
233
- # Check cross-reactivity
234
- if allergy in ALLERGY_INTERACTIONS:
235
- for cross_reactant in ALLERGY_INTERACTIONS[allergy]:
236
- if cross_reactant.lower() == potential_med_lower:
237
- warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to '{allergy}'. High risk with '{potential_prescription}'.")
238
-
239
- # Check Drug-Drug Interactions
240
- for current_med in current_meds_lower:
241
- # Check pairs in both orders using the mock DB
242
- pair1 = (current_med, potential_med_lower)
243
- pair2 = (potential_med_lower, current_med)
244
- interaction_msg = MOCK_INTERACTION_DB.get(pair1) or MOCK_INTERACTION_DB.get(pair2)
245
- if interaction_msg:
246
- warnings.append(f"Interaction: {potential_prescription.capitalize()} with {current_med.capitalize()} - {interaction_msg}")
247
-
248
- status = "warning" if warnings else "clear"
249
- message = f"Interaction check for '{potential_prescription}': {len(warnings)} potential issue(s) found." if warnings else f"No major interactions identified for '{potential_prescription}' based on provided lists."
250
- print(f" Interaction Check Result: {status}, Message: {message}, Warnings: {warnings}")
251
- return json.dumps({"status": status, "message": message, "warnings": warnings})
252
-
253
-
254
- class FlagRiskInput(BaseModel):
255
- risk_description: str = Field(..., description="Specific critical risk identified (e.g., 'Suspected Sepsis', 'Acute Coronary Syndrome', 'Stroke Alert').")
256
- urgency: str = Field("High", description="Urgency level (e.g., 'Critical', 'High', 'Moderate').")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  @tool("flag_risk", args_schema=FlagRiskInput)
259
  def flag_risk(risk_description: str, urgency: str) -> str:
260
  """Flags a critical risk identified during analysis for immediate attention."""
261
  print(f"Executing flag_risk: {risk_description}, Urgency: {urgency}")
262
- # Display in Streamlit immediately
263
  st.error(f"🚨 **{urgency.upper()} RISK FLAGGED by AI:** {risk_description}", icon="🚨")
264
- return json.dumps({
265
- "status": "flagged",
266
- "message": f"Risk '{risk_description}' flagged with {urgency} urgency."
267
- })
268
 
269
  # Initialize Search Tool
270
  search_tool = TavilySearchResults(
271
  max_results=ClinicalAppSettings.MAX_SEARCH_RESULTS,
272
- name="tavily_search_results" # Explicitly name the tool
273
  )
274
 
275
  # --- LangGraph Setup ---
276
 
277
  # Define the state structure
278
  class AgentState(TypedDict):
279
- messages: Annotated[list[Any], operator.add] # Accumulates messages (Human, AI, Tool)
280
- patient_data: Optional[dict] # Holds the structured patient data
281
 
282
  # Define Tools and Tool Executor
283
  tools = [
284
  order_lab_test,
285
  prescribe_medication,
286
- check_drug_interactions,
287
  flag_risk,
288
  search_tool
289
  ]
@@ -293,239 +450,132 @@ tool_executor = ToolExecutor(tools)
293
  model = ChatGroq(
294
  temperature=ClinicalAppSettings.TEMPERATURE,
295
  model=ClinicalAppSettings.MODEL_NAME,
296
- # Increase max_tokens if needed for large JSON output + conversation history
297
- # max_tokens=4096
298
  )
299
- # Bind tools FOR the model to know their schemas and descriptions
300
  model_with_tools = model.bind_tools(tools)
301
 
302
- # --- Graph Nodes ---
303
 
304
- # 1. Agent Node: Calls the LLM
305
  def agent_node(state: AgentState):
306
  """Invokes the LLM to decide the next action or response."""
307
  print("\n---AGENT NODE---")
308
  current_messages = state['messages']
309
- # Ensure System Prompt is present
310
  if not current_messages or not isinstance(current_messages[0], SystemMessage):
311
  print("Prepending System Prompt.")
312
  current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages
313
-
314
- # Optional: Augment first human message with patient data if not already done explicitly
315
- # This helps ensure the LLM sees it early, though it's also in the state.
316
- # Be mindful of context window limits.
317
- # if len(current_messages) > 1 and isinstance(current_messages[1], HumanMessage) and state.get('patient_data'):
318
- # if "**Initial Patient Data:**" not in current_messages[1].content:
319
- # print("Augmenting first HumanMessage with patient data summary.")
320
- # formatted_data = format_patient_data_for_prompt(state['patient_data'])
321
- # current_messages[1] = HumanMessage(content=f"{current_messages[1].content}\n\n**Initial Patient Data Summary:**\n{formatted_data}")
322
-
323
  print(f"Invoking LLM with {len(current_messages)} messages.")
324
- # print(f"Messages Sent: {[m.type for m in current_messages]}") # Log message types
325
  try:
326
  response = model_with_tools.invoke(current_messages)
327
  print(f"Agent Raw Response Type: {type(response)}")
328
- # print(f"Agent Raw Response Content: {response.content}")
329
- if hasattr(response, 'tool_calls') and response.tool_calls:
330
- print(f"Agent Response Tool Calls: {response.tool_calls}")
331
- else:
332
- print("Agent Response: No tool calls.")
333
-
334
  except Exception as e:
335
  print(f"ERROR in agent_node during LLM invocation: {type(e).__name__} - {e}")
336
- traceback.print_exc() # Print full traceback for debugging
337
- # Return an error message to the graph state
338
  error_message = AIMessage(content=f"Sorry, an internal error occurred while processing the request: {type(e).__name__}")
339
  return {"messages": [error_message]}
340
-
341
  return {"messages": [response]}
342
 
343
- # 2. Tool Node: Executes tools called by the Agent (REVISED WITH ROBUST ERROR HANDLING)
344
  def tool_node(state: AgentState):
345
  """Executes tools called by the LLM and returns results."""
346
  print("\n---TOOL NODE---")
347
- tool_messages = [] # Initialize list to store results or errors
348
  last_message = state['messages'][-1]
349
 
350
- # Ensure the last message is an AIMessage with tool calls
351
  if not isinstance(last_message, AIMessage) or not getattr(last_message, 'tool_calls', None):
352
- print("Warning: Tool node called unexpectedly without tool calls in the last AI message.")
353
- # If this happens, it might indicate a routing issue or the LLM hallucinating flow.
354
- # Returning empty list lets the agent proceed, potentially without needed info.
355
- # Consider adding a ToolMessage indicating the issue if needed.
356
  return {"messages": []}
357
 
358
  tool_calls = last_message.tool_calls
359
- print(f"Tool calls received: {json.dumps(tool_calls, indent=2)}") # Log received calls
360
-
361
- # Safety Check: Identify required interaction checks before prescriptions
362
- prescriptions_requested = {} # medication_name_lower -> tool_call
363
- interaction_checks_requested = {} # medication_name_lower -> tool_call
364
 
 
 
 
365
  for call in tool_calls:
366
- tool_name = call.get('name')
367
- tool_args = call.get('args', {})
368
  if tool_name == 'prescribe_medication':
369
- med_name = tool_args.get('medication_name', '').lower()
370
- if med_name:
371
- prescriptions_requested[med_name] = call
372
  elif tool_name == 'check_drug_interactions':
373
  potential_med = tool_args.get('potential_prescription', '').lower()
374
- if potential_med:
375
- interaction_checks_requested[potential_med] = call
376
 
377
  valid_tool_calls_for_execution = []
378
-
379
- # Validate prescriptions against interaction checks
380
  for med_name, prescribe_call in prescriptions_requested.items():
381
  if med_name not in interaction_checks_requested:
382
  st.error(f"**Safety Violation:** AI attempted to prescribe '{med_name}' without requesting `check_drug_interactions` in the *same turn*. Prescription blocked.")
383
- error_msg = ToolMessage(
384
- content=json.dumps({"status": "error", "message": f"Interaction check for '{med_name}' must be requested *before or alongside* the prescription call."}),
385
- tool_call_id=prescribe_call['id'],
386
- name=prescribe_call['name'] # Include tool name in ToolMessage
387
- )
388
  tool_messages.append(error_msg)
389
- else:
390
- # Interaction check IS requested, allow prescription call to proceed
391
- pass # The call will be added below if it's in the original tool_calls list
392
 
393
- # Prepare list of calls to execute (all non-blocked calls)
394
- blocked_ids = {msg.tool_call_id for msg in tool_messages if msg.content and '"status": "error"' in msg.content}
395
  valid_tool_calls_for_execution = [call for call in tool_calls if call['id'] not in blocked_ids]
396
 
397
- # Augment interaction checks with patient data from state
398
- patient_meds = state.get("patient_data", {}).get("medications", {}).get("names_only", [])
399
- patient_allergies = state.get("patient_data", {}).get("allergies", [])
 
400
 
401
  for call in valid_tool_calls_for_execution:
402
  if call['name'] == 'check_drug_interactions':
403
- # Ensure args exist before modifying
404
  if 'args' not in call: call['args'] = {}
405
- call['args']['current_medications'] = patient_meds
 
 
406
  call['args']['allergies'] = patient_allergies
407
- print(f"Augmented interaction check args for call ID {call['id']}: {call['args']}")
408
-
409
 
410
- # Execute valid tool calls using batch for efficiency, capturing exceptions
411
  if valid_tool_calls_for_execution:
412
  print(f"Attempting to execute {len(valid_tool_calls_for_execution)} tools: {[c['name'] for c in valid_tool_calls_for_execution]}")
413
  try:
414
  responses = tool_executor.batch(valid_tool_calls_for_execution, return_exceptions=True)
415
-
416
- # Process responses, creating ToolMessage for each
417
  for call, resp in zip(valid_tool_calls_for_execution, responses):
418
- tool_call_id = call['id']
419
- tool_name = call['name']
420
-
421
  if isinstance(resp, Exception):
422
- # Handle exceptions returned by the batch call
423
- error_type = type(resp).__name__
424
- error_str = str(resp)
425
  print(f"ERROR executing tool '{tool_name}' (ID: {tool_call_id}): {error_type} - {error_str}")
426
- traceback.print_exc() # Log full traceback
427
  st.error(f"Error executing action '{tool_name}': {error_type}")
428
- error_content = json.dumps({
429
- "status": "error",
430
- "message": f"Failed to execute '{tool_name}': {error_type} - {error_str}"
431
- })
432
  tool_messages.append(ToolMessage(content=error_content, tool_call_id=tool_call_id, name=tool_name))
433
-
434
- # Specific check for the error mentioned by user
435
  if isinstance(resp, AttributeError) and "'dict' object has no attribute 'tool'" in error_str:
436
- print("\n *** DETECTED SPECIFIC ATTRIBUTE ERROR ('dict' object has no attribute 'tool') ***")
437
- print(f" Tool Call causing error: {json.dumps(call, indent=2)}")
438
- print(" This likely indicates an internal issue within Langchain/LangGraph or ToolExecutor expecting a different object structure.")
439
- print(" Ensure tool definitions (@tool decorators) and Pydantic schemas are correct.\n")
440
-
441
  else:
442
- # Process successful results
443
- print(f"Tool '{tool_name}' (ID: {tool_call_id}) executed successfully. Result type: {type(resp)}")
444
- # Ensure content is string for ToolMessage
445
  content_str = str(resp)
446
  tool_messages.append(ToolMessage(content=content_str, tool_call_id=tool_call_id, name=tool_name))
447
-
448
- # Display result in Streamlit right away for feedback (optional, but helpful)
449
- # This part might be better handled purely in the UI display loop later
450
- # try:
451
- # result_data = json.loads(content_str)
452
- # status = result_data.get("status", "info")
453
- # message = result_data.get("message", content_str)
454
- # if status in ["success", "clear", "flagged"]: st.success(f"Action `{tool_name}` completed: {message}", icon="βœ…" if status != "flagged" else "🚨")
455
- # elif status == "warning": st.warning(f"Action `{tool_name}` completed: {message}", icon="⚠️")
456
- # else: st.info(f"Action `{tool_name}` completed: {message}") # Info for other statuses
457
- # except json.JSONDecodeError:
458
- # st.info(f"Action `{tool_name}` completed (non-JSON output).")
459
-
460
- # Catch potential errors within the tool_node logic itself (e.g., preparing calls)
461
  except Exception as e:
462
  print(f"CRITICAL UNEXPECTED ERROR within tool_node logic: {type(e).__name__} - {e}")
463
- traceback.print_exc()
464
- st.error(f"Critical internal error processing actions: {e}")
465
- # Create generic error messages for all calls that were intended
466
  error_content = json.dumps({"status": "error", "message": f"Internal error processing tools: {e}"})
467
- # Add error messages for calls that didn't get processed yet
468
  processed_ids = {msg.tool_call_id for msg in tool_messages}
469
  for call in valid_tool_calls_for_execution:
470
- if call['id'] not in processed_ids:
471
- tool_messages.append(ToolMessage(content=error_content, tool_call_id=call['id'], name=call['name']))
472
 
473
  print(f"Returning {len(tool_messages)} tool messages.")
474
- # print(f"Tool messages content snippets: {[m.content[:100] + '...' if len(m.content)>100 else m.content for m in tool_messages]}")
475
  return {"messages": tool_messages}
476
 
477
 
478
- # --- Graph Edges (Routing Logic) ---
479
  def should_continue(state: AgentState) -> str:
480
  """Determines whether to call tools, end the conversation turn, or handle errors."""
481
  print("\n---ROUTING DECISION---")
482
  last_message = state['messages'][-1] if state['messages'] else None
 
 
 
 
483
 
484
- if not isinstance(last_message, AIMessage):
485
- # This case might happen if the graph starts with a non-AI message or after an error
486
- print("Routing: Last message not AI. Ending turn.")
487
- return "end_conversation_turn"
488
-
489
- # If the LLM produced an error message (e.g., during invocation)
490
- if "Sorry, an internal error occurred" in last_message.content:
491
- print("Routing: AI returned internal error. Ending turn.")
492
- return "end_conversation_turn"
493
-
494
- # If the LLM made tool calls, execute them
495
- if getattr(last_message, 'tool_calls', None):
496
- print("Routing: AI requested tool calls. Continue to tools node.")
497
- return "continue_tools"
498
- # Otherwise, the AI provided a response without tool calls, end the turn
499
- else:
500
- print("Routing: AI provided final response or asked question. Ending turn.")
501
- return "end_conversation_turn"
502
-
503
- # --- Graph Definition ---
504
  workflow = StateGraph(AgentState)
505
-
506
- # Add nodes
507
  workflow.add_node("agent", agent_node)
508
  workflow.add_node("tools", tool_node)
509
-
510
- # Define entry point
511
  workflow.set_entry_point("agent")
512
-
513
- # Add conditional edges
514
- workflow.add_conditional_edges(
515
- "agent", # Source node
516
- should_continue, # Function to decide the route
517
- {
518
- "continue_tools": "tools", # If tool calls exist, go to tools node
519
- "end_conversation_turn": END # Otherwise, end the graph iteration for this turn
520
- }
521
- )
522
-
523
- # Add edge from tools back to agent
524
  workflow.add_edge("tools", "agent")
525
-
526
- # Compile the graph
527
- # memory = MemorySaverInMemory() # Optional: for persisting state across runs
528
- # app = workflow.compile(checkpointer=memory)
529
  app = workflow.compile()
530
  print("LangGraph compiled successfully.")
531
 
@@ -533,41 +583,34 @@ print("LangGraph compiled successfully.")
533
  def main():
534
  st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
535
  st.title(f"🩺 {ClinicalAppSettings.APP_TITLE}")
536
- st.caption(f"Interactive Assistant | Powered by Langchain/LangGraph & Groq ({ClinicalAppSettings.MODEL_NAME})")
537
 
538
- # Initialize session state
539
- if "messages" not in st.session_state:
540
- st.session_state.messages = [] # Stores full conversation history
541
- if "patient_data" not in st.session_state:
542
- st.session_state.patient_data = None
543
- if "graph_app" not in st.session_state:
544
- st.session_state.graph_app = app
545
 
546
- # --- Patient Data Input Sidebar ---
547
  with st.sidebar:
548
  st.header("πŸ“„ Patient Intake Form")
549
- # Demographics
 
550
  st.subheader("Demographics")
551
  age = st.number_input("Age", min_value=0, max_value=120, value=55, key="age_input")
552
  sex = st.selectbox("Biological Sex", ["Male", "Female", "Other/Prefer not to say"], key="sex_input")
553
- # HPI
554
  st.subheader("History of Present Illness (HPI)")
555
  chief_complaint = st.text_input("Chief Complaint", "Chest pain", key="cc_input")
556
- hpi_details = st.text_area("Detailed HPI", "55 y/o male presents with substernal chest pain started 2 hours ago, described as pressure, radiating to left arm. Associated with nausea and diaphoresis. Pain is 8/10 severity. No relief with rest.", key="hpi_input", height=150)
557
  symptoms = st.multiselect("Associated Symptoms", ["Nausea", "Diaphoresis", "Shortness of Breath", "Dizziness", "Palpitations", "Fever", "Cough", "Severe Headache", "Syncope", "Hemoptysis"], default=["Nausea", "Diaphoresis"], key="sym_input")
558
- # History
559
  st.subheader("Past History")
560
  pmh = st.text_area("Past Medical History (PMH)", "Hypertension (HTN), Hyperlipidemia (HLD), Type 2 Diabetes Mellitus (DM2), History of MI", key="pmh_input")
561
  psh = st.text_area("Past Surgical History (PSH)", "Appendectomy (2005)", key="psh_input")
562
- # Meds & Allergies
563
  st.subheader("Medications & Allergies")
564
  current_meds_str = st.text_area("Current Medications (name, dose, freq)", "Lisinopril 10mg daily\nMetformin 1000mg BID\nAtorvastatin 40mg daily\nAspirin 81mg daily", key="meds_input")
565
  allergies_str = st.text_area("Allergies (comma separated, specify reaction if known)", "Penicillin (rash), Sulfa (hives)", key="allergy_input")
566
- # Social/Family
567
  st.subheader("Social/Family History")
568
  social_history = st.text_area("Social History (SH)", "Smoker (1 ppd x 30 years), occasional alcohol.", key="sh_input")
569
  family_history = st.text_area("Family History (FHx)", "Father had MI at age 60. Mother has HTN.", key="fhx_input")
570
- # Vitals/Exam
571
  st.subheader("Vitals & Exam Findings")
572
  col1, col2 = st.columns(2)
573
  with col1:
@@ -578,241 +621,161 @@ def main():
578
  bp_mmhg = st.text_input("BP (SYS/DIA)", "155/90", key="bp_input")
579
  spo2_percent = st.number_input("SpO2 (%)", 70, 100, 96, key="spo2_input")
580
  pain_scale = st.slider("Pain (0-10)", 0, 10, 8, key="pain_input")
581
- exam_notes = st.text_area("Brief Physical Exam Notes", "Awake, alert, oriented x3. Mild distress. Lungs clear bilaterally. Cardiac exam: Regular rhythm, S1/S2 normal, no murmurs/gallops/rubs. Abdomen soft, non-tender. No lower extremity edema.", key="exam_input", height=100)
582
 
583
- # Compile Patient Data Dictionary on button press
 
584
  if st.button("Start/Update Consultation", key="start_button"):
 
585
  current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
586
- # Basic name extraction (first word, lowercased) for interaction check
587
- current_med_names = []
588
  for med in current_meds_list:
589
  match = re.match(r"^\s*([a-zA-Z\-]+)", med)
590
- if match:
591
- current_med_names.append(match.group(1).lower())
592
 
593
- # Basic allergy extraction (first word or phrase before parenthesis, lowercased)
594
  allergies_list = []
595
  for a in allergies_str.split(','):
596
  cleaned_allergy = a.strip()
597
  if cleaned_allergy:
598
- match = re.match(r"^\s*([a-zA-Z\-\s]+)(?:\s*\(.*\))?", cleaned_allergy)
599
- if match:
600
- allergies_list.append(match.group(1).strip().lower())
601
- else: # Fallback if no parenthesis
602
- allergies_list.append(cleaned_allergy.lower())
603
 
604
  st.session_state.patient_data = {
605
  "demographics": {"age": age, "sex": sex},
606
  "hpi": {"chief_complaint": chief_complaint, "details": hpi_details, "symptoms": symptoms},
607
  "pmh": {"conditions": pmh}, "psh": {"procedures": psh},
608
- "medications": {"current": current_meds_list, "names_only": current_med_names},
609
- "allergies": allergies_list,
 
610
  "social_history": {"details": social_history}, "family_history": {"details": family_history},
611
  "vitals": { "temp_c": temp_c, "hr_bpm": hr_bpm, "bp_mmhg": bp_mmhg, "rr_rpm": rr_rpm, "spo2_percent": spo2_percent, "pain_scale": pain_scale},
612
  "exam_findings": {"notes": exam_notes}
613
  }
614
 
615
- # Initial Red Flag Check (Client-side)
616
  red_flags = check_red_flags(st.session_state.patient_data)
617
  st.sidebar.markdown("---")
618
  if red_flags:
619
  st.sidebar.warning("**Initial Red Flags Detected:**")
620
  for flag in red_flags: st.sidebar.warning(f"- {flag.replace('Red Flag: ','')}")
621
- else:
622
- st.sidebar.success("No immediate red flags detected in initial data.")
623
 
624
- # Prepare initial message for the graph
625
- initial_prompt = f"Initiate consultation for the patient described in the intake form. Start the analysis."
626
- # Clear previous messages and start fresh
627
  st.session_state.messages = [HumanMessage(content=initial_prompt)]
628
- st.success("Patient data loaded. Ready for analysis.")
629
- # No rerun needed here, chat input will trigger the graph
630
 
631
- # --- Main Chat Interface Area ---
632
  st.header("πŸ’¬ Clinical Consultation")
633
 
634
  # Display chat messages from history
 
635
  for msg_index, msg in enumerate(st.session_state.messages):
636
- unique_key = f"msg_{msg_index}" # Basic unique key
637
  if isinstance(msg, HumanMessage):
638
- with st.chat_message("user", key=f"{unique_key}_user"):
639
- st.markdown(msg.content)
640
  elif isinstance(msg, AIMessage):
641
  with st.chat_message("assistant", key=f"{unique_key}_ai"):
642
- # Display AI text content
643
- ai_content = msg.content
644
- structured_output = None
645
-
646
- # Attempt to parse structured JSON if present
647
  try:
648
- # Look for ```json ... ``` block
649
  json_match = re.search(r"```json\s*(\{.*?\})\s*```", ai_content, re.DOTALL | re.IGNORECASE)
650
  if json_match:
651
- json_str = json_match.group(1)
652
- # Display content before/after the JSON block if any
653
- prefix = ai_content[:json_match.start()].strip()
654
- suffix = ai_content[json_match.end():].strip()
655
  if prefix: st.markdown(prefix)
656
  structured_output = json.loads(json_str)
657
  if suffix: st.markdown(suffix)
658
- # Check if the entire message might be JSON
659
  elif ai_content.strip().startswith("{") and ai_content.strip().endswith("}"):
660
- structured_output = json.loads(ai_content)
661
- ai_content = "" # Don't display raw JSON if parsed ok
662
- else:
663
- # No JSON found, display content as is
664
- st.markdown(ai_content)
665
-
666
- except json.JSONDecodeError:
667
- # Failed to parse, display raw content
668
- st.markdown(ai_content)
669
- st.warning("Note: Could not parse structured JSON in AI response.", icon="⚠️")
670
- except Exception as e:
671
- st.markdown(ai_content) # Display raw on other errors
672
- st.error(f"Error processing AI message display: {e}", icon="❌")
673
-
674
- # Display structured data nicely if parsed
675
  if structured_output and isinstance(structured_output, dict):
676
- st.divider()
677
- st.subheader("πŸ“Š AI Analysis & Recommendations")
678
  cols = st.columns(2)
679
  with cols[0]:
680
- st.markdown(f"**Assessment:**")
681
- st.markdown(f"> {structured_output.get('assessment', 'N/A')}")
682
-
683
  st.markdown(f"**Differential Diagnosis:**")
684
- ddx = structured_output.get('differential_diagnosis', [])
685
  if ddx:
686
  for item in ddx:
687
- likelihood = item.get('likelihood', 'Unknown').capitalize()
688
- icon = "πŸ₯‡" if likelihood=="High" else ("πŸ₯ˆ" if likelihood=="Medium" else "πŸ₯‰")
689
- with st.expander(f"{icon} {item.get('diagnosis', 'Unknown')} ({likelihood})"):
690
- st.write(f"**Rationale:** {item.get('rationale', 'N/A')}")
691
- else: st.info("No differential diagnosis provided.")
692
-
693
- st.markdown(f"**Risk Assessment:**")
694
- risk = structured_output.get('risk_assessment', {})
695
- flags = risk.get('identified_red_flags', [])
696
  if flags: st.warning(f"**Flags:** {', '.join(flags)}")
697
- if risk.get("immediate_concerns"): st.warning(f"**Concerns:** {', '.join(risk.get('immediate_concerns'))}")
698
- if risk.get("potential_complications"): st.info(f"**Potential Complications:** {', '.join(risk.get('potential_complications'))}")
699
- if not flags and not risk.get("immediate_concerns"): st.success("No major risks highlighted in this assessment.")
700
-
701
  with cols[1]:
702
- st.markdown(f"**Recommended Plan:**")
703
- plan = structured_output.get('recommended_plan', {})
704
- sub_sections = ["investigations", "therapeutics", "consultations", "patient_education"]
705
- for section in sub_sections:
706
- st.markdown(f"_{section.replace('_',' ').capitalize()}:_")
707
- items = plan.get(section)
708
- if items and isinstance(items, list):
709
- for item in items: st.markdown(f"- {item}")
710
- elif items: # Handle if it's just a string
711
- st.markdown(f"- {items}")
712
  else: st.markdown("_None suggested._")
713
- st.markdown("") # Add space
714
-
715
- # Display Rationale and Interaction Summary below columns
716
- st.markdown(f"**Rationale & Guideline Check:**")
717
- st.markdown(f"> {structured_output.get('rationale_summary', 'N/A')}")
718
  interaction_summary = structured_output.get("interaction_check_summary", "")
719
- if interaction_summary:
720
- st.markdown(f"**Interaction Check Summary:**")
721
- st.markdown(f"> {interaction_summary}")
722
-
723
  st.divider()
724
 
725
- # Display tool calls requested in this AI turn
726
  if getattr(msg, 'tool_calls', None):
727
  with st.expander("πŸ› οΈ AI requested actions", expanded=False):
728
  for tc in msg.tool_calls:
729
- try:
730
- # Safely display, default args to empty dict if missing
731
- st.code(f"Action: {tc.get('name', 'Unknown Tool')}\nArgs: {json.dumps(tc.get('args', {}), indent=2)}", language="json")
732
- except Exception as display_e:
733
- st.error(f"Could not display tool call: {display_e}")
734
- st.code(str(tc)) # Raw display as fallback
735
 
736
  elif isinstance(msg, ToolMessage):
737
- # Safely get tool name
738
- tool_name_display = getattr(msg, 'name', 'tool_execution') # Use 'name' attribute added in tool_node
739
  with st.chat_message(tool_name_display, avatar="πŸ› οΈ", key=f"{unique_key}_tool"):
 
740
  try:
741
- # Attempt to parse content as JSON for structured display
742
- tool_data = json.loads(msg.content)
743
- status = tool_data.get("status", "info")
744
- message = tool_data.get("message", msg.content)
745
- details = tool_data.get("details")
746
- warnings = tool_data.get("warnings")
747
-
748
- if status == "success" or status == "clear" or status == "flagged":
749
- st.success(f"{message}", icon="βœ…" if status != "flagged" else "🚨")
750
  elif status == "warning":
751
  st.warning(f"{message}", icon="⚠️")
752
  if warnings and isinstance(warnings, list):
753
  st.caption("Details:")
754
- for warn in warnings: st.caption(f"- {warn}")
755
- else: # Error or unknown status
756
- st.error(f"{message}", icon="❌")
757
-
758
  if details: st.caption(f"Details: {details}")
 
 
759
 
760
- except json.JSONDecodeError:
761
- # If content is not JSON, display it plainly
762
- st.info(f"{msg.content}")
763
- except Exception as e:
764
- st.error(f"Error displaying tool message: {e}", icon="❌")
765
- st.caption(f"Raw content: {msg.content}")
766
 
767
- # --- Chat Input Logic ---
768
  if prompt := st.chat_input("Your message or follow-up query..."):
769
  if not st.session_state.patient_data:
770
- st.warning("Please load patient data using the sidebar first.")
771
- st.stop() # Prevent execution if no patient data
772
 
773
- # Add user message to state and display it immediately
774
  user_message = HumanMessage(content=prompt)
775
  st.session_state.messages.append(user_message)
776
- with st.chat_message("user"):
777
- st.markdown(prompt)
778
 
779
- # Prepare state for graph invocation
780
- current_state = AgentState(
781
- messages=st.session_state.messages,
782
- patient_data=st.session_state.patient_data
783
- )
784
 
785
- # Invoke the graph
786
  with st.spinner("SynapseAI is thinking..."):
787
  try:
788
- # Use invoke to run the graph until it ends for this turn
789
- final_state = st.session_state.graph_app.invoke(
790
- current_state,
791
- {"recursion_limit": 15} # Add recursion limit for safety
792
- )
793
- # Update the session state messages with the final list from the graph run
794
  st.session_state.messages = final_state['messages']
795
-
796
  except Exception as e:
797
- print(f"CRITICAL ERROR during graph invocation: {type(e).__name__} - {e}")
798
- traceback.print_exc()
799
  st.error(f"An error occurred during the conversation turn: {e}", icon="❌")
800
- # Attempt to add an error message to the history for visibility
801
- error_ai_msg = AIMessage(content=f"Sorry, a critical error occurred: {type(e).__name__}. Please check logs or try again.")
802
- # Avoid modifying state directly during exception handling if possible,
803
- # but appending might be okay for display purposes.
804
- # st.session_state.messages.append(error_ai_msg) # Be cautious with state modification here
805
 
806
- # Rerun the script to display the updated chat history, including AI response and tool results
807
- st.rerun()
808
 
809
- # Disclaimer at the bottom
810
  st.markdown("---")
811
- st.warning(
812
- """**Disclaimer:** SynapseAI is an AI assistant for clinical decision support and does not replace professional medical judgment.
813
- All outputs must be critically reviewed and verified by a qualified healthcare provider before making any clinical decisions.
814
- Validate all information, especially diagnoses, dosages, and interactions, independently using standard medical resources."""
815
- )
816
 
817
  if __name__ == "__main__":
818
  main()
 
1
  import streamlit as st
2
+ import requests
3
+ import json
4
+ import re
5
+ import os
6
+ import operator
7
+ import traceback
8
+ from functools import lru_cache
9
+ from dotenv import load_dotenv
10
+
11
  from langchain_groq import ChatGroq
12
  from langchain_community.tools.tavily_search import TavilySearchResults
13
  from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
 
16
  from langchain_core.tools import tool
17
  from langgraph.prebuilt import ToolExecutor
18
  from langgraph.graph import StateGraph, END
 
19
 
20
  from typing import Optional, List, Dict, Any, TypedDict, Annotated
21
+
22
+ # --- Environment Variable Loading & Validation ---
23
+ load_dotenv() # Load .env file if present (for local development)
24
+
25
+ UMLS_API_KEY = os.environ.get("UMLS_API_KEY")
26
+ GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
27
+ TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY")
28
+
29
+ # Stop execution if essential keys are missing (crucial for HF Spaces)
30
+ missing_keys = []
31
+ if not UMLS_API_KEY: missing_keys.append("UMLS_API_KEY")
32
+ if not GROQ_API_KEY: missing_keys.append("GROQ_API_KEY")
33
+ if not TAVILY_API_KEY: missing_keys.append("TAVILY_API_KEY")
34
+
35
+ if missing_keys:
36
+ st.error(f"Missing required API Key(s): {', '.join(missing_keys)}. Please set them in Hugging Face Space Secrets or your environment variables.")
37
+ st.stop()
38
 
39
  # --- Configuration & Constants ---
40
  class ClinicalAppSettings:
41
+ APP_TITLE = "SynapseAI: Interactive Clinical Decision Support (UMLS/FDA Integrated)"
42
  PAGE_LAYOUT = "wide"
43
  MODEL_NAME = "llama3-70b-8192" # Groq Llama3 70b
44
  TEMPERATURE = 0.1
45
  MAX_SEARCH_RESULTS = 3
46
 
47
  class ClinicalPrompts:
48
+ # System prompt remains the same as the previous version, emphasizing structured output,
49
+ # safety checks, guideline search, and conversational flow.
50
  SYSTEM_PROMPT = """
51
  You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation.
52
  Your goal is to support healthcare professionals by analyzing patient data, providing differential diagnoses, suggesting evidence-based management plans, and identifying risks according to current standards of care.
 
85
  8. **Conciseness & Flow:** Be medically accurate and concise. Use standard terminology. Respond naturally in conversation (asking questions, acknowledging info) until ready for the full structured JSON output.
86
  """
87
 
88
+ # --- UMLS/RxNorm & OpenFDA API Helper Functions ---
89
+ UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key" # May not be needed if using apiKey directly
90
+ RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
91
+ OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ @lru_cache(maxsize=256) # Cache RxCUI lookups
94
+ def get_rxcui(drug_name: str) -> Optional[str]:
95
+ """Uses RxNorm API to find the RxCUI for a given drug name."""
96
+ if not drug_name or not isinstance(drug_name, str): return None
97
+ drug_name = drug_name.strip()
98
+ if not drug_name: return None
99
+
100
+ print(f"RxNorm Lookup for: '{drug_name}'")
101
+ try:
102
+ params = {"name": drug_name, "search": 1} # Search for concepts related to the name
103
+ response = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10)
104
+ response.raise_for_status()
105
+ data = response.json()
106
+ # Extract RxCUI - prioritize exact matches or common types
107
+ if data and "idGroup" in data and "rxnormId" in data["idGroup"]:
108
+ # Select the first one, assuming it's the most relevant by default.
109
+ # More sophisticated logic could check TTYs (Term Types) if needed.
110
+ rxcui = data["idGroup"]["rxnormId"][0]
111
+ print(f" Found RxCUI: {rxcui} for '{drug_name}'")
112
+ return rxcui
113
+ else:
114
+ # Fallback: Search /drugs endpoint if direct rxcui lookup fails
115
+ params = {"name": drug_name}
116
+ response = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10)
117
+ response.raise_for_status()
118
+ data = response.json()
119
+ if data and "drugGroup" in data and "conceptGroup" in data["drugGroup"]:
120
+ for group in data["drugGroup"]["conceptGroup"]:
121
+ # Prioritize Semantic Types like Brand/Clinical Drug/Ingredient
122
+ if group.get("tty") in ["SBD", "SCD", "GPCK", "BPCK", "IN", "MIN", "PIN"]:
123
+ if "conceptProperties" in group and group["conceptProperties"]:
124
+ rxcui = group["conceptProperties"][0].get("rxcui")
125
+ if rxcui:
126
+ print(f" Found RxCUI (via /drugs): {rxcui} for '{drug_name}'")
127
+ return rxcui
128
+ print(f" RxCUI not found for '{drug_name}'.")
129
+ return None
130
+ except requests.exceptions.RequestException as e:
131
+ print(f" Error fetching RxCUI for '{drug_name}': {e}")
132
+ return None
133
+ except json.JSONDecodeError as e:
134
+ print(f" Error decoding RxNorm JSON response for '{drug_name}': {e}")
135
+ return None
136
+ except Exception as e: # Catch any other unexpected error
137
+ print(f" Unexpected error in get_rxcui for '{drug_name}': {e}")
138
+ return None
139
+
140
+ @lru_cache(maxsize=128) # Cache OpenFDA lookups
141
+ def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[dict]:
142
+ """Fetches drug label information from OpenFDA using RxCUI or drug name."""
143
+ if not rxcui and not drug_name: return None
144
+ print(f"OpenFDA Label Lookup for: RXCUI={rxcui}, Name={drug_name}")
145
+
146
+ search_terms = []
147
+ # Prioritize RxCUI lookup using multiple potential fields
148
+ if rxcui:
149
+ search_terms.append(f'spl_rxnorm_code:"{rxcui}"')
150
+ search_terms.append(f'openfda.rxcui:"{rxcui}"')
151
+ # Add name search as fallback or supplement
152
+ if drug_name:
153
+ search_terms.append(f'(openfda.brand_name:"{drug_name.lower()}" OR openfda.generic_name:"{drug_name.lower()}")')
154
+
155
+ search_query = " OR ".join(search_terms)
156
+ params = {"search": search_query, "limit": 1} # Get only the most relevant label
157
+
158
+ try:
159
+ response = requests.get(OPENFDA_API_BASE, params=params, timeout=15)
160
+ response.raise_for_status()
161
+ data = response.json()
162
+ if data and "results" in data and data["results"]:
163
+ print(f" Found OpenFDA label for query: {search_query}")
164
+ return data["results"][0] # Return the first label found
165
+ print(f" No OpenFDA label found for query: {search_query}")
166
+ return None
167
+ except requests.exceptions.RequestException as e:
168
+ print(f" Error fetching OpenFDA label: {e}")
169
+ return None
170
+ except json.JSONDecodeError as e:
171
+ print(f" Error decoding OpenFDA JSON response: {e}")
172
+ return None
173
+ except Exception as e:
174
+ print(f" Unexpected error in get_openfda_label: {e}")
175
+ return None
176
+
177
+ def search_text_list(text_list: Optional[List[str]], search_terms: List[str]) -> List[str]:
178
+ """ Case-insensitive search for any search_term within a list of text strings. Returns snippets. """
179
+ found_snippets = []
180
+ if not text_list or not search_terms: return found_snippets
181
+ # Ensure search terms are lowercased strings
182
+ search_terms_lower = [str(term).lower() for term in search_terms if term]
183
+
184
+ for text_item in text_list:
185
+ if not isinstance(text_item, str): continue # Skip non-string items
186
+ text_item_lower = text_item.lower()
187
+ for term in search_terms_lower:
188
+ if term in text_item_lower:
189
+ # Find the start index of the term
190
+ start_index = text_item_lower.find(term)
191
+ # Define snippet boundaries (e.g., 50 chars before, 100 after)
192
+ snippet_start = max(0, start_index - 50)
193
+ snippet_end = min(len(text_item), start_index + len(term) + 100)
194
+ snippet = text_item[snippet_start:snippet_end]
195
+ # Add indication of where the match is
196
+ snippet = snippet.replace(term, f"**{term}**", 1) # Highlight first match
197
+ found_snippets.append(f"...{snippet}...")
198
+ break # Move to the next text item once a match is found
199
+ return found_snippets
200
+
201
+ # --- Other Helper Functions ---
202
  def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
203
  """Parses BP string like '120/80' into (systolic, diastolic) integers."""
204
  if not isinstance(bp_string, str): return None
205
  match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string.strip())
206
+ if match: return int(match.group(1)), int(match.group(2))
 
207
  return None
208
 
209
  def check_red_flags(patient_data: dict) -> List[str]:
210
  """Checks patient data against predefined red flags."""
211
+ # (Keep the implementation from the previous full code listing)
212
  flags = []
213
  if not patient_data: return flags
 
214
  symptoms = patient_data.get("hpi", {}).get("symptoms", [])
215
  vitals = patient_data.get("vitals", {})
216
  history = patient_data.get("pmh", {}).get("conditions", "")
 
217
  symptoms_lower = [str(s).lower() for s in symptoms if isinstance(s, str)]
218
 
 
219
  if "chest pain" in symptoms_lower: flags.append("Red Flag: Chest Pain reported.")
220
  if "shortness of breath" in symptoms_lower: flags.append("Red Flag: Shortness of Breath reported.")
221
  if "severe headache" in symptoms_lower: flags.append("Red Flag: Severe Headache reported.")
 
224
  if "hemoptysis" in symptoms_lower: flags.append("Red Flag: Hemoptysis (coughing up blood).")
225
  if "syncope" in symptoms_lower: flags.append("Red Flag: Syncope (fainting).")
226
 
 
227
  if vitals:
228
+ temp = vitals.get("temp_c"); hr = vitals.get("hr_bpm"); rr = vitals.get("rr_rpm")
229
+ spo2 = vitals.get("spo2_percent"); bp_str = vitals.get("bp_mmhg")
230
+ if temp is not None and temp >= 38.5: flags.append(f"Red Flag: Fever ({temp}Β°C).")
231
+ if hr is not None and hr >= 120: flags.append(f"Red Flag: Tachycardia ({hr} bpm).")
232
+ if hr is not None and hr <= 50: flags.append(f"Red Flag: Bradycardia ({hr} bpm).")
233
+ if rr is not None and rr >= 24: flags.append(f"Red Flag: Tachypnea ({rr} rpm).")
234
+ if spo2 is not None and spo2 <= 92: flags.append(f"Red Flag: Hypoxia ({spo2}%).")
 
 
 
 
235
  if bp_str:
236
  bp = parse_bp(bp_str)
237
  if bp:
238
  if bp[0] >= 180 or bp[1] >= 110: flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {bp_str} mmHg).")
239
  if bp[0] <= 90 or bp[1] <= 60: flags.append(f"Red Flag: Hypotension (BP: {bp_str} mmHg).")
240
 
241
+ if history and isinstance(history, str):
242
+ history_lower = history.lower()
243
+ if "history of mi" in history_lower and "chest pain" in symptoms_lower: flags.append("Red Flag: History of MI with current Chest Pain.")
244
+ if "history of dvt/pe" in history_lower and "shortness of breath" in symptoms_lower: flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.")
245
+
246
+ return list(set(flags)) # Unique flags
247
 
 
 
248
 
249
  def format_patient_data_for_prompt(data: dict) -> str:
250
  """Formats the patient dictionary into a readable string for the LLM."""
251
+ # (Keep the implementation from the previous full code listing)
252
  if not data: return "No patient data provided."
253
  prompt_str = ""
254
  for key, value in data.items():
 
258
  if has_content:
259
  prompt_str += f"**{section_title}:**\n"
260
  for sub_key, sub_value in value.items():
261
+ if sub_value: prompt_str += f" - {sub_key.replace('_', ' ').title()}: {sub_value}\n"
 
262
  elif isinstance(value, list) and value:
263
  prompt_str += f"**{section_title}:** {', '.join(map(str, value))}\n"
264
+ elif value and not isinstance(value, dict):
265
  prompt_str += f"**{section_title}:** {value}\n"
266
  return prompt_str.strip()
267
 
268
 
269
  # --- Tool Definitions ---
270
 
271
+ # Pydantic models for tool inputs
272
  class LabOrderInput(BaseModel):
273
  test_name: str = Field(..., description="Specific name of the lab test or panel (e.g., 'CBC', 'BMP', 'Troponin I', 'Urinalysis', 'D-dimer').")
274
  reason: str = Field(..., description="Clinical justification for ordering the test (e.g., 'Rule out infection', 'Assess renal function', 'Evaluate for ACS', 'Assess for PE').")
275
  priority: str = Field("Routine", description="Priority of the test (e.g., 'STAT', 'Routine').")
276
 
 
 
 
 
 
 
 
 
 
 
 
277
  class PrescriptionInput(BaseModel):
278
  medication_name: str = Field(..., description="Name of the medication.")
279
  dosage: str = Field(..., description="Dosage amount and unit (e.g., '500 mg', '10 mg', '81 mg').")
 
282
  duration: str = Field("As directed", description="Duration of treatment (e.g., '7 days', '1 month', 'Ongoing', 'Until follow-up').")
283
  reason: str = Field(..., description="Clinical indication for the prescription.")
284
 
285
+ # Updated InteractionCheckInput - Note: current_medications/allergies are Optional here
286
+ # because they are populated by the tool_node from state *before* execution.
 
 
 
 
 
 
 
 
 
 
287
  class InteractionCheckInput(BaseModel):
288
  potential_prescription: str = Field(..., description="The name of the NEW medication being considered for prescribing.")
 
 
 
 
289
  current_medications: Optional[List[str]] = Field(None, description="List of patient's current medication names (populated from state).")
290
  allergies: Optional[List[str]] = Field(None, description="List of patient's known allergies (populated from state).")
291
 
292
+ class FlagRiskInput(BaseModel):
293
+ risk_description: str = Field(..., description="Specific critical risk identified (e.g., 'Suspected Sepsis', 'Acute Coronary Syndrome', 'Stroke Alert').")
294
+ urgency: str = Field("High", description="Urgency level (e.g., 'Critical', 'High', 'Moderate').")
295
+
296
+ # Tool functions
297
+ @tool("order_lab_test", args_schema=LabOrderInput)
298
+ def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
299
+ """Orders a specific lab test with clinical justification and priority."""
300
+ print(f"Executing order_lab_test: {test_name}, Reason: {reason}, Priority: {priority}")
301
+ return json.dumps({"status": "success", "message": f"Lab Ordered: {test_name} ({priority})", "details": f"Reason: {reason}"})
302
+
303
+ @tool("prescribe_medication", args_schema=PrescriptionInput)
304
+ def prescribe_medication(medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str) -> str:
305
+ """Prescribes a medication with detailed instructions and clinical indication. IMPORTANT: Requires prior interaction check."""
306
+ print(f"Executing prescribe_medication: {medication_name} {dosage}...")
307
+ # Safety check happens in tool_node *before* this is called.
308
+ return json.dumps({"status": "success", "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", "details": f"Duration: {duration}. Reason: {reason}"})
309
 
310
+ # --- NEW Interaction Check Tool using UMLS/RxNorm & OpenFDA ---
311
  @tool("check_drug_interactions", args_schema=InteractionCheckInput)
312
  def check_drug_interactions(potential_prescription: str, current_medications: Optional[List[str]] = None, allergies: Optional[List[str]] = None) -> str:
313
+ """
314
+ Checks for potential drug-drug and drug-allergy interactions using RxNorm API for normalization
315
+ and OpenFDA drug labels for interaction/warning text. REQUIRES UMLS_API_KEY environment variable.
316
+ """
317
+ print(f"\n--- Executing REAL check_drug_interactions ---")
318
+ print(f"Checking potential prescription: '{potential_prescription}'")
319
  warnings = []
320
+ potential_med_lower = potential_prescription.lower().strip()
321
 
322
+ # Use provided lists or default to empty
323
  current_meds_list = current_medications or []
324
  allergies_list = allergies or []
325
+ # Clean and lowercase current med names (basic extraction: first word)
326
+ current_med_names_lower = []
327
+ for med in current_meds_list:
328
+ match = re.match(r"^\s*([a-zA-Z\-]+)", str(med))
329
+ if match: current_med_names_lower.append(match.group(1).lower())
330
+ # Clean and lowercase allergies
331
+ allergies_lower = [str(a).lower().strip() for a in allergies_list if a]
332
+
333
+ print(f" Against Current Meds (names): {current_med_names_lower}")
334
+ print(f" Against Allergies: {allergies_lower}")
335
+
336
+ # --- Step 1: Normalize potential prescription ---
337
+ print(f" Step 1: Normalizing '{potential_prescription}'...")
338
+ potential_rxcui = get_rxcui(potential_prescription)
339
+ potential_label = get_openfda_label(rxcui=potential_rxcui, drug_name=potential_prescription)
340
+ if not potential_rxcui and not potential_label:
341
+ print(f" Warning: Could not find RxCUI or OpenFDA label for '{potential_prescription}'. Interaction check will be limited.")
342
+ warnings.append(f"INFO: Could not reliably identify '{potential_prescription}' in standard terminologies/databases. Checks may be incomplete.")
343
+
344
+ # --- Step 2: Allergy Check ---
345
+ print(" Step 2: Performing Allergy Check...")
346
+ # Direct name match against patient's allergy list
347
  for allergy in allergies_lower:
 
348
  if allergy == potential_med_lower:
349
+ warnings.append(f"CRITICAL ALLERGY (Name Match): Patient allergic to '{allergy}'. Potential prescription is '{potential_prescription}'.")
350
+ # Basic cross-reactivity check (can be expanded)
351
+ elif allergy in ["penicillin", "pcns"] and potential_med_lower in ["amoxicillin", "ampicillin", "augmentin", "piperacillin"]:
352
+ warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to Penicillin. High risk with '{potential_prescription}'.")
353
+ elif allergy == "sulfa" and potential_med_lower in ["sulfamethoxazole", "bactrim", "sulfasalazine"]:
354
+ warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to Sulfa. High risk with '{potential_prescription}'.")
355
+ elif allergy in ["nsaids", "aspirin"] and potential_med_lower in ["ibuprofen", "naproxen", "ketorolac", "diclofenac"]:
356
+ warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to NSAIDs/Aspirin. Risk with '{potential_prescription}'.")
357
+
358
+ # Check OpenFDA Label for Contraindications/Warnings related to ALLERGIES
359
+ if potential_label:
360
+ contraindications = potential_label.get("contraindications")
361
+ warnings_section = potential_label.get("warnings_and_cautions") or potential_label.get("warnings")
362
+
363
+ if contraindications:
364
+ allergy_mentions_ci = search_text_list(contraindications, allergies_lower)
365
+ if allergy_mentions_ci:
366
+ warnings.append(f"ALLERGY RISK (Contraindication Found): Label for '{potential_prescription}' mentions contraindication potentially related to patient allergies: {'; '.join(allergy_mentions_ci)}")
367
+
368
+ if warnings_section:
369
+ allergy_mentions_warn = search_text_list(warnings_section, allergies_lower)
370
+ if allergy_mentions_warn:
371
+ warnings.append(f"ALLERGY RISK (Warning Found): Label for '{potential_prescription}' mentions warnings potentially related to patient allergies: {'; '.join(allergy_mentions_warn)}")
372
+
373
+ # --- Step 3: Drug-Drug Interaction Check ---
374
+ print(" Step 3: Performing Drug-Drug Interaction Check...")
375
+ if potential_rxcui or potential_label: # Proceed only if we have info on the potential drug
376
+ for current_med_name in current_med_names_lower:
377
+ if not current_med_name or current_med_name == potential_med_lower: continue # Skip empty or self-interaction
378
+
379
+ print(f" Checking interaction between '{potential_prescription}' and '{current_med_name}'...")
380
+ current_rxcui = get_rxcui(current_med_name)
381
+ current_label = get_openfda_label(rxcui=current_rxcui, drug_name=current_med_name)
382
+
383
+ # Terms to search for in interaction text
384
+ search_terms_for_current = [current_med_name]
385
+ if current_rxcui: search_terms_for_current.append(current_rxcui) # Add RxCUI if found
386
+
387
+ search_terms_for_potential = [potential_med_lower]
388
+ if potential_rxcui: search_terms_for_potential.append(potential_rxcui) # Add RxCUI if found
389
+
390
+ interaction_found_flag = False
391
+ # Check Potential Drug's Label ('drug_interactions' section) for mentions of Current Drug
392
+ if potential_label and potential_label.get("drug_interactions"):
393
+ interaction_mentions = search_text_list(potential_label.get("drug_interactions"), search_terms_for_current)
394
+ if interaction_mentions:
395
+ warnings.append(f"Potential Interaction ({potential_prescription.capitalize()} Label): Mentions '{current_med_name.capitalize()}'. Snippets: {'; '.join(interaction_mentions)}")
396
+ interaction_found_flag = True
397
+
398
+ # Check Current Drug's Label ('drug_interactions' section) for mentions of Potential Drug
399
+ if current_label and current_label.get("drug_interactions") and not interaction_found_flag: # Avoid duplicate warnings if already found
400
+ interaction_mentions = search_text_list(current_label.get("drug_interactions"), search_terms_for_potential)
401
+ if interaction_mentions:
402
+ warnings.append(f"Potential Interaction ({current_med_name.capitalize()} Label): Mentions '{potential_prescription.capitalize()}'. Snippets: {'; '.join(interaction_mentions)}")
403
+
404
+ else: # Case where potential drug wasn't identified
405
+ warnings.append(f"INFO: Drug-drug interaction check skipped for '{potential_prescription}' as it could not be identified via RxNorm/OpenFDA.")
406
+
407
+
408
+ # --- Step 4: Format Output ---
409
+ final_warnings = list(set(warnings)) # Remove duplicates
410
+ status = "warning" if any("CRITICAL" in w or "Interaction" in w or "RISK" in w for w in final_warnings) else "clear"
411
+ if not final_warnings: status = "clear" # Ensure clear if no warnings remain
412
+
413
+ message = f"Interaction/Allergy check for '{potential_prescription}': {len(final_warnings)} potential issue(s) identified using RxNorm/OpenFDA." if final_warnings else f"No major interactions or allergy issues identified for '{potential_prescription}' based on RxNorm/OpenFDA lookup."
414
+ print(f"--- Interaction Check Complete for '{potential_prescription}' ---")
415
+
416
+ return json.dumps({"status": status, "message": message, "warnings": final_warnings})
417
+ # --- End of NEW Interaction Check Tool ---
418
 
419
  @tool("flag_risk", args_schema=FlagRiskInput)
420
  def flag_risk(risk_description: str, urgency: str) -> str:
421
  """Flags a critical risk identified during analysis for immediate attention."""
422
  print(f"Executing flag_risk: {risk_description}, Urgency: {urgency}")
 
423
  st.error(f"🚨 **{urgency.upper()} RISK FLAGGED by AI:** {risk_description}", icon="🚨")
424
+ return json.dumps({"status": "flagged", "message": f"Risk '{risk_description}' flagged with {urgency} urgency."})
 
 
 
425
 
426
  # Initialize Search Tool
427
  search_tool = TavilySearchResults(
428
  max_results=ClinicalAppSettings.MAX_SEARCH_RESULTS,
429
+ name="tavily_search_results"
430
  )
431
 
432
  # --- LangGraph Setup ---
433
 
434
  # Define the state structure
435
  class AgentState(TypedDict):
436
+ messages: Annotated[list[Any], operator.add]
437
+ patient_data: Optional[dict]
438
 
439
  # Define Tools and Tool Executor
440
  tools = [
441
  order_lab_test,
442
  prescribe_medication,
443
+ check_drug_interactions, # Using the new implementation
444
  flag_risk,
445
  search_tool
446
  ]
 
450
  model = ChatGroq(
451
  temperature=ClinicalAppSettings.TEMPERATURE,
452
  model=ClinicalAppSettings.MODEL_NAME,
 
 
453
  )
 
454
  model_with_tools = model.bind_tools(tools)
455
 
456
+ # --- Graph Nodes (agent_node, tool_node remain mostly the same structurally) ---
457
 
458
+ # 1. Agent Node: Calls the LLM (No change needed from previous version)
459
  def agent_node(state: AgentState):
460
  """Invokes the LLM to decide the next action or response."""
461
  print("\n---AGENT NODE---")
462
  current_messages = state['messages']
 
463
  if not current_messages or not isinstance(current_messages[0], SystemMessage):
464
  print("Prepending System Prompt.")
465
  current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages
 
 
 
 
 
 
 
 
 
 
466
  print(f"Invoking LLM with {len(current_messages)} messages.")
 
467
  try:
468
  response = model_with_tools.invoke(current_messages)
469
  print(f"Agent Raw Response Type: {type(response)}")
470
+ if hasattr(response, 'tool_calls') and response.tool_calls: print(f"Agent Response Tool Calls: {response.tool_calls}")
471
+ else: print("Agent Response: No tool calls.")
 
 
 
 
472
  except Exception as e:
473
  print(f"ERROR in agent_node during LLM invocation: {type(e).__name__} - {e}")
474
+ traceback.print_exc()
 
475
  error_message = AIMessage(content=f"Sorry, an internal error occurred while processing the request: {type(e).__name__}")
476
  return {"messages": [error_message]}
 
477
  return {"messages": [response]}
478
 
479
+ # 2. Tool Node: Executes tools (Mostly the same, ensures context injection)
480
  def tool_node(state: AgentState):
481
  """Executes tools called by the LLM and returns results."""
482
  print("\n---TOOL NODE---")
483
+ tool_messages = []
484
  last_message = state['messages'][-1]
485
 
 
486
  if not isinstance(last_message, AIMessage) or not getattr(last_message, 'tool_calls', None):
487
+ print("Warning: Tool node called unexpectedly without tool calls.")
 
 
 
488
  return {"messages": []}
489
 
490
  tool_calls = last_message.tool_calls
491
+ print(f"Tool calls received: {json.dumps(tool_calls, indent=2)}")
 
 
 
 
492
 
493
+ # Safety Check Logic (No change needed from previous version)
494
+ prescriptions_requested = {}
495
+ interaction_checks_requested = {}
496
  for call in tool_calls:
497
+ tool_name = call.get('name'); tool_args = call.get('args', {})
 
498
  if tool_name == 'prescribe_medication':
499
+ med_name = tool_args.get('medication_name', '').lower();
500
+ if med_name: prescriptions_requested[med_name] = call
 
501
  elif tool_name == 'check_drug_interactions':
502
  potential_med = tool_args.get('potential_prescription', '').lower()
503
+ if potential_med: interaction_checks_requested[potential_med] = call
 
504
 
505
  valid_tool_calls_for_execution = []
506
+ blocked_ids = set()
 
507
  for med_name, prescribe_call in prescriptions_requested.items():
508
  if med_name not in interaction_checks_requested:
509
  st.error(f"**Safety Violation:** AI attempted to prescribe '{med_name}' without requesting `check_drug_interactions` in the *same turn*. Prescription blocked.")
510
+ error_msg = ToolMessage(content=json.dumps({"status": "error", "message": f"Interaction check for '{med_name}' must be requested *before or alongside* the prescription call."}), tool_call_id=prescribe_call['id'], name=prescribe_call['name'])
 
 
 
 
511
  tool_messages.append(error_msg)
512
+ blocked_ids.add(prescribe_call['id'])
 
 
513
 
 
 
514
  valid_tool_calls_for_execution = [call for call in tool_calls if call['id'] not in blocked_ids]
515
 
516
+ # Augment interaction checks with patient data (Crucial part - no change needed here)
517
+ patient_data = state.get("patient_data", {})
518
+ patient_meds_full = patient_data.get("medications", {}).get("current", []) # Pass full med list if needed by tool
519
+ patient_allergies = patient_data.get("allergies", [])
520
 
521
  for call in valid_tool_calls_for_execution:
522
  if call['name'] == 'check_drug_interactions':
 
523
  if 'args' not in call: call['args'] = {}
524
+ # Pass the necessary context from patient_data to the tool arguments
525
+ # The tool function expects 'current_medications' (list of names) and 'allergies'
526
+ call['args']['current_medications'] = patient_meds_full # Pass the full strings
527
  call['args']['allergies'] = patient_allergies
528
+ print(f"Augmented interaction check args for call ID {call['id']}") # Removed args content for brevity
 
529
 
530
+ # Execute valid tool calls (No change needed from previous version)
531
  if valid_tool_calls_for_execution:
532
  print(f"Attempting to execute {len(valid_tool_calls_for_execution)} tools: {[c['name'] for c in valid_tool_calls_for_execution]}")
533
  try:
534
  responses = tool_executor.batch(valid_tool_calls_for_execution, return_exceptions=True)
 
 
535
  for call, resp in zip(valid_tool_calls_for_execution, responses):
536
+ tool_call_id = call['id']; tool_name = call['name']
 
 
537
  if isinstance(resp, Exception):
538
+ error_type = type(resp).__name__; error_str = str(resp)
 
 
539
  print(f"ERROR executing tool '{tool_name}' (ID: {tool_call_id}): {error_type} - {error_str}")
540
+ traceback.print_exc()
541
  st.error(f"Error executing action '{tool_name}': {error_type}")
542
+ error_content = json.dumps({"status": "error", "message": f"Failed to execute '{tool_name}': {error_type} - {error_str}"})
 
 
 
543
  tool_messages.append(ToolMessage(content=error_content, tool_call_id=tool_call_id, name=tool_name))
 
 
544
  if isinstance(resp, AttributeError) and "'dict' object has no attribute 'tool'" in error_str:
545
+ print("\n *** DETECTED SPECIFIC ATTRIBUTE ERROR ('dict' object has no attribute 'tool') *** \n")
 
 
 
 
546
  else:
547
+ print(f"Tool '{tool_name}' (ID: {tool_call_id}) executed successfully.")
 
 
548
  content_str = str(resp)
549
  tool_messages.append(ToolMessage(content=content_str, tool_call_id=tool_call_id, name=tool_name))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
  except Exception as e:
551
  print(f"CRITICAL UNEXPECTED ERROR within tool_node logic: {type(e).__name__} - {e}")
552
+ traceback.print_exc(); st.error(f"Critical internal error processing actions: {e}")
 
 
553
  error_content = json.dumps({"status": "error", "message": f"Internal error processing tools: {e}"})
 
554
  processed_ids = {msg.tool_call_id for msg in tool_messages}
555
  for call in valid_tool_calls_for_execution:
556
+ if call['id'] not in processed_ids: tool_messages.append(ToolMessage(content=error_content, tool_call_id=call['id'], name=call['name']))
 
557
 
558
  print(f"Returning {len(tool_messages)} tool messages.")
 
559
  return {"messages": tool_messages}
560
 
561
 
562
+ # --- Graph Edges (Routing Logic) --- (No change needed)
563
  def should_continue(state: AgentState) -> str:
564
  """Determines whether to call tools, end the conversation turn, or handle errors."""
565
  print("\n---ROUTING DECISION---")
566
  last_message = state['messages'][-1] if state['messages'] else None
567
+ if not isinstance(last_message, AIMessage): return "end_conversation_turn"
568
+ if "Sorry, an internal error occurred" in last_message.content: return "end_conversation_turn"
569
+ if getattr(last_message, 'tool_calls', None): return "continue_tools"
570
+ else: return "end_conversation_turn"
571
 
572
+ # --- Graph Definition & Compilation --- (No change needed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
  workflow = StateGraph(AgentState)
 
 
574
  workflow.add_node("agent", agent_node)
575
  workflow.add_node("tools", tool_node)
 
 
576
  workflow.set_entry_point("agent")
577
+ workflow.add_conditional_edges("agent", should_continue, {"continue_tools": "tools", "end_conversation_turn": END})
 
 
 
 
 
 
 
 
 
 
 
578
  workflow.add_edge("tools", "agent")
 
 
 
 
579
  app = workflow.compile()
580
  print("LangGraph compiled successfully.")
581
 
 
583
  def main():
584
  st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
585
  st.title(f"🩺 {ClinicalAppSettings.APP_TITLE}")
586
+ st.caption(f"Interactive Assistant | LangGraph/Groq/Tavily/UMLS/OpenFDA | Model: {ClinicalAppSettings.MODEL_NAME}")
587
 
588
+ # Initialize session state (No change needed)
589
+ if "messages" not in st.session_state: st.session_state.messages = []
590
+ if "patient_data" not in st.session_state: st.session_state.patient_data = None
591
+ if "graph_app" not in st.session_state: st.session_state.graph_app = app
 
 
 
592
 
593
+ # --- Patient Data Input Sidebar --- (Adjusted allergy/med extraction slightly)
594
  with st.sidebar:
595
  st.header("πŸ“„ Patient Intake Form")
596
+ # Demographics, HPI, History, Social/Family, Vitals/Exam sections remain the same input fields
597
+ # ... (Copy input fields from previous full code version) ...
598
  st.subheader("Demographics")
599
  age = st.number_input("Age", min_value=0, max_value=120, value=55, key="age_input")
600
  sex = st.selectbox("Biological Sex", ["Male", "Female", "Other/Prefer not to say"], key="sex_input")
 
601
  st.subheader("History of Present Illness (HPI)")
602
  chief_complaint = st.text_input("Chief Complaint", "Chest pain", key="cc_input")
603
+ hpi_details = st.text_area("Detailed HPI", "55 y/o male presents with substernal chest pain started 2 hours ago...", key="hpi_input", height=150)
604
  symptoms = st.multiselect("Associated Symptoms", ["Nausea", "Diaphoresis", "Shortness of Breath", "Dizziness", "Palpitations", "Fever", "Cough", "Severe Headache", "Syncope", "Hemoptysis"], default=["Nausea", "Diaphoresis"], key="sym_input")
 
605
  st.subheader("Past History")
606
  pmh = st.text_area("Past Medical History (PMH)", "Hypertension (HTN), Hyperlipidemia (HLD), Type 2 Diabetes Mellitus (DM2), History of MI", key="pmh_input")
607
  psh = st.text_area("Past Surgical History (PSH)", "Appendectomy (2005)", key="psh_input")
 
608
  st.subheader("Medications & Allergies")
609
  current_meds_str = st.text_area("Current Medications (name, dose, freq)", "Lisinopril 10mg daily\nMetformin 1000mg BID\nAtorvastatin 40mg daily\nAspirin 81mg daily", key="meds_input")
610
  allergies_str = st.text_area("Allergies (comma separated, specify reaction if known)", "Penicillin (rash), Sulfa (hives)", key="allergy_input")
 
611
  st.subheader("Social/Family History")
612
  social_history = st.text_area("Social History (SH)", "Smoker (1 ppd x 30 years), occasional alcohol.", key="sh_input")
613
  family_history = st.text_area("Family History (FHx)", "Father had MI at age 60. Mother has HTN.", key="fhx_input")
 
614
  st.subheader("Vitals & Exam Findings")
615
  col1, col2 = st.columns(2)
616
  with col1:
 
621
  bp_mmhg = st.text_input("BP (SYS/DIA)", "155/90", key="bp_input")
622
  spo2_percent = st.number_input("SpO2 (%)", 70, 100, 96, key="spo2_input")
623
  pain_scale = st.slider("Pain (0-10)", 0, 10, 8, key="pain_input")
624
+ exam_notes = st.text_area("Brief Physical Exam Notes", "Awake, alert, oriented x3. Mild distress. Lungs clear bilaterally...", key="exam_input", height=100)
625
 
626
+
627
+ # Compile Patient Data Dictionary (Refined Extraction for Tool Use)
628
  if st.button("Start/Update Consultation", key="start_button"):
629
+ # Store full medication strings for display/context
630
  current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
631
+ # Extract just the names (simplified) for the interaction check tool's state population
632
+ current_med_names_only = []
633
  for med in current_meds_list:
634
  match = re.match(r"^\s*([a-zA-Z\-]+)", med)
635
+ if match: current_med_names_only.append(match.group(1).lower())
 
636
 
637
+ # Extract allergy names (simplified, before parenthesis)
638
  allergies_list = []
639
  for a in allergies_str.split(','):
640
  cleaned_allergy = a.strip()
641
  if cleaned_allergy:
642
+ match = re.match(r"^\s*([a-zA-Z\-\s/]+)(?:\s*\(.*\))?", cleaned_allergy)
643
+ name_part = match.group(1).strip().lower() if match else cleaned_allergy.lower()
644
+ allergies_list.append(name_part)
 
 
645
 
646
  st.session_state.patient_data = {
647
  "demographics": {"age": age, "sex": sex},
648
  "hpi": {"chief_complaint": chief_complaint, "details": hpi_details, "symptoms": symptoms},
649
  "pmh": {"conditions": pmh}, "psh": {"procedures": psh},
650
+ # Store both full list and names_only list
651
+ "medications": {"current": current_meds_list, "names_only": current_med_names_only},
652
+ "allergies": allergies_list, # Store cleaned list
653
  "social_history": {"details": social_history}, "family_history": {"details": family_history},
654
  "vitals": { "temp_c": temp_c, "hr_bpm": hr_bpm, "bp_mmhg": bp_mmhg, "rr_rpm": rr_rpm, "spo2_percent": spo2_percent, "pain_scale": pain_scale},
655
  "exam_findings": {"notes": exam_notes}
656
  }
657
 
658
+ # Initial Red Flag Check
659
  red_flags = check_red_flags(st.session_state.patient_data)
660
  st.sidebar.markdown("---")
661
  if red_flags:
662
  st.sidebar.warning("**Initial Red Flags Detected:**")
663
  for flag in red_flags: st.sidebar.warning(f"- {flag.replace('Red Flag: ','')}")
664
+ else: st.sidebar.success("No immediate red flags detected.")
 
665
 
666
+ # Prepare initial message & reset history
667
+ initial_prompt = "Initiate consultation for the patient described in the intake form. Review the data and begin analysis."
 
668
  st.session_state.messages = [HumanMessage(content=initial_prompt)]
669
+ st.success("Patient data loaded/updated. Ready for analysis.")
670
+
671
 
672
+ # --- Main Chat Interface Area --- (No change needed in display logic)
673
  st.header("πŸ’¬ Clinical Consultation")
674
 
675
  # Display chat messages from history
676
+ # (Copy the message display loop from the previous full code version)
677
  for msg_index, msg in enumerate(st.session_state.messages):
678
+ unique_key = f"msg_{msg_index}"
679
  if isinstance(msg, HumanMessage):
680
+ with st.chat_message("user", key=f"{unique_key}_user"): st.markdown(msg.content)
 
681
  elif isinstance(msg, AIMessage):
682
  with st.chat_message("assistant", key=f"{unique_key}_ai"):
683
+ ai_content = msg.content; structured_output = None
 
 
 
 
684
  try:
 
685
  json_match = re.search(r"```json\s*(\{.*?\})\s*```", ai_content, re.DOTALL | re.IGNORECASE)
686
  if json_match:
687
+ json_str = json_match.group(1); prefix = ai_content[:json_match.start()].strip(); suffix = ai_content[json_match.end():].strip()
 
 
 
688
  if prefix: st.markdown(prefix)
689
  structured_output = json.loads(json_str)
690
  if suffix: st.markdown(suffix)
 
691
  elif ai_content.strip().startswith("{") and ai_content.strip().endswith("}"):
692
+ structured_output = json.loads(ai_content); ai_content = ""
693
+ else: st.markdown(ai_content)
694
+ except Exception as e: st.markdown(ai_content); print(f"Error parsing/displaying AI JSON: {e}")
695
+
 
 
 
 
 
 
 
 
 
 
 
696
  if structured_output and isinstance(structured_output, dict):
697
+ # (Copy the structured JSON display logic from previous full code)
698
+ st.divider(); st.subheader("πŸ“Š AI Analysis & Recommendations")
699
  cols = st.columns(2)
700
  with cols[0]:
701
+ st.markdown(f"**Assessment:**"); st.markdown(f"> {structured_output.get('assessment', 'N/A')}")
 
 
702
  st.markdown(f"**Differential Diagnosis:**")
703
+ ddx = structured_output.get('differential_diagnosis', []);
704
  if ddx:
705
  for item in ddx:
706
+ likelihood = item.get('likelihood', '?').capitalize(); icon = "πŸ₯‡" if likelihood=="High" else ("πŸ₯ˆ" if likelihood=="Medium" else "πŸ₯‰")
707
+ with st.expander(f"{icon} {item.get('diagnosis', 'Unknown')} ({likelihood})"): st.write(f"**Rationale:** {item.get('rationale', 'N/A')}")
708
+ else: st.info("No DDx provided.")
709
+ st.markdown(f"**Risk Assessment:**"); risk = structured_output.get('risk_assessment', {})
710
+ flags = risk.get('identified_red_flags', []); concerns = risk.get("immediate_concerns", []); comps = risk.get("potential_complications", [])
 
 
 
 
711
  if flags: st.warning(f"**Flags:** {', '.join(flags)}")
712
+ if concerns: st.warning(f"**Concerns:** {', '.join(concerns)}")
713
+ if comps: st.info(f"**Potential Complications:** {', '.join(comps)}")
714
+ if not flags and not concerns: st.success("No major risks highlighted.")
 
715
  with cols[1]:
716
+ st.markdown(f"**Recommended Plan:**"); plan = structured_output.get('recommended_plan', {})
717
+ for section in ["investigations", "therapeutics", "consultations", "patient_education"]:
718
+ st.markdown(f"_{section.replace('_',' ').capitalize()}:_"); items = plan.get(section)
719
+ if items: [st.markdown(f"- {item}") for item in items] if isinstance(items, list) else st.markdown(f"- {items}")
 
 
 
 
 
 
720
  else: st.markdown("_None suggested._")
721
+ st.markdown("") # Space
722
+ st.markdown(f"**Rationale & Guideline Check:**"); st.markdown(f"> {structured_output.get('rationale_summary', 'N/A')}")
 
 
 
723
  interaction_summary = structured_output.get("interaction_check_summary", "")
724
+ if interaction_summary: st.markdown(f"**Interaction Check Summary:**"); st.markdown(f"> {interaction_summary}")
 
 
 
725
  st.divider()
726
 
 
727
  if getattr(msg, 'tool_calls', None):
728
  with st.expander("πŸ› οΈ AI requested actions", expanded=False):
729
  for tc in msg.tool_calls:
730
+ try: st.code(f"Action: {tc.get('name', 'Unknown')}\nArgs: {json.dumps(tc.get('args', {}), indent=2)}", language="json")
731
+ except Exception as display_e: st.error(f"Could not display tool call: {display_e}"); st.code(str(tc))
 
 
 
 
732
 
733
  elif isinstance(msg, ToolMessage):
734
+ tool_name_display = getattr(msg, 'name', 'tool_execution')
 
735
  with st.chat_message(tool_name_display, avatar="πŸ› οΈ", key=f"{unique_key}_tool"):
736
+ # (Copy the ToolMessage display logic from previous full code)
737
  try:
738
+ tool_data = json.loads(msg.content); status = tool_data.get("status", "info"); message = tool_data.get("message", msg.content)
739
+ details = tool_data.get("details"); warnings = tool_data.get("warnings")
740
+ if status == "success" or status == "clear" or status == "flagged": st.success(f"{message}", icon="βœ…" if status != "flagged" else "🚨")
 
 
 
 
 
 
741
  elif status == "warning":
742
  st.warning(f"{message}", icon="⚠️")
743
  if warnings and isinstance(warnings, list):
744
  st.caption("Details:")
745
+ for warn in warnings: st.caption(f"- {warn}") # Display warnings from the tool output JSON
746
+ else: st.error(f"{message}", icon="❌")
 
 
747
  if details: st.caption(f"Details: {details}")
748
+ except json.JSONDecodeError: st.info(f"{msg.content}") # Display raw if not JSON
749
+ except Exception as e: st.error(f"Error displaying tool message: {e}", icon="❌"); st.caption(f"Raw content: {msg.content}")
750
 
 
 
 
 
 
 
751
 
752
+ # --- Chat Input Logic --- (No change needed)
753
  if prompt := st.chat_input("Your message or follow-up query..."):
754
  if not st.session_state.patient_data:
755
+ st.warning("Please load patient data using the sidebar first."); st.stop()
 
756
 
 
757
  user_message = HumanMessage(content=prompt)
758
  st.session_state.messages.append(user_message)
759
+ with st.chat_message("user"): st.markdown(prompt)
 
760
 
761
+ current_state = AgentState(messages=st.session_state.messages, patient_data=st.session_state.patient_data)
 
 
 
 
762
 
 
763
  with st.spinner("SynapseAI is thinking..."):
764
  try:
765
+ final_state = st.session_state.graph_app.invoke(current_state, {"recursion_limit": 15})
 
 
 
 
 
766
  st.session_state.messages = final_state['messages']
 
767
  except Exception as e:
768
+ print(f"CRITICAL ERROR during graph invocation: {type(e).__name__} - {e}"); traceback.print_exc()
 
769
  st.error(f"An error occurred during the conversation turn: {e}", icon="❌")
770
+ # Optionally add error to history for user visibility
771
+ # error_ai_msg = AIMessage(content=f"Sorry, a critical error occurred: {type(e).__name__}. Please check logs or try again.")
772
+ # st.session_state.messages.append(error_ai_msg)
 
 
773
 
774
+ st.rerun() # Refresh display
 
775
 
776
+ # Disclaimer (No change needed)
777
  st.markdown("---")
778
+ st.warning("""**Disclaimer:** SynapseAI is an AI assistant... (Verify all outputs)""")
 
 
 
 
779
 
780
  if __name__ == "__main__":
781
  main()