mgbam commited on
Commit
9988477
Β·
verified Β·
1 Parent(s): b34efbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -390
app.py CHANGED
@@ -11,7 +11,7 @@ from dotenv import load_dotenv
11
  from langchain_groq import ChatGroq
12
  from langchain_community.tools.tavily_search import TavilySearchResults
13
  from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
14
- from langchain_core.prompts import ChatPromptTemplate
15
  from langchain_core.pydantic_v1 import BaseModel, Field
16
  from langchain_core.tools import tool
17
  from langgraph.prebuilt import ToolExecutor
@@ -33,9 +33,12 @@ if not GROQ_API_KEY: missing_keys.append("GROQ_API_KEY")
33
  if not TAVILY_API_KEY: missing_keys.append("TAVILY_API_KEY")
34
 
35
  if missing_keys:
 
36
  st.error(f"Missing required API Key(s): {', '.join(missing_keys)}. Please set them in Hugging Face Space Secrets or your environment variables.")
 
37
  st.stop()
38
 
 
39
  # --- Configuration & Constants ---
40
  class ClinicalAppSettings:
41
  APP_TITLE = "SynapseAI: Interactive Clinical Decision Support (UMLS/FDA Integrated)"
@@ -103,99 +106,57 @@ def get_rxcui(drug_name: str) -> Optional[str]:
103
  response = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10)
104
  response.raise_for_status()
105
  data = response.json()
106
- # Extract RxCUI - prioritize exact matches or common types
107
  if data and "idGroup" in data and "rxnormId" in data["idGroup"]:
108
- # Select the first one, assuming it's the most relevant by default.
109
- # More sophisticated logic could check TTYs (Term Types) if needed.
110
  rxcui = data["idGroup"]["rxnormId"][0]
111
  print(f" Found RxCUI: {rxcui} for '{drug_name}'")
112
  return rxcui
113
- else:
114
- # Fallback: Search /drugs endpoint if direct rxcui lookup fails
115
- params = {"name": drug_name}
116
- response = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10)
117
- response.raise_for_status()
118
- data = response.json()
119
  if data and "drugGroup" in data and "conceptGroup" in data["drugGroup"]:
120
  for group in data["drugGroup"]["conceptGroup"]:
121
- # Prioritize Semantic Types like Brand/Clinical Drug/Ingredient
122
  if group.get("tty") in ["SBD", "SCD", "GPCK", "BPCK", "IN", "MIN", "PIN"]:
123
  if "conceptProperties" in group and group["conceptProperties"]:
124
  rxcui = group["conceptProperties"][0].get("rxcui")
125
- if rxcui:
126
- print(f" Found RxCUI (via /drugs): {rxcui} for '{drug_name}'")
127
- return rxcui
128
  print(f" RxCUI not found for '{drug_name}'.")
129
  return None
130
- except requests.exceptions.RequestException as e:
131
- print(f" Error fetching RxCUI for '{drug_name}': {e}")
132
- return None
133
- except json.JSONDecodeError as e:
134
- print(f" Error decoding RxNorm JSON response for '{drug_name}': {e}")
135
- return None
136
- except Exception as e: # Catch any other unexpected error
137
- print(f" Unexpected error in get_rxcui for '{drug_name}': {e}")
138
- return None
139
 
140
  @lru_cache(maxsize=128) # Cache OpenFDA lookups
141
  def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[dict]:
142
  """Fetches drug label information from OpenFDA using RxCUI or drug name."""
143
  if not rxcui and not drug_name: return None
144
  print(f"OpenFDA Label Lookup for: RXCUI={rxcui}, Name={drug_name}")
145
-
146
  search_terms = []
147
- # Prioritize RxCUI lookup using multiple potential fields
148
- if rxcui:
149
- search_terms.append(f'spl_rxnorm_code:"{rxcui}"')
150
- search_terms.append(f'openfda.rxcui:"{rxcui}"')
151
- # Add name search as fallback or supplement
152
- if drug_name:
153
- search_terms.append(f'(openfda.brand_name:"{drug_name.lower()}" OR openfda.generic_name:"{drug_name.lower()}")')
154
-
155
- search_query = " OR ".join(search_terms)
156
- params = {"search": search_query, "limit": 1} # Get only the most relevant label
157
-
158
  try:
159
  response = requests.get(OPENFDA_API_BASE, params=params, timeout=15)
160
- response.raise_for_status()
161
- data = response.json()
162
- if data and "results" in data and data["results"]:
163
- print(f" Found OpenFDA label for query: {search_query}")
164
- return data["results"][0] # Return the first label found
165
- print(f" No OpenFDA label found for query: {search_query}")
166
- return None
167
- except requests.exceptions.RequestException as e:
168
- print(f" Error fetching OpenFDA label: {e}")
169
- return None
170
- except json.JSONDecodeError as e:
171
- print(f" Error decoding OpenFDA JSON response: {e}")
172
- return None
173
- except Exception as e:
174
- print(f" Unexpected error in get_openfda_label: {e}")
175
- return None
176
 
177
  def search_text_list(text_list: Optional[List[str]], search_terms: List[str]) -> List[str]:
178
  """ Case-insensitive search for any search_term within a list of text strings. Returns snippets. """
179
  found_snippets = []
180
  if not text_list or not search_terms: return found_snippets
181
- # Ensure search terms are lowercased strings
182
  search_terms_lower = [str(term).lower() for term in search_terms if term]
183
-
184
  for text_item in text_list:
185
- if not isinstance(text_item, str): continue # Skip non-string items
186
  text_item_lower = text_item.lower()
187
  for term in search_terms_lower:
188
  if term in text_item_lower:
189
- # Find the start index of the term
190
- start_index = text_item_lower.find(term)
191
- # Define snippet boundaries (e.g., 50 chars before, 100 after)
192
- snippet_start = max(0, start_index - 50)
193
- snippet_end = min(len(text_item), start_index + len(term) + 100)
194
- snippet = text_item[snippet_start:snippet_end]
195
- # Add indication of where the match is
196
- snippet = snippet.replace(term, f"**{term}**", 1) # Highlight first match
197
- found_snippets.append(f"...{snippet}...")
198
- break # Move to the next text item once a match is found
199
  return found_snippets
200
 
201
  # --- Other Helper Functions ---
@@ -208,7 +169,6 @@ def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
208
 
209
  def check_red_flags(patient_data: dict) -> List[str]:
210
  """Checks patient data against predefined red flags."""
211
- # (Keep the implementation from the previous full code listing)
212
  flags = []
213
  if not patient_data: return flags
214
  symptoms = patient_data.get("hpi", {}).get("symptoms", [])
@@ -243,12 +203,10 @@ def check_red_flags(patient_data: dict) -> List[str]:
243
  if "history of mi" in history_lower and "chest pain" in symptoms_lower: flags.append("Red Flag: History of MI with current Chest Pain.")
244
  if "history of dvt/pe" in history_lower and "shortness of breath" in symptoms_lower: flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.")
245
 
246
- return list(set(flags)) # Unique flags
247
-
248
 
249
  def format_patient_data_for_prompt(data: dict) -> str:
250
  """Formats the patient dictionary into a readable string for the LLM."""
251
- # (Keep the implementation from the previous full code listing)
252
  if not data: return "No patient data provided."
253
  prompt_str = ""
254
  for key, value in data.items():
@@ -268,7 +226,7 @@ def format_patient_data_for_prompt(data: dict) -> str:
268
 
269
  # --- Tool Definitions ---
270
 
271
- # Pydantic models for tool inputs
272
  class LabOrderInput(BaseModel):
273
  test_name: str = Field(..., description="Specific name of the lab test or panel (e.g., 'CBC', 'BMP', 'Troponin I', 'Urinalysis', 'D-dimer').")
274
  reason: str = Field(..., description="Clinical justification for ordering the test (e.g., 'Rule out infection', 'Assess renal function', 'Evaluate for ACS', 'Assess for PE').")
@@ -282,8 +240,6 @@ class PrescriptionInput(BaseModel):
282
  duration: str = Field("As directed", description="Duration of treatment (e.g., '7 days', '1 month', 'Ongoing', 'Until follow-up').")
283
  reason: str = Field(..., description="Clinical indication for the prescription.")
284
 
285
- # Updated InteractionCheckInput - Note: current_medications/allergies are Optional here
286
- # because they are populated by the tool_node from state *before* execution.
287
  class InteractionCheckInput(BaseModel):
288
  potential_prescription: str = Field(..., description="The name of the NEW medication being considered for prescribing.")
289
  current_medications: Optional[List[str]] = Field(None, description="List of patient's current medication names (populated from state).")
@@ -304,10 +260,8 @@ def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> st
304
  def prescribe_medication(medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str) -> str:
305
  """Prescribes a medication with detailed instructions and clinical indication. IMPORTANT: Requires prior interaction check."""
306
  print(f"Executing prescribe_medication: {medication_name} {dosage}...")
307
- # Safety check happens in tool_node *before* this is called.
308
  return json.dumps({"status": "success", "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", "details": f"Duration: {duration}. Reason: {reason}"})
309
 
310
- # --- NEW Interaction Check Tool using UMLS/RxNorm & OpenFDA ---
311
  @tool("check_drug_interactions", args_schema=InteractionCheckInput)
312
  def check_drug_interactions(potential_prescription: str, current_medications: Optional[List[str]] = None, allergies: Optional[List[str]] = None) -> str:
313
  """
@@ -318,103 +272,57 @@ def check_drug_interactions(potential_prescription: str, current_medications: Op
318
  print(f"Checking potential prescription: '{potential_prescription}'")
319
  warnings = []
320
  potential_med_lower = potential_prescription.lower().strip()
321
-
322
- # Use provided lists or default to empty
323
- current_meds_list = current_medications or []
324
- allergies_list = allergies or []
325
- # Clean and lowercase current med names (basic extraction: first word)
326
  current_med_names_lower = []
327
  for med in current_meds_list:
328
- match = re.match(r"^\s*([a-zA-Z\-]+)", str(med))
329
  if match: current_med_names_lower.append(match.group(1).lower())
330
- # Clean and lowercase allergies
331
  allergies_lower = [str(a).lower().strip() for a in allergies_list if a]
 
332
 
333
- print(f" Against Current Meds (names): {current_med_names_lower}")
334
- print(f" Against Allergies: {allergies_lower}")
335
-
336
- # --- Step 1: Normalize potential prescription ---
337
- print(f" Step 1: Normalizing '{potential_prescription}'...")
338
- potential_rxcui = get_rxcui(potential_prescription)
339
  potential_label = get_openfda_label(rxcui=potential_rxcui, drug_name=potential_prescription)
340
- if not potential_rxcui and not potential_label:
341
- print(f" Warning: Could not find RxCUI or OpenFDA label for '{potential_prescription}'. Interaction check will be limited.")
342
- warnings.append(f"INFO: Could not reliably identify '{potential_prescription}' in standard terminologies/databases. Checks may be incomplete.")
343
 
344
- # --- Step 2: Allergy Check ---
345
- print(" Step 2: Performing Allergy Check...")
346
- # Direct name match against patient's allergy list
347
  for allergy in allergies_lower:
348
- if allergy == potential_med_lower:
349
- warnings.append(f"CRITICAL ALLERGY (Name Match): Patient allergic to '{allergy}'. Potential prescription is '{potential_prescription}'.")
350
- # Basic cross-reactivity check (can be expanded)
351
- elif allergy in ["penicillin", "pcns"] and potential_med_lower in ["amoxicillin", "ampicillin", "augmentin", "piperacillin"]:
352
- warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to Penicillin. High risk with '{potential_prescription}'.")
353
- elif allergy == "sulfa" and potential_med_lower in ["sulfamethoxazole", "bactrim", "sulfasalazine"]:
354
- warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to Sulfa. High risk with '{potential_prescription}'.")
355
- elif allergy in ["nsaids", "aspirin"] and potential_med_lower in ["ibuprofen", "naproxen", "ketorolac", "diclofenac"]:
356
- warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to NSAIDs/Aspirin. Risk with '{potential_prescription}'.")
357
-
358
- # Check OpenFDA Label for Contraindications/Warnings related to ALLERGIES
359
  if potential_label:
360
- contraindications = potential_label.get("contraindications")
361
- warnings_section = potential_label.get("warnings_and_cautions") or potential_label.get("warnings")
362
-
363
  if contraindications:
364
  allergy_mentions_ci = search_text_list(contraindications, allergies_lower)
365
- if allergy_mentions_ci:
366
- warnings.append(f"ALLERGY RISK (Contraindication Found): Label for '{potential_prescription}' mentions contraindication potentially related to patient allergies: {'; '.join(allergy_mentions_ci)}")
367
-
368
  if warnings_section:
369
  allergy_mentions_warn = search_text_list(warnings_section, allergies_lower)
370
- if allergy_mentions_warn:
371
- warnings.append(f"ALLERGY RISK (Warning Found): Label for '{potential_prescription}' mentions warnings potentially related to patient allergies: {'; '.join(allergy_mentions_warn)}")
372
 
373
- # --- Step 3: Drug-Drug Interaction Check ---
374
  print(" Step 3: Performing Drug-Drug Interaction Check...")
375
- if potential_rxcui or potential_label: # Proceed only if we have info on the potential drug
376
  for current_med_name in current_med_names_lower:
377
- if not current_med_name or current_med_name == potential_med_lower: continue # Skip empty or self-interaction
378
-
379
  print(f" Checking interaction between '{potential_prescription}' and '{current_med_name}'...")
380
- current_rxcui = get_rxcui(current_med_name)
381
- current_label = get_openfda_label(rxcui=current_rxcui, drug_name=current_med_name)
382
-
383
- # Terms to search for in interaction text
384
- search_terms_for_current = [current_med_name]
385
- if current_rxcui: search_terms_for_current.append(current_rxcui) # Add RxCUI if found
386
-
387
- search_terms_for_potential = [potential_med_lower]
388
- if potential_rxcui: search_terms_for_potential.append(potential_rxcui) # Add RxCUI if found
389
-
390
  interaction_found_flag = False
391
- # Check Potential Drug's Label ('drug_interactions' section) for mentions of Current Drug
392
  if potential_label and potential_label.get("drug_interactions"):
393
  interaction_mentions = search_text_list(potential_label.get("drug_interactions"), search_terms_for_current)
394
- if interaction_mentions:
395
- warnings.append(f"Potential Interaction ({potential_prescription.capitalize()} Label): Mentions '{current_med_name.capitalize()}'. Snippets: {'; '.join(interaction_mentions)}")
396
- interaction_found_flag = True
397
-
398
- # Check Current Drug's Label ('drug_interactions' section) for mentions of Potential Drug
399
- if current_label and current_label.get("drug_interactions") and not interaction_found_flag: # Avoid duplicate warnings if already found
400
  interaction_mentions = search_text_list(current_label.get("drug_interactions"), search_terms_for_potential)
401
- if interaction_mentions:
402
- warnings.append(f"Potential Interaction ({current_med_name.capitalize()} Label): Mentions '{potential_prescription.capitalize()}'. Snippets: {'; '.join(interaction_mentions)}")
403
-
404
- else: # Case where potential drug wasn't identified
405
- warnings.append(f"INFO: Drug-drug interaction check skipped for '{potential_prescription}' as it could not be identified via RxNorm/OpenFDA.")
406
-
407
-
408
- # --- Step 4: Format Output ---
409
- final_warnings = list(set(warnings)) # Remove duplicates
410
- status = "warning" if any("CRITICAL" in w or "Interaction" in w or "RISK" in w for w in final_warnings) else "clear"
411
- if not final_warnings: status = "clear" # Ensure clear if no warnings remain
412
 
 
 
413
  message = f"Interaction/Allergy check for '{potential_prescription}': {len(final_warnings)} potential issue(s) identified using RxNorm/OpenFDA." if final_warnings else f"No major interactions or allergy issues identified for '{potential_prescription}' based on RxNorm/OpenFDA lookup."
414
  print(f"--- Interaction Check Complete for '{potential_prescription}' ---")
415
-
416
  return json.dumps({"status": status, "message": message, "warnings": final_warnings})
417
- # --- End of NEW Interaction Check Tool ---
418
 
419
  @tool("flag_risk", args_schema=FlagRiskInput)
420
  def flag_risk(risk_description: str, urgency: str) -> str:
@@ -424,45 +332,22 @@ def flag_risk(risk_description: str, urgency: str) -> str:
424
  return json.dumps({"status": "flagged", "message": f"Risk '{risk_description}' flagged with {urgency} urgency."})
425
 
426
  # Initialize Search Tool
427
- search_tool = TavilySearchResults(
428
- max_results=ClinicalAppSettings.MAX_SEARCH_RESULTS,
429
- name="tavily_search_results"
430
- )
431
 
432
  # --- LangGraph Setup ---
433
-
434
- # Define the state structure
435
  class AgentState(TypedDict):
436
- messages: Annotated[list[Any], operator.add]
437
- patient_data: Optional[dict]
438
-
439
- # Define Tools and Tool Executor
440
- tools = [
441
- order_lab_test,
442
- prescribe_medication,
443
- check_drug_interactions, # Using the new implementation
444
- flag_risk,
445
- search_tool
446
- ]
447
  tool_executor = ToolExecutor(tools)
448
-
449
- # Define the Agent Model
450
- model = ChatGroq(
451
- temperature=ClinicalAppSettings.TEMPERATURE,
452
- model=ClinicalAppSettings.MODEL_NAME,
453
- )
454
  model_with_tools = model.bind_tools(tools)
455
 
456
- # --- Graph Nodes (agent_node, tool_node remain mostly the same structurally) ---
457
-
458
- # 1. Agent Node: Calls the LLM (No change needed from previous version)
459
  def agent_node(state: AgentState):
460
- """Invokes the LLM to decide the next action or response."""
461
  print("\n---AGENT NODE---")
462
  current_messages = state['messages']
463
  if not current_messages or not isinstance(current_messages[0], SystemMessage):
464
- print("Prepending System Prompt.")
465
- current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages
466
  print(f"Invoking LLM with {len(current_messages)} messages.")
467
  try:
468
  response = model_with_tools.invoke(current_messages)
@@ -470,64 +355,36 @@ def agent_node(state: AgentState):
470
  if hasattr(response, 'tool_calls') and response.tool_calls: print(f"Agent Response Tool Calls: {response.tool_calls}")
471
  else: print("Agent Response: No tool calls.")
472
  except Exception as e:
473
- print(f"ERROR in agent_node during LLM invocation: {type(e).__name__} - {e}")
474
- traceback.print_exc()
475
  error_message = AIMessage(content=f"Sorry, an internal error occurred while processing the request: {type(e).__name__}")
476
  return {"messages": [error_message]}
477
  return {"messages": [response]}
478
 
479
- # 2. Tool Node: Executes tools (Mostly the same, ensures context injection)
480
  def tool_node(state: AgentState):
481
- """Executes tools called by the LLM and returns results."""
482
  print("\n---TOOL NODE---")
483
- tool_messages = []
484
- last_message = state['messages'][-1]
485
-
486
  if not isinstance(last_message, AIMessage) or not getattr(last_message, 'tool_calls', None):
487
- print("Warning: Tool node called unexpectedly without tool calls.")
488
- return {"messages": []}
489
-
490
- tool_calls = last_message.tool_calls
491
- print(f"Tool calls received: {json.dumps(tool_calls, indent=2)}")
492
-
493
- # Safety Check Logic (No change needed from previous version)
494
- prescriptions_requested = {}
495
- interaction_checks_requested = {}
496
  for call in tool_calls:
497
  tool_name = call.get('name'); tool_args = call.get('args', {})
498
- if tool_name == 'prescribe_medication':
499
- med_name = tool_args.get('medication_name', '').lower();
500
- if med_name: prescriptions_requested[med_name] = call
501
- elif tool_name == 'check_drug_interactions':
502
- potential_med = tool_args.get('potential_prescription', '').lower()
503
- if potential_med: interaction_checks_requested[potential_med] = call
504
-
505
- valid_tool_calls_for_execution = []
506
- blocked_ids = set()
507
  for med_name, prescribe_call in prescriptions_requested.items():
508
  if med_name not in interaction_checks_requested:
509
  st.error(f"**Safety Violation:** AI attempted to prescribe '{med_name}' without requesting `check_drug_interactions` in the *same turn*. Prescription blocked.")
510
  error_msg = ToolMessage(content=json.dumps({"status": "error", "message": f"Interaction check for '{med_name}' must be requested *before or alongside* the prescription call."}), tool_call_id=prescribe_call['id'], name=prescribe_call['name'])
511
- tool_messages.append(error_msg)
512
- blocked_ids.add(prescribe_call['id'])
513
-
514
  valid_tool_calls_for_execution = [call for call in tool_calls if call['id'] not in blocked_ids]
515
-
516
- # Augment interaction checks with patient data (Crucial part - no change needed here)
517
- patient_data = state.get("patient_data", {})
518
- patient_meds_full = patient_data.get("medications", {}).get("current", []) # Pass full med list if needed by tool
519
- patient_allergies = patient_data.get("allergies", [])
520
-
521
  for call in valid_tool_calls_for_execution:
522
  if call['name'] == 'check_drug_interactions':
523
  if 'args' not in call: call['args'] = {}
524
- # Pass the necessary context from patient_data to the tool arguments
525
- # The tool function expects 'current_medications' (list of names) and 'allergies'
526
- call['args']['current_medications'] = patient_meds_full # Pass the full strings
527
- call['args']['allergies'] = patient_allergies
528
- print(f"Augmented interaction check args for call ID {call['id']}") # Removed args content for brevity
529
-
530
- # Execute valid tool calls (No change needed from previous version)
531
  if valid_tool_calls_for_execution:
532
  print(f"Attempting to execute {len(valid_tool_calls_for_execution)} tools: {[c['name'] for c in valid_tool_calls_for_execution]}")
533
  try:
@@ -535,151 +392,79 @@ def tool_node(state: AgentState):
535
  for call, resp in zip(valid_tool_calls_for_execution, responses):
536
  tool_call_id = call['id']; tool_name = call['name']
537
  if isinstance(resp, Exception):
538
- error_type = type(resp).__name__; error_str = str(resp)
539
- print(f"ERROR executing tool '{tool_name}' (ID: {tool_call_id}): {error_type} - {error_str}")
540
- traceback.print_exc()
541
- st.error(f"Error executing action '{tool_name}': {error_type}")
542
- error_content = json.dumps({"status": "error", "message": f"Failed to execute '{tool_name}': {error_type} - {error_str}"})
543
  tool_messages.append(ToolMessage(content=error_content, tool_call_id=tool_call_id, name=tool_name))
544
- if isinstance(resp, AttributeError) and "'dict' object has no attribute 'tool'" in error_str:
545
- print("\n *** DETECTED SPECIFIC ATTRIBUTE ERROR ('dict' object has no attribute 'tool') *** \n")
546
  else:
547
- print(f"Tool '{tool_name}' (ID: {tool_call_id}) executed successfully.")
548
- content_str = str(resp)
549
- tool_messages.append(ToolMessage(content=content_str, tool_call_id=tool_call_id, name=tool_name))
550
  except Exception as e:
551
- print(f"CRITICAL UNEXPECTED ERROR within tool_node logic: {type(e).__name__} - {e}")
552
- traceback.print_exc(); st.error(f"Critical internal error processing actions: {e}")
553
- error_content = json.dumps({"status": "error", "message": f"Internal error processing tools: {e}"})
554
- processed_ids = {msg.tool_call_id for msg in tool_messages}
555
  for call in valid_tool_calls_for_execution:
556
  if call['id'] not in processed_ids: tool_messages.append(ToolMessage(content=error_content, tool_call_id=call['id'], name=call['name']))
 
557
 
558
- print(f"Returning {len(tool_messages)} tool messages.")
559
- return {"messages": tool_messages}
560
-
561
-
562
- # --- Graph Edges (Routing Logic) --- (No change needed)
563
  def should_continue(state: AgentState) -> str:
564
- """Determines whether to call tools, end the conversation turn, or handle errors."""
565
- print("\n---ROUTING DECISION---")
566
- last_message = state['messages'][-1] if state['messages'] else None
567
  if not isinstance(last_message, AIMessage): return "end_conversation_turn"
568
  if "Sorry, an internal error occurred" in last_message.content: return "end_conversation_turn"
569
  if getattr(last_message, 'tool_calls', None): return "continue_tools"
570
  else: return "end_conversation_turn"
571
 
572
- # --- Graph Definition & Compilation --- (No change needed)
573
- workflow = StateGraph(AgentState)
574
- workflow.add_node("agent", agent_node)
575
- workflow.add_node("tools", tool_node)
576
- workflow.set_entry_point("agent")
577
- workflow.add_conditional_edges("agent", should_continue, {"continue_tools": "tools", "end_conversation_turn": END})
578
- workflow.add_edge("tools", "agent")
579
- app = workflow.compile()
580
- print("LangGraph compiled successfully.")
581
 
582
  # --- Streamlit UI ---
583
  def main():
584
  st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
585
  st.title(f"🩺 {ClinicalAppSettings.APP_TITLE}")
586
  st.caption(f"Interactive Assistant | LangGraph/Groq/Tavily/UMLS/OpenFDA | Model: {ClinicalAppSettings.MODEL_NAME}")
587
-
588
- # Initialize session state (No change needed)
589
  if "messages" not in st.session_state: st.session_state.messages = []
590
  if "patient_data" not in st.session_state: st.session_state.patient_data = None
591
  if "graph_app" not in st.session_state: st.session_state.graph_app = app
592
 
593
- # --- Patient Data Input Sidebar --- (Adjusted allergy/med extraction slightly)
594
  with st.sidebar:
595
  st.header("πŸ“„ Patient Intake Form")
596
- # Demographics, HPI, History, Social/Family, Vitals/Exam sections remain the same input fields
597
- # ... (Copy input fields from previous full code version) ...
598
- st.subheader("Demographics")
599
- age = st.number_input("Age", min_value=0, max_value=120, value=55, key="age_input")
600
- sex = st.selectbox("Biological Sex", ["Male", "Female", "Other/Prefer not to say"], key="sex_input")
601
- st.subheader("History of Present Illness (HPI)")
602
- chief_complaint = st.text_input("Chief Complaint", "Chest pain", key="cc_input")
603
- hpi_details = st.text_area("Detailed HPI", "55 y/o male presents with substernal chest pain started 2 hours ago...", key="hpi_input", height=150)
604
- symptoms = st.multiselect("Associated Symptoms", ["Nausea", "Diaphoresis", "Shortness of Breath", "Dizziness", "Palpitations", "Fever", "Cough", "Severe Headache", "Syncope", "Hemoptysis"], default=["Nausea", "Diaphoresis"], key="sym_input")
605
- st.subheader("Past History")
606
- pmh = st.text_area("Past Medical History (PMH)", "Hypertension (HTN), Hyperlipidemia (HLD), Type 2 Diabetes Mellitus (DM2), History of MI", key="pmh_input")
607
- psh = st.text_area("Past Surgical History (PSH)", "Appendectomy (2005)", key="psh_input")
608
- st.subheader("Medications & Allergies")
609
- current_meds_str = st.text_area("Current Medications (name, dose, freq)", "Lisinopril 10mg daily\nMetformin 1000mg BID\nAtorvastatin 40mg daily\nAspirin 81mg daily", key="meds_input")
610
- allergies_str = st.text_area("Allergies (comma separated, specify reaction if known)", "Penicillin (rash), Sulfa (hives)", key="allergy_input")
611
- st.subheader("Social/Family History")
612
- social_history = st.text_area("Social History (SH)", "Smoker (1 ppd x 30 years), occasional alcohol.", key="sh_input")
613
- family_history = st.text_area("Family History (FHx)", "Father had MI at age 60. Mother has HTN.", key="fhx_input")
614
- st.subheader("Vitals & Exam Findings")
615
- col1, col2 = st.columns(2)
616
- with col1:
617
- temp_c = st.number_input("Temp (Β°C)", 35.0, 42.0, 36.8, format="%.1f", key="temp_input")
618
- hr_bpm = st.number_input("HR (bpm)", 30, 250, 95, key="hr_input")
619
- rr_rpm = st.number_input("RR (rpm)", 5, 50, 18, key="rr_input")
620
- with col2:
621
- bp_mmhg = st.text_input("BP (SYS/DIA)", "155/90", key="bp_input")
622
- spo2_percent = st.number_input("SpO2 (%)", 70, 100, 96, key="spo2_input")
623
- pain_scale = st.slider("Pain (0-10)", 0, 10, 8, key="pain_input")
624
- exam_notes = st.text_area("Brief Physical Exam Notes", "Awake, alert, oriented x3. Mild distress. Lungs clear bilaterally...", key="exam_input", height=100)
625
-
626
-
627
- # Compile Patient Data Dictionary (Refined Extraction for Tool Use)
628
- if st.button("Start/Update Consultation", key="start_button"):
629
- # Store full medication strings for display/context
630
  current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
631
- # Extract just the names (simplified) for the interaction check tool's state population
632
- current_med_names_only = []
633
- for med in current_meds_list:
634
- match = re.match(r"^\s*([a-zA-Z\-]+)", med)
635
- if match: current_med_names_only.append(match.group(1).lower())
636
-
637
- # Extract allergy names (simplified, before parenthesis)
638
  allergies_list = []
639
  for a in allergies_str.split(','):
640
- cleaned_allergy = a.strip()
641
- if cleaned_allergy:
642
- match = re.match(r"^\s*([a-zA-Z\-\s/]+)(?:\s*\(.*\))?", cleaned_allergy)
643
- name_part = match.group(1).strip().lower() if match else cleaned_allergy.lower()
644
- allergies_list.append(name_part)
645
-
646
- st.session_state.patient_data = {
647
- "demographics": {"age": age, "sex": sex},
648
- "hpi": {"chief_complaint": chief_complaint, "details": hpi_details, "symptoms": symptoms},
649
- "pmh": {"conditions": pmh}, "psh": {"procedures": psh},
650
- # Store both full list and names_only list
651
- "medications": {"current": current_meds_list, "names_only": current_med_names_only},
652
- "allergies": allergies_list, # Store cleaned list
653
- "social_history": {"details": social_history}, "family_history": {"details": family_history},
654
- "vitals": { "temp_c": temp_c, "hr_bpm": hr_bpm, "bp_mmhg": bp_mmhg, "rr_rpm": rr_rpm, "spo2_percent": spo2_percent, "pain_scale": pain_scale},
655
- "exam_findings": {"notes": exam_notes}
656
- }
657
-
658
- # Initial Red Flag Check
659
- red_flags = check_red_flags(st.session_state.patient_data)
660
- st.sidebar.markdown("---")
661
- if red_flags:
662
- st.sidebar.warning("**Initial Red Flags Detected:**")
663
- for flag in red_flags: st.sidebar.warning(f"- {flag.replace('Red Flag: ','')}")
664
- else: st.sidebar.success("No immediate red flags detected.")
665
-
666
- # Prepare initial message & reset history
667
- initial_prompt = "Initiate consultation for the patient described in the intake form. Review the data and begin analysis."
668
- st.session_state.messages = [HumanMessage(content=initial_prompt)]
669
- st.success("Patient data loaded/updated. Ready for analysis.")
670
-
671
-
672
- # --- Main Chat Interface Area --- (No change needed in display logic)
673
  st.header("πŸ’¬ Clinical Consultation")
674
-
675
- # Display chat messages from history
676
- # (Copy the message display loop from the previous full code version)
677
- for msg_index, msg in enumerate(st.session_state.messages):
678
- unique_key = f"msg_{msg_index}"
679
  if isinstance(msg, HumanMessage):
680
- with st.chat_message("user", key=f"{unique_key}_user"): st.markdown(msg.content)
681
  elif isinstance(msg, AIMessage):
682
- with st.chat_message("assistant", key=f"{unique_key}_ai"):
683
  ai_content = msg.content; structured_output = None
684
  try:
685
  json_match = re.search(r"```json\s*(\{.*?\})\s*```", ai_content, re.DOTALL | re.IGNORECASE)
@@ -692,90 +477,58 @@ def main():
692
  structured_output = json.loads(ai_content); ai_content = ""
693
  else: st.markdown(ai_content)
694
  except Exception as e: st.markdown(ai_content); print(f"Error parsing/displaying AI JSON: {e}")
695
-
696
  if structured_output and isinstance(structured_output, dict):
697
- # (Copy the structured JSON display logic from previous full code)
698
- st.divider(); st.subheader("πŸ“Š AI Analysis & Recommendations")
699
  cols = st.columns(2)
700
- with cols[0]:
701
- st.markdown(f"**Assessment:**"); st.markdown(f"> {structured_output.get('assessment', 'N/A')}")
702
- st.markdown(f"**Differential Diagnosis:**")
703
- ddx = structured_output.get('differential_diagnosis', []);
704
- if ddx:
705
- for item in ddx:
706
- likelihood = item.get('likelihood', '?').capitalize(); icon = "πŸ₯‡" if likelihood=="High" else ("πŸ₯ˆ" if likelihood=="Medium" else "πŸ₯‰")
707
- with st.expander(f"{icon} {item.get('diagnosis', 'Unknown')} ({likelihood})"): st.write(f"**Rationale:** {item.get('rationale', 'N/A')}")
708
  else: st.info("No DDx provided.")
709
- st.markdown(f"**Risk Assessment:**"); risk = structured_output.get('risk_assessment', {})
710
- flags = risk.get('identified_red_flags', []); concerns = risk.get("immediate_concerns", []); comps = risk.get("potential_complications", [])
711
  if flags: st.warning(f"**Flags:** {', '.join(flags)}")
712
  if concerns: st.warning(f"**Concerns:** {', '.join(concerns)}")
713
  if comps: st.info(f"**Potential Complications:** {', '.join(comps)}")
714
  if not flags and not concerns: st.success("No major risks highlighted.")
715
- with cols[1]:
716
- st.markdown(f"**Recommended Plan:**"); plan = structured_output.get('recommended_plan', {})
717
- for section in ["investigations", "therapeutics", "consultations", "patient_education"]:
718
- st.markdown(f"_{section.replace('_',' ').capitalize()}:_"); items = plan.get(section)
719
- if items: [st.markdown(f"- {item}") for item in items] if isinstance(items, list) else st.markdown(f"- {items}")
720
- else: st.markdown("_None suggested._")
721
- st.markdown("") # Space
722
- st.markdown(f"**Rationale & Guideline Check:**"); st.markdown(f"> {structured_output.get('rationale_summary', 'N/A')}")
723
  interaction_summary = structured_output.get("interaction_check_summary", "")
724
- if interaction_summary: st.markdown(f"**Interaction Check Summary:**"); st.markdown(f"> {interaction_summary}")
725
  st.divider()
726
-
727
  if getattr(msg, 'tool_calls', None):
728
- with st.expander("πŸ› οΈ AI requested actions", expanded=False):
729
- for tc in msg.tool_calls:
730
- try: st.code(f"Action: {tc.get('name', 'Unknown')}\nArgs: {json.dumps(tc.get('args', {}), indent=2)}", language="json")
731
- except Exception as display_e: st.error(f"Could not display tool call: {display_e}"); st.code(str(tc))
732
-
733
  elif isinstance(msg, ToolMessage):
734
  tool_name_display = getattr(msg, 'name', 'tool_execution')
735
- with st.chat_message(tool_name_display, avatar="πŸ› οΈ", key=f"{unique_key}_tool"):
736
- # (Copy the ToolMessage display logic from previous full code)
737
- try:
738
- tool_data = json.loads(msg.content); status = tool_data.get("status", "info"); message = tool_data.get("message", msg.content)
739
- details = tool_data.get("details"); warnings = tool_data.get("warnings")
740
  if status == "success" or status == "clear" or status == "flagged": st.success(f"{message}", icon="βœ…" if status != "flagged" else "🚨")
741
- elif status == "warning":
742
- st.warning(f"{message}", icon="⚠️")
743
- if warnings and isinstance(warnings, list):
744
- st.caption("Details:")
745
- for warn in warnings: st.caption(f"- {warn}") # Display warnings from the tool output JSON
746
  else: st.error(f"{message}", icon="❌")
747
  if details: st.caption(f"Details: {details}")
748
- except json.JSONDecodeError: st.info(f"{msg.content}") # Display raw if not JSON
749
  except Exception as e: st.error(f"Error displaying tool message: {e}", icon="❌"); st.caption(f"Raw content: {msg.content}")
750
 
751
-
752
- # --- Chat Input Logic --- (No change needed)
753
  if prompt := st.chat_input("Your message or follow-up query..."):
754
- if not st.session_state.patient_data:
755
- st.warning("Please load patient data using the sidebar first."); st.stop()
756
-
757
- user_message = HumanMessage(content=prompt)
758
- st.session_state.messages.append(user_message)
759
- with st.chat_message("user"): st.markdown(prompt)
760
-
761
  current_state = AgentState(messages=st.session_state.messages, patient_data=st.session_state.patient_data)
762
-
763
  with st.spinner("SynapseAI is thinking..."):
764
  try:
765
  final_state = st.session_state.graph_app.invoke(current_state, {"recursion_limit": 15})
766
- st.session_state.messages = final_state['messages']
767
- except Exception as e:
768
- print(f"CRITICAL ERROR during graph invocation: {type(e).__name__} - {e}"); traceback.print_exc()
769
- st.error(f"An error occurred during the conversation turn: {e}", icon="❌")
770
- # Optionally add error to history for user visibility
771
- # error_ai_msg = AIMessage(content=f"Sorry, a critical error occurred: {type(e).__name__}. Please check logs or try again.")
772
- # st.session_state.messages.append(error_ai_msg)
773
-
774
  st.rerun() # Refresh display
775
 
776
- # Disclaimer (No change needed)
777
- st.markdown("---")
778
- st.warning("""**Disclaimer:** SynapseAI is an AI assistant... (Verify all outputs)""")
779
 
780
  if __name__ == "__main__":
781
  main()
 
11
  from langchain_groq import ChatGroq
12
  from langchain_community.tools.tavily_search import TavilySearchResults
13
  from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
14
+ # from langchain_core.prompts import ChatPromptTemplate # Not explicitly used in this version
15
  from langchain_core.pydantic_v1 import BaseModel, Field
16
  from langchain_core.tools import tool
17
  from langgraph.prebuilt import ToolExecutor
 
33
  if not TAVILY_API_KEY: missing_keys.append("TAVILY_API_KEY")
34
 
35
  if missing_keys:
36
+ # Use st.error which stops execution in recent Streamlit versions
37
  st.error(f"Missing required API Key(s): {', '.join(missing_keys)}. Please set them in Hugging Face Space Secrets or your environment variables.")
38
+ # Ensure execution stops if st.error doesn't automatically do it in the environment
39
  st.stop()
40
 
41
+
42
  # --- Configuration & Constants ---
43
  class ClinicalAppSettings:
44
  APP_TITLE = "SynapseAI: Interactive Clinical Decision Support (UMLS/FDA Integrated)"
 
106
  response = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10)
107
  response.raise_for_status()
108
  data = response.json()
 
109
  if data and "idGroup" in data and "rxnormId" in data["idGroup"]:
 
 
110
  rxcui = data["idGroup"]["rxnormId"][0]
111
  print(f" Found RxCUI: {rxcui} for '{drug_name}'")
112
  return rxcui
113
+ else: # Fallback search
114
+ params = {"name": drug_name}; response = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10)
115
+ response.raise_for_status(); data = response.json()
 
 
 
116
  if data and "drugGroup" in data and "conceptGroup" in data["drugGroup"]:
117
  for group in data["drugGroup"]["conceptGroup"]:
 
118
  if group.get("tty") in ["SBD", "SCD", "GPCK", "BPCK", "IN", "MIN", "PIN"]:
119
  if "conceptProperties" in group and group["conceptProperties"]:
120
  rxcui = group["conceptProperties"][0].get("rxcui")
121
+ if rxcui: print(f" Found RxCUI (via /drugs): {rxcui} for '{drug_name}'"); return rxcui
 
 
122
  print(f" RxCUI not found for '{drug_name}'.")
123
  return None
124
+ except requests.exceptions.RequestException as e: print(f" Error fetching RxCUI for '{drug_name}': {e}"); return None
125
+ except json.JSONDecodeError as e: print(f" Error decoding RxNorm JSON response for '{drug_name}': {e}"); return None
126
+ except Exception as e: print(f" Unexpected error in get_rxcui for '{drug_name}': {e}"); return None
 
 
 
 
 
 
127
 
128
  @lru_cache(maxsize=128) # Cache OpenFDA lookups
129
  def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[dict]:
130
  """Fetches drug label information from OpenFDA using RxCUI or drug name."""
131
  if not rxcui and not drug_name: return None
132
  print(f"OpenFDA Label Lookup for: RXCUI={rxcui}, Name={drug_name}")
 
133
  search_terms = []
134
+ if rxcui: search_terms.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"')
135
+ if drug_name: search_terms.append(f'(openfda.brand_name:"{drug_name.lower()}" OR openfda.generic_name:"{drug_name.lower()}")')
136
+ search_query = " OR ".join(search_terms); params = {"search": search_query, "limit": 1}
 
 
 
 
 
 
 
 
137
  try:
138
  response = requests.get(OPENFDA_API_BASE, params=params, timeout=15)
139
+ response.raise_for_status(); data = response.json()
140
+ if data and "results" in data and data["results"]: print(f" Found OpenFDA label for query: {search_query}"); return data["results"][0]
141
+ print(f" No OpenFDA label found for query: {search_query}"); return None
142
+ except requests.exceptions.RequestException as e: print(f" Error fetching OpenFDA label: {e}"); return None
143
+ except json.JSONDecodeError as e: print(f" Error decoding OpenFDA JSON response: {e}"); return None
144
+ except Exception as e: print(f" Unexpected error in get_openfda_label: {e}"); return None
 
 
 
 
 
 
 
 
 
 
145
 
146
  def search_text_list(text_list: Optional[List[str]], search_terms: List[str]) -> List[str]:
147
  """ Case-insensitive search for any search_term within a list of text strings. Returns snippets. """
148
  found_snippets = []
149
  if not text_list or not search_terms: return found_snippets
 
150
  search_terms_lower = [str(term).lower() for term in search_terms if term]
 
151
  for text_item in text_list:
152
+ if not isinstance(text_item, str): continue
153
  text_item_lower = text_item.lower()
154
  for term in search_terms_lower:
155
  if term in text_item_lower:
156
+ start_index = text_item_lower.find(term); snippet_start = max(0, start_index - 50)
157
+ snippet_end = min(len(text_item), start_index + len(term) + 100); snippet = text_item[snippet_start:snippet_end]
158
+ snippet = snippet.replace(term, f"**{term}**", 1); found_snippets.append(f"...{snippet}...")
159
+ break
 
 
 
 
 
 
160
  return found_snippets
161
 
162
  # --- Other Helper Functions ---
 
169
 
170
  def check_red_flags(patient_data: dict) -> List[str]:
171
  """Checks patient data against predefined red flags."""
 
172
  flags = []
173
  if not patient_data: return flags
174
  symptoms = patient_data.get("hpi", {}).get("symptoms", [])
 
203
  if "history of mi" in history_lower and "chest pain" in symptoms_lower: flags.append("Red Flag: History of MI with current Chest Pain.")
204
  if "history of dvt/pe" in history_lower and "shortness of breath" in symptoms_lower: flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.")
205
 
206
+ return list(set(flags))
 
207
 
208
  def format_patient_data_for_prompt(data: dict) -> str:
209
  """Formats the patient dictionary into a readable string for the LLM."""
 
210
  if not data: return "No patient data provided."
211
  prompt_str = ""
212
  for key, value in data.items():
 
226
 
227
  # --- Tool Definitions ---
228
 
229
+ # Pydantic models
230
  class LabOrderInput(BaseModel):
231
  test_name: str = Field(..., description="Specific name of the lab test or panel (e.g., 'CBC', 'BMP', 'Troponin I', 'Urinalysis', 'D-dimer').")
232
  reason: str = Field(..., description="Clinical justification for ordering the test (e.g., 'Rule out infection', 'Assess renal function', 'Evaluate for ACS', 'Assess for PE').")
 
240
  duration: str = Field("As directed", description="Duration of treatment (e.g., '7 days', '1 month', 'Ongoing', 'Until follow-up').")
241
  reason: str = Field(..., description="Clinical indication for the prescription.")
242
 
 
 
243
  class InteractionCheckInput(BaseModel):
244
  potential_prescription: str = Field(..., description="The name of the NEW medication being considered for prescribing.")
245
  current_medications: Optional[List[str]] = Field(None, description="List of patient's current medication names (populated from state).")
 
260
  def prescribe_medication(medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str) -> str:
261
  """Prescribes a medication with detailed instructions and clinical indication. IMPORTANT: Requires prior interaction check."""
262
  print(f"Executing prescribe_medication: {medication_name} {dosage}...")
 
263
  return json.dumps({"status": "success", "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", "details": f"Duration: {duration}. Reason: {reason}"})
264
 
 
265
  @tool("check_drug_interactions", args_schema=InteractionCheckInput)
266
  def check_drug_interactions(potential_prescription: str, current_medications: Optional[List[str]] = None, allergies: Optional[List[str]] = None) -> str:
267
  """
 
272
  print(f"Checking potential prescription: '{potential_prescription}'")
273
  warnings = []
274
  potential_med_lower = potential_prescription.lower().strip()
275
+ current_meds_list = current_medications or []; allergies_list = allergies or []
 
 
 
 
276
  current_med_names_lower = []
277
  for med in current_meds_list:
278
+ match = re.match(r"^\s*([a-zA-Z\-]+)", str(med));
279
  if match: current_med_names_lower.append(match.group(1).lower())
 
280
  allergies_lower = [str(a).lower().strip() for a in allergies_list if a]
281
+ print(f" Against Current Meds (names): {current_med_names_lower}"); print(f" Against Allergies: {allergies_lower}")
282
 
283
+ print(f" Step 1: Normalizing '{potential_prescription}'..."); potential_rxcui = get_rxcui(potential_prescription)
 
 
 
 
 
284
  potential_label = get_openfda_label(rxcui=potential_rxcui, drug_name=potential_prescription)
285
+ if not potential_rxcui and not potential_label: warnings.append(f"INFO: Could not reliably identify '{potential_prescription}'. Checks may be incomplete.")
 
 
286
 
287
+ print(" Step 2: Performing Allergy Check...");
 
 
288
  for allergy in allergies_lower:
289
+ if allergy == potential_med_lower: warnings.append(f"CRITICAL ALLERGY (Name Match): Patient allergic to '{allergy}'. Potential prescription is '{potential_prescription}'.")
290
+ elif allergy in ["penicillin", "pcns"] and potential_med_lower in ["amoxicillin", "ampicillin", "augmentin", "piperacillin"]: warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to Penicillin. High risk with '{potential_prescription}'.")
291
+ elif allergy == "sulfa" and potential_med_lower in ["sulfamethoxazole", "bactrim", "sulfasalazine"]: warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to Sulfa. High risk with '{potential_prescription}'.")
292
+ elif allergy in ["nsaids", "aspirin"] and potential_med_lower in ["ibuprofen", "naproxen", "ketorolac", "diclofenac"]: warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to NSAIDs/Aspirin. Risk with '{potential_prescription}'.")
 
 
 
 
 
 
 
293
  if potential_label:
294
+ contraindications = potential_label.get("contraindications"); warnings_section = potential_label.get("warnings_and_cautions") or potential_label.get("warnings")
 
 
295
  if contraindications:
296
  allergy_mentions_ci = search_text_list(contraindications, allergies_lower)
297
+ if allergy_mentions_ci: warnings.append(f"ALLERGY RISK (Contraindication Found): Label for '{potential_prescription}' mentions contraindication potentially related to patient allergies: {'; '.join(allergy_mentions_ci)}")
 
 
298
  if warnings_section:
299
  allergy_mentions_warn = search_text_list(warnings_section, allergies_lower)
300
+ if allergy_mentions_warn: warnings.append(f"ALLERGY RISK (Warning Found): Label for '{potential_prescription}' mentions warnings potentially related to patient allergies: {'; '.join(allergy_mentions_warn)}")
 
301
 
 
302
  print(" Step 3: Performing Drug-Drug Interaction Check...")
303
+ if potential_rxcui or potential_label:
304
  for current_med_name in current_med_names_lower:
305
+ if not current_med_name or current_med_name == potential_med_lower: continue
 
306
  print(f" Checking interaction between '{potential_prescription}' and '{current_med_name}'...")
307
+ current_rxcui = get_rxcui(current_med_name); current_label = get_openfda_label(rxcui=current_rxcui, drug_name=current_med_name)
308
+ search_terms_for_current = [current_med_name];
309
+ if current_rxcui: search_terms_for_current.append(current_rxcui)
310
+ search_terms_for_potential = [potential_med_lower];
311
+ if potential_rxcui: search_terms_for_potential.append(potential_rxcui)
 
 
 
 
 
312
  interaction_found_flag = False
 
313
  if potential_label and potential_label.get("drug_interactions"):
314
  interaction_mentions = search_text_list(potential_label.get("drug_interactions"), search_terms_for_current)
315
+ if interaction_mentions: warnings.append(f"Potential Interaction ({potential_prescription.capitalize()} Label): Mentions '{current_med_name.capitalize()}'. Snippets: {'; '.join(interaction_mentions)}"); interaction_found_flag = True
316
+ if current_label and current_label.get("drug_interactions") and not interaction_found_flag:
 
 
 
 
317
  interaction_mentions = search_text_list(current_label.get("drug_interactions"), search_terms_for_potential)
318
+ if interaction_mentions: warnings.append(f"Potential Interaction ({current_med_name.capitalize()} Label): Mentions '{potential_prescription.capitalize()}'. Snippets: {'; '.join(interaction_mentions)}")
319
+ else: warnings.append(f"INFO: Drug-drug interaction check skipped for '{potential_prescription}' as it could not be identified via RxNorm/OpenFDA.")
 
 
 
 
 
 
 
 
 
320
 
321
+ final_warnings = list(set(warnings)); status = "warning" if any("CRITICAL" in w or "Interaction" in w or "RISK" in w for w in final_warnings) else "clear"
322
+ if not final_warnings: status = "clear"
323
  message = f"Interaction/Allergy check for '{potential_prescription}': {len(final_warnings)} potential issue(s) identified using RxNorm/OpenFDA." if final_warnings else f"No major interactions or allergy issues identified for '{potential_prescription}' based on RxNorm/OpenFDA lookup."
324
  print(f"--- Interaction Check Complete for '{potential_prescription}' ---")
 
325
  return json.dumps({"status": status, "message": message, "warnings": final_warnings})
 
326
 
327
  @tool("flag_risk", args_schema=FlagRiskInput)
328
  def flag_risk(risk_description: str, urgency: str) -> str:
 
332
  return json.dumps({"status": "flagged", "message": f"Risk '{risk_description}' flagged with {urgency} urgency."})
333
 
334
  # Initialize Search Tool
335
+ search_tool = TavilySearchResults(max_results=ClinicalAppSettings.MAX_SEARCH_RESULTS, name="tavily_search_results")
 
 
 
336
 
337
  # --- LangGraph Setup ---
 
 
338
  class AgentState(TypedDict):
339
+ messages: Annotated[list[Any], operator.add]; patient_data: Optional[dict]
340
+ tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool]
 
 
 
 
 
 
 
 
 
341
  tool_executor = ToolExecutor(tools)
342
+ model = ChatGroq(temperature=ClinicalAppSettings.TEMPERATURE, model=ClinicalAppSettings.MODEL_NAME)
 
 
 
 
 
343
  model_with_tools = model.bind_tools(tools)
344
 
345
+ # --- Graph Nodes ---
 
 
346
  def agent_node(state: AgentState):
 
347
  print("\n---AGENT NODE---")
348
  current_messages = state['messages']
349
  if not current_messages or not isinstance(current_messages[0], SystemMessage):
350
+ print("Prepending System Prompt."); current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages
 
351
  print(f"Invoking LLM with {len(current_messages)} messages.")
352
  try:
353
  response = model_with_tools.invoke(current_messages)
 
355
  if hasattr(response, 'tool_calls') and response.tool_calls: print(f"Agent Response Tool Calls: {response.tool_calls}")
356
  else: print("Agent Response: No tool calls.")
357
  except Exception as e:
358
+ print(f"ERROR in agent_node during LLM invocation: {type(e).__name__} - {e}"); traceback.print_exc()
 
359
  error_message = AIMessage(content=f"Sorry, an internal error occurred while processing the request: {type(e).__name__}")
360
  return {"messages": [error_message]}
361
  return {"messages": [response]}
362
 
 
363
  def tool_node(state: AgentState):
 
364
  print("\n---TOOL NODE---")
365
+ tool_messages = []; last_message = state['messages'][-1]
 
 
366
  if not isinstance(last_message, AIMessage) or not getattr(last_message, 'tool_calls', None):
367
+ print("Warning: Tool node called unexpectedly without tool calls."); return {"messages": []}
368
+ tool_calls = last_message.tool_calls; print(f"Tool calls received: {json.dumps(tool_calls, indent=2)}")
369
+ prescriptions_requested = {}; interaction_checks_requested = {}
 
 
 
 
 
 
370
  for call in tool_calls:
371
  tool_name = call.get('name'); tool_args = call.get('args', {})
372
+ if tool_name == 'prescribe_medication': med_name = tool_args.get('medication_name', '').lower();
373
+ if med_name: prescriptions_requested[med_name] = call
374
+ elif tool_name == 'check_drug_interactions': potential_med = tool_args.get('potential_prescription', '').lower()
375
+ if potential_med: interaction_checks_requested[potential_med] = call
376
+ valid_tool_calls_for_execution = []; blocked_ids = set()
 
 
 
 
377
  for med_name, prescribe_call in prescriptions_requested.items():
378
  if med_name not in interaction_checks_requested:
379
  st.error(f"**Safety Violation:** AI attempted to prescribe '{med_name}' without requesting `check_drug_interactions` in the *same turn*. Prescription blocked.")
380
  error_msg = ToolMessage(content=json.dumps({"status": "error", "message": f"Interaction check for '{med_name}' must be requested *before or alongside* the prescription call."}), tool_call_id=prescribe_call['id'], name=prescribe_call['name'])
381
+ tool_messages.append(error_msg); blocked_ids.add(prescribe_call['id'])
 
 
382
  valid_tool_calls_for_execution = [call for call in tool_calls if call['id'] not in blocked_ids]
383
+ patient_data = state.get("patient_data", {}); patient_meds_full = patient_data.get("medications", {}).get("current", []); patient_allergies = patient_data.get("allergies", [])
 
 
 
 
 
384
  for call in valid_tool_calls_for_execution:
385
  if call['name'] == 'check_drug_interactions':
386
  if 'args' not in call: call['args'] = {}
387
+ call['args']['current_medications'] = patient_meds_full; call['args']['allergies'] = patient_allergies; print(f"Augmented interaction check args for call ID {call['id']}")
 
 
 
 
 
 
388
  if valid_tool_calls_for_execution:
389
  print(f"Attempting to execute {len(valid_tool_calls_for_execution)} tools: {[c['name'] for c in valid_tool_calls_for_execution]}")
390
  try:
 
392
  for call, resp in zip(valid_tool_calls_for_execution, responses):
393
  tool_call_id = call['id']; tool_name = call['name']
394
  if isinstance(resp, Exception):
395
+ error_type = type(resp).__name__; error_str = str(resp); print(f"ERROR executing tool '{tool_name}' (ID: {tool_call_id}): {error_type} - {error_str}"); traceback.print_exc()
396
+ st.error(f"Error executing action '{tool_name}': {error_type}"); error_content = json.dumps({"status": "error", "message": f"Failed to execute '{tool_name}': {error_type} - {error_str}"})
 
 
 
397
  tool_messages.append(ToolMessage(content=error_content, tool_call_id=tool_call_id, name=tool_name))
398
+ if isinstance(resp, AttributeError) and "'dict' object has no attribute 'tool'" in error_str: print("\n *** DETECTED SPECIFIC ATTRIBUTE ERROR ('dict' object has no attribute 'tool') *** \n")
 
399
  else:
400
+ print(f"Tool '{tool_name}' (ID: {tool_call_id}) executed successfully."); content_str = str(resp); tool_messages.append(ToolMessage(content=content_str, tool_call_id=tool_call_id, name=tool_name))
 
 
401
  except Exception as e:
402
+ print(f"CRITICAL UNEXPECTED ERROR within tool_node logic: {type(e).__name__} - {e}"); traceback.print_exc(); st.error(f"Critical internal error processing actions: {e}")
403
+ error_content = json.dumps({"status": "error", "message": f"Internal error processing tools: {e}"}); processed_ids = {msg.tool_call_id for msg in tool_messages}
 
 
404
  for call in valid_tool_calls_for_execution:
405
  if call['id'] not in processed_ids: tool_messages.append(ToolMessage(content=error_content, tool_call_id=call['id'], name=call['name']))
406
+ print(f"Returning {len(tool_messages)} tool messages."); return {"messages": tool_messages}
407
 
408
+ # --- Graph Edges (Routing Logic) ---
 
 
 
 
409
  def should_continue(state: AgentState) -> str:
410
+ print("\n---ROUTING DECISION---"); last_message = state['messages'][-1] if state['messages'] else None
 
 
411
  if not isinstance(last_message, AIMessage): return "end_conversation_turn"
412
  if "Sorry, an internal error occurred" in last_message.content: return "end_conversation_turn"
413
  if getattr(last_message, 'tool_calls', None): return "continue_tools"
414
  else: return "end_conversation_turn"
415
 
416
+ # --- Graph Definition & Compilation ---
417
+ workflow = StateGraph(AgentState); workflow.add_node("agent", agent_node); workflow.add_node("tools", tool_node)
418
+ workflow.set_entry_point("agent"); workflow.add_conditional_edges("agent", should_continue, {"continue_tools": "tools", "end_conversation_turn": END})
419
+ workflow.add_edge("tools", "agent"); app = workflow.compile(); print("LangGraph compiled successfully.")
 
 
 
 
 
420
 
421
  # --- Streamlit UI ---
422
  def main():
423
  st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
424
  st.title(f"🩺 {ClinicalAppSettings.APP_TITLE}")
425
  st.caption(f"Interactive Assistant | LangGraph/Groq/Tavily/UMLS/OpenFDA | Model: {ClinicalAppSettings.MODEL_NAME}")
 
 
426
  if "messages" not in st.session_state: st.session_state.messages = []
427
  if "patient_data" not in st.session_state: st.session_state.patient_data = None
428
  if "graph_app" not in st.session_state: st.session_state.graph_app = app
429
 
430
+ # --- Patient Data Input Sidebar ---
431
  with st.sidebar:
432
  st.header("πŸ“„ Patient Intake Form")
433
+ # Input fields (Demographics, HPI, History, Meds/Allergies, Social/Family, Vitals/Exam)
434
+ st.subheader("Demographics"); age = st.number_input("Age", 0, 120, 55); sex = st.selectbox("Sex", ["Male", "Female", "Other"])
435
+ st.subheader("HPI"); chief_complaint = st.text_input("Chief Complaint", "Chest pain"); hpi_details = st.text_area("HPI Details", "55 y/o male...", height=150); symptoms = st.multiselect("Symptoms", ["Nausea", "Diaphoresis", "SOB", "Dizziness"], default=["Nausea", "Diaphoresis"])
436
+ st.subheader("History"); pmh = st.text_area("PMH", "HTN, HLD, DM2, MI"); psh = st.text_area("PSH", "Appendectomy")
437
+ st.subheader("Meds & Allergies"); current_meds_str = st.text_area("Current Meds", "Lisinopril 10mg daily\nMetformin 1000mg BID\nAtorvastatin 40mg daily"); allergies_str = st.text_area("Allergies", "Penicillin (rash), Sulfa")
438
+ st.subheader("Social/Family"); social_history = st.text_area("SH", "Smoker"); family_history = st.text_area("FHx", "Father MI")
439
+ st.subheader("Vitals & Exam"); col1, col2 = st.columns(2);
440
+ with col1: temp_c = st.number_input("Temp C", 35.0, 42.0, 36.8, format="%.1f"); hr_bpm = st.number_input("HR", 30, 250, 95); rr_rpm = st.number_input("RR", 5, 50, 18)
441
+ with col2: bp_mmhg = st.text_input("BP", "155/90"); spo2_percent = st.number_input("SpO2", 70, 100, 96); pain_scale = st.slider("Pain", 0, 10, 8)
442
+ exam_notes = st.text_area("Exam Notes", "Awake, alert...", height=100)
443
+
444
+ if st.button("Start/Update Consultation"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
446
+ current_med_names_only = [];
447
+ for med in current_meds_list: match = re.match(r"^\s*([a-zA-Z\-]+)", med);
448
+ if match: current_med_names_only.append(match.group(1).lower())
 
 
 
 
449
  allergies_list = []
450
  for a in allergies_str.split(','):
451
+ cleaned_allergy = a.strip();
452
+ if cleaned_allergy: match = re.match(r"^\s*([a-zA-Z\-\s/]+)(?:\s*\(.*\))?", cleaned_allergy); name_part = match.group(1).strip().lower() if match else cleaned_allergy.lower(); allergies_list.append(name_part)
453
+ st.session_state.patient_data = { "demographics": {"age": age, "sex": sex}, "hpi": {"chief_complaint": chief_complaint, "details": hpi_details, "symptoms": symptoms}, "pmh": {"conditions": pmh}, "psh": {"procedures": psh}, "medications": {"current": current_meds_list, "names_only": current_med_names_only}, "allergies": allergies_list, "social_history": {"details": social_history}, "family_history": {"details": family_history}, "vitals": { "temp_c": temp_c, "hr_bpm": hr_bpm, "bp_mmhg": bp_mmhg, "rr_rpm": rr_rpm, "spo2_percent": spo2_percent, "pain_scale": pain_scale}, "exam_findings": {"notes": exam_notes} }
454
+ red_flags = check_red_flags(st.session_state.patient_data); st.sidebar.markdown("---")
455
+ if red_flags: st.sidebar.warning("**Initial Red Flags:**"); [st.sidebar.warning(f"- {flag.replace('Red Flag: ','')}") for flag in red_flags]
456
+ else: st.sidebar.success("No immediate red flags.")
457
+ initial_prompt = "Initiate consultation for the patient described in the intake form. Review data and begin analysis."
458
+ st.session_state.messages = [HumanMessage(content=initial_prompt)]; st.success("Patient data loaded/updated.")
459
+
460
+ # --- Main Chat Interface Area ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  st.header("πŸ’¬ Clinical Consultation")
462
+ # Display loop - REMOVED key= ARGUMENT
463
+ for msg in st.session_state.messages:
 
 
 
464
  if isinstance(msg, HumanMessage):
465
+ with st.chat_message("user"): st.markdown(msg.content) # No key
466
  elif isinstance(msg, AIMessage):
467
+ with st.chat_message("assistant"): # No key
468
  ai_content = msg.content; structured_output = None
469
  try:
470
  json_match = re.search(r"```json\s*(\{.*?\})\s*```", ai_content, re.DOTALL | re.IGNORECASE)
 
477
  structured_output = json.loads(ai_content); ai_content = ""
478
  else: st.markdown(ai_content)
479
  except Exception as e: st.markdown(ai_content); print(f"Error parsing/displaying AI JSON: {e}")
 
480
  if structured_output and isinstance(structured_output, dict):
481
+ st.divider(); st.subheader("πŸ“Š AI Analysis & Recommendations") # Display logic for JSON...
 
482
  cols = st.columns(2)
483
+ with cols[0]: # Assessment, DDx, Risk
484
+ st.markdown("**Assessment:**"); st.markdown(f"> {structured_output.get('assessment', 'N/A')}")
485
+ st.markdown("**Differential Diagnosis:**"); ddx = structured_output.get('differential_diagnosis', []);
486
+ if ddx: [st.expander(f"{'πŸ₯‡πŸ₯ˆπŸ₯‰'[('High','Medium','Low').index(item.get('likelihood','Low')[0])] if item.get('likelihood','?')[0] in 'HML' else '?'} {item.get('diagnosis', 'Unknown')} ({item.get('likelihood','?')})").write(f"**Rationale:** {item.get('rationale', 'N/A')}") for item in ddx]
 
 
 
 
487
  else: st.info("No DDx provided.")
488
+ st.markdown("**Risk Assessment:**"); risk = structured_output.get('risk_assessment', {}); flags=risk.get('identified_red_flags',[]); concerns=risk.get("immediate_concerns",[]); comps=risk.get("potential_complications",[])
 
489
  if flags: st.warning(f"**Flags:** {', '.join(flags)}")
490
  if concerns: st.warning(f"**Concerns:** {', '.join(concerns)}")
491
  if comps: st.info(f"**Potential Complications:** {', '.join(comps)}")
492
  if not flags and not concerns: st.success("No major risks highlighted.")
493
+ with cols[1]: # Plan
494
+ st.markdown("**Recommended Plan:**"); plan = structured_output.get('recommended_plan', {})
495
+ for section in ["investigations","therapeutics","consultations","patient_education"]: st.markdown(f"_{section.replace('_',' ').capitalize()}:_"); items = plan.get(section); [st.markdown(f"- {item}") for item in items] if items and isinstance(items, list) else (st.markdown(f"- {items}") if items else st.markdown("_None_")); st.markdown("")
496
+ st.markdown("**Rationale & Guideline Check:**"); st.markdown(f"> {structured_output.get('rationale_summary', 'N/A')}")
 
 
 
 
497
  interaction_summary = structured_output.get("interaction_check_summary", "")
498
+ if interaction_summary: st.markdown("**Interaction Check Summary:**"); st.markdown(f"> {interaction_summary}")
499
  st.divider()
 
500
  if getattr(msg, 'tool_calls', None):
501
+ with st.expander("πŸ› οΈ AI requested actions", expanded=False): # Tool call display logic...
502
+ for tc in msg.tool_calls: try: st.code(f"Action: {tc.get('name', 'Unknown')}\nArgs: {json.dumps(tc.get('args', {}), indent=2)}", language="json")
503
+ except Exception as display_e: st.error(f"Could not display tool call: {display_e}"); st.code(str(tc))
 
 
504
  elif isinstance(msg, ToolMessage):
505
  tool_name_display = getattr(msg, 'name', 'tool_execution')
506
+ with st.chat_message(tool_name_display, avatar="πŸ› οΈ"): # No key
507
+ try: # Tool message display logic...
508
+ tool_data = json.loads(msg.content); status = tool_data.get("status", "info"); message = tool_data.get("message", msg.content); details = tool_data.get("details"); warnings = tool_data.get("warnings")
 
 
509
  if status == "success" or status == "clear" or status == "flagged": st.success(f"{message}", icon="βœ…" if status != "flagged" else "🚨")
510
+ elif status == "warning": st.warning(f"{message}", icon="⚠️");
511
+ if warnings and isinstance(warnings, list): st.caption("Details:"); [st.caption(f"- {warn}") for warn in warnings]
 
 
 
512
  else: st.error(f"{message}", icon="❌")
513
  if details: st.caption(f"Details: {details}")
514
+ except json.JSONDecodeError: st.info(f"{msg.content}")
515
  except Exception as e: st.error(f"Error displaying tool message: {e}", icon="❌"); st.caption(f"Raw content: {msg.content}")
516
 
517
+ # --- Chat Input Logic ---
 
518
  if prompt := st.chat_input("Your message or follow-up query..."):
519
+ if not st.session_state.patient_data: st.warning("Please load patient data first."); st.stop()
520
+ user_message = HumanMessage(content=prompt); st.session_state.messages.append(user_message)
521
+ with st.chat_message("user"): st.markdown(prompt) # Display user msg immediately
 
 
 
 
522
  current_state = AgentState(messages=st.session_state.messages, patient_data=st.session_state.patient_data)
 
523
  with st.spinner("SynapseAI is thinking..."):
524
  try:
525
  final_state = st.session_state.graph_app.invoke(current_state, {"recursion_limit": 15})
526
+ st.session_state.messages = final_state['messages'] # Update state with results
527
+ except Exception as e: print(f"CRITICAL ERROR: {e}"); traceback.print_exc(); st.error(f"Error: {e}")
 
 
 
 
 
 
528
  st.rerun() # Refresh display
529
 
530
+ # Disclaimer
531
+ st.markdown("---"); st.warning("**Disclaimer:** SynapseAI is for demonstration...")
 
532
 
533
  if __name__ == "__main__":
534
  main()