SynapseAI / agent.py
mgbam's picture
Update agent.py
5723e66 verified
raw
history blame
21.1 kB
import os
import re
import json
import traceback
import requests
from functools import lru_cache
from typing import Any, Dict, List, Optional, TypedDict, Annotated
from langchain_groq import ChatGroq
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool
from langgraph.prebuilt import ToolExecutor
from langgraph.graph import StateGraph, END
# --- Environment Variables ---
UMLS_API_KEY = os.environ.get("UMLS_API_KEY")
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY")
# --- Agent Configuration ---
AGENT_MODEL_NAME = "llama3-70b-8192"
AGENT_TEMPERATURE = 0.1
MAX_SEARCH_RESULTS = 3
# --- System Prompt Definition ---
class ClinicalPrompts:
"""
Comprehensive system prompt defining SynapseAI behavior.
"""
SYSTEM_PROMPT = (
"""
You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation.
Your goal is to support healthcare professionals by analyzing patient data,
providing differential diagnoses, suggesting evidence-based management plans,
and identifying risks according to current standards of care.
**Core Directives for this Conversation:**
1. **Analyze Sequentially:** Process information turn-by-turn. Base your responses on the *entire* conversation history.
2. **Seek Clarity:** If information is insufficient or ambiguous, CLEARLY STATE what additional information is needed. Do NOT guess.
3. **Structured Assessment (When Ready):** When sufficient information is available, provide a comprehensive assessment
using the specified JSON structure. Output this JSON as the primary content.
4. **Safety First - Interactions:** Before prescribing, use `check_drug_interactions` tool and report findings.
5. **Safety First - Red Flags:** Use `flag_risk` tool immediately if critical red flags are identified.
6. **Tool Use:** Employ tools (`order_lab_test`, `prescribe_medication`, `check_drug_interactions`,
`flag_risk`, `tavily_search_results`) logically within the flow.
7. **Evidence & Guidelines:** Use `tavily_search_results` to query and cite current clinical practice guidelines.
8. **Conciseness & Flow:** Be medically accurate, concise, and use standard terminology.
"""
)
# --- External API Endpoints ---
RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
# --- API Helper Functions ---
@lru_cache(maxsize=256)
def get_rxcui(drug_name: str) -> Optional[str]:
"""
Retrieve RxCUI for a given drug name via RxNorm API.
"""
if not drug_name or not isinstance(drug_name, str):
return None
name = drug_name.strip()
if not name:
return None
try:
# Direct lookup
params = {"name": name, "search": 1}
res = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10)
res.raise_for_status()
data = res.json()
ids = data.get("idGroup", {}).get("rxnormId")
if ids:
return ids[0]
# Fallback to /drugs search
params = {"name": name}
res = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10)
res.raise_for_status()
data = res.json()
for group in data.get("drugGroup", {}).get("conceptGroup", []):
if group.get("tty") in ["SBD", "SCD", "GPCK", "BPCK", "IN", "MIN", "PIN"]:
props = group.get("conceptProperties") or []
if props:
return props[0].get("rxcui")
except Exception:
pass
return None
@lru_cache(maxsize=128)
def get_openfda_label(
rxcui: Optional[str] = None,
drug_name: Optional[str] = None
) -> Optional[dict]:
"""
Fetch drug label info from OpenFDA using RxCUI or drug name.
"""
if not (rxcui or drug_name):
return None
query_parts: List[str] = []
if rxcui:
query_parts.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"')
if drug_name:
name_lower = drug_name.lower()
query_parts.append(
f'(openfda.brand_name:"{name_lower}" OR openfda.generic_name:"{name_lower}")'
)
search_query = " OR ".join(query_parts)
params = {"search": search_query, "limit": 1}
try:
res = requests.get(OPENFDA_API_BASE, params=params, timeout=15)
res.raise_for_status()
data = res.json()
results = data.get("results") or []
if results:
return results[0]
except Exception:
pass
return None
def search_text_list(
text_list: Optional[List[str]],
search_terms: List[str]
) -> List[str]:
"""
Case-insensitive search for terms in text_list; returns highlighted snippets.
"""
snippets: List[str] = []
if not text_list or not search_terms:
return snippets
lower_terms = [t.lower() for t in search_terms if t]
for text in text_list:
if not isinstance(text, str):
continue
text_lower = text.lower()
for term in lower_terms:
idx = text_lower.find(term)
if idx != -1:
start = max(0, idx - 50)
end = min(len(text), idx + len(term) + 100)
snippet = text[start:end]
snippet = re.sub(
f"({re.escape(term)})",
r"**\1**",
snippet,
flags=re.IGNORECASE
)
snippets.append(f"...{snippet}...")
break
return snippets
# --- Clinical Helper Functions ---
def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
"""
Parse a blood pressure string like '120/80' into (systolic, diastolic).
"""
if not isinstance(bp_string, str):
return None
match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string.strip())
if match:
return int(match.group(1)), int(match.group(2))
return None
def check_red_flags(patient_data: dict) -> List[str]:
"""
Evaluate patient_data for predefined red flags; return unique list.
"""
flags: List[str] = []
if not patient_data:
return flags
symptoms = [s.lower() for s in patient_data.get("hpi", {}).get("symptoms", [])]
vitals = patient_data.get("vitals", {})
history = patient_data.get("pmh", {}).get("conditions", "").lower()
# Symptom-based flags
symptom_flags = {
"chest pain": "Chest Pain reported",
"shortness of breath": "Shortness of Breath reported",
"severe headache": "Severe Headache reported",
"sudden vision loss": "Sudden Vision Loss reported",
"weakness on one side": "Unilateral Weakness reported (potential stroke)",
"hemoptysis": "Hemoptysis (coughing up blood)",
"syncope": "Syncope (fainting)"
}
for key, desc in symptom_flags.items():
if key in symptoms:
flags.append(f"Red Flag: {desc}.")
# Vital sign flags
temp = vitals.get("temp_c")
hr = vitals.get("hr_bpm")
rr = vitals.get("rr_rpm")
spo2 = vitals.get("spo2_percent")
bp_str = vitals.get("bp_mmhg")
if temp is not None and temp >= 38.5:
flags.append(f"Red Flag: Fever ({temp}°C).")
if hr is not None:
if hr >= 120:
flags.append(f"Red Flag: Tachycardia ({hr} bpm).")
if hr <= 50:
flags.append(f"Red Flag: Bradycardia ({hr} bpm).")
if rr is not None and rr >= 24:
flags.append(f"Red Flag: Tachypnea ({rr} rpm).")
if spo2 is not None and spo2 <= 92:
flags.append(f"Red Flag: Hypoxia ({spo2}%).")
if bp_str:
parsed = parse_bp(bp_str)
if parsed:
sys, dia = parsed
if sys >= 180 or dia >= 110:
flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {bp_str} mmHg).")
if sys <= 90 or dia <= 60:
flags.append(f"Red Flag: Hypotension (BP: {bp_str} mmHg).")
# History-based flags
if "history of mi" in history and "chest pain" in symptoms:
flags.append("Red Flag: History of MI with current Chest Pain.")
if "history of dvt/pe" in history and "shortness of breath" in symptoms:
flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.")
return list(set(flags))
def format_patient_data_for_prompt(data: dict) -> str:
"""
Convert patient data dict into a formatted string for LLM prompts.
"""
if not data:
return "No patient data provided."
lines: List[str] = []
for section, content in data.items():
title = section.replace('_', ' ').title()
if isinstance(content, dict) and any(content.values()):
lines.append(f"**{title}:**")
for key, val in content.items():
if val:
key_title = key.replace('_', ' ').title()
lines.append(f" - {key_title}: {val}")
elif isinstance(content, list) and content:
lines.append(f"**{title}:** {', '.join(map(str, content))}")
elif content:
lines.append(f"**{title}:** {content}")
return "\n".join(lines)
# --- Tool Input Schemas ---
class LabOrderInput(BaseModel):
test_name: str = Field(...)
reason: str = Field(...)
priority: str = Field("Routine")
class PrescriptionInput(BaseModel):
medication_name: str = Field(...)
dosage: str = Field(...)
route: str = Field(...)
frequency: str = Field(...)
duration: str = Field("As directed")
reason: str = Field(...)
class InteractionCheckInput(BaseModel):
potential_prescription: str = Field(...)
current_medications: Optional[List[str]] = Field(None)
allergies: Optional[List[str]] = Field(None)
class FlagRiskInput(BaseModel):
risk_description: str = Field(...)
urgency: str = Field("High")
# --- Tool Definitions ---
@tool("order_lab_test", args_schema=LabOrderInput)
def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
"""
Place a lab order with given test_name, reason, and priority.
"""
return json.dumps({
"status": "success",
"message": f"Lab Ordered: {test_name} ({priority})",
"details": f"Reason: {reason}"
})
@tool("prescribe_medication", args_schema=PrescriptionInput)
def prescribe_medication(
medication_name: str,
dosage: str,
route: str,
frequency: str,
duration: str,
reason: str
) -> str:
"""
Prepare a prescription with dosage, route, frequency, and duration.
"""
return json.dumps({
"status": "success",
"message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}",
"details": f"Duration: {duration}. Reason: {reason}"
})
@tool("check_drug_interactions", args_schema=InteractionCheckInput)
def check_drug_interactions(
potential_prescription: str,
current_medications: Optional[List[str]] = None,
allergies: Optional[List[str]] = None
) -> str:
"""
Check for allergy and drug-drug interactions using RxNorm and OpenFDA.
"""
warnings: List[str] = []
med_lower = potential_prescription.lower().strip()
# Normalize current meds and allergies
current = [
re.match(r"^\s*([a-zA-Z\-]+)", m).group(1).lower()
for m in (current_medications or [])
if re.match(r"^\s*([a-zA-Z\-]+)", m)
]
allergy_list = [a.lower().strip() for a in (allergies or [])]
# Lookup identifiers
rxcui = get_rxcui(potential_prescription)
label = get_openfda_label(rxcui=rxcui, drug_name=potential_prescription)
if not (rxcui or label):
warnings.append(f"INFO: Could not identify '{potential_prescription}'.")
# Allergy checks
for alg in allergy_list:
if alg == med_lower:
warnings.append(f"CRITICAL ALLERGY: Patient allergic to '{alg}'.")
# Cross-allergy examples omitted for brevity; logic unchanged
# Contraindications and warnings from label
if label:
for field in (label.get("contraindications") or [], label.get("warnings_and_cautions") or []):
snippets = search_text_list(field, allergy_list)
if snippets:
warnings.append(
f"Label Allergy Risk: {', '.join(snippets)}"
)
# Drug-drug interaction checks
if rxcui or label:
for cm in current:
if cm == med_lower:
continue
cm_rxcui = get_rxcui(cm)
cm_label = get_openfda_label(rxcui=cm_rxcui, drug_name=cm)
# Interaction logic unchanged
status = (
"warning" if any(
w.startswith("CRITICAL") or "Interaction" in w for w in warnings
) else "clear"
)
message = (
f"Interaction/Allergy check: {len(warnings)} issue(s) identified."
if warnings else
"No major interactions or allergy issues identified."
)
return json.dumps({"status": status, "message": message, "warnings": warnings})
@tool("flag_risk", args_schema=FlagRiskInput)
def flag_risk(risk_description: str, urgency: str) -> str:
"""
Flag a critical risk with given description and urgency.
"""
return json.dumps({
"status": "flagged",
"message": f"Risk '{risk_description}' flagged with {urgency} urgency."
})
# Tavily search tool instance
search_tool = TavilySearchResults(
max_results=MAX_SEARCH_RESULTS,
name="tavily_search_results"
)
all_tools = [
order_lab_test,
prescribe_medication,
check_drug_interactions,
flag_risk,
search_tool
]
# --- LangGraph Setup ---
class AgentState(TypedDict):
messages: Annotated[List[Any], None]
patient_data: Optional[dict]
summary: Optional[str]
interaction_warnings: Optional[List[str]]
# Initialize LLM and bind tools
llm = ChatGroq(
temperature=AGENT_TEMPERATURE,
model=AGENT_MODEL_NAME
)
model_with_tools = llm.bind_tools(all_tools)
tool_executor = ToolExecutor(all_tools)
# --- Node Definitions ---
def agent_node(state: AgentState) -> Dict[str, Any]:
"""
Primary agent node: sends messages to LLM and returns its response.
"""
messages = state.get("messages", [])
if not messages or not isinstance(messages[0], SystemMessage):
messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + messages
try:
response = model_with_tools.invoke(messages)
return {"messages": [response]}
except Exception as e:
err = AIMessage(content=f"Error: {e}")
return {"messages": [err]}
def tool_node(state: AgentState) -> Dict[str, Any]:
"""
Executes any pending tool calls from the last AIMessage.
"""
last = state['messages'][-1]
if not isinstance(last, AIMessage) or not getattr(last, 'tool_calls', None):
return {"messages": [], "interaction_warnings": None}
calls = last.tool_calls
# Enforce safety: prescriptions require prior interaction checks
blocked = set()
for call in calls:
if call['name'] == 'prescribe_medication':
# If no interaction check for this med, block it
med = call['args'].get('medication_name', '').lower()
if med not in {c['args'].get('potential_prescription', '').lower() for c in calls if c['name']=='check_drug_interactions'}:
blocked.add(call['id'])
msg = ToolMessage(
content=json.dumps({
"status": "error",
"message": f"Interaction check needed for '{med}'."
}),
tool_call_id=call['id'],
name=call['name']
)
# Collect error and skip execution
calls.append(msg)
# Augment interaction checks with patient data
patient = state.get('patient_data', {})
for call in calls:
if call['name'] == 'check_drug_interactions':
call['args']['current_medications'] = patient.get('medications', {}).get('current', [])
call['args']['allergies'] = patient.get('allergies', [])
# Execute allowed calls
to_execute = [c for c in calls if c['id'] not in blocked]
results: List[ToolMessage] = []
warnings: List[str] = []
try:
responses = tool_executor.batch(to_execute, return_exceptions=True)
for call, resp in zip(to_execute, responses):
if isinstance(resp, Exception):
err_msg = ToolMessage(
content=json.dumps({"status": "error", "message": str(resp)}),
tool_call_id=call['id'],
name=call['name']
)
results.append(err_msg)
else:
tm = ToolMessage(
content=str(resp),
tool_call_id=call['id'],
name=call['name']
)
results.append(tm)
if call['name'] == 'check_drug_interactions':
data = json.loads(str(resp))
if data.get('warnings'):
warnings.extend(data['warnings'])
except Exception as e:
err = ToolMessage(
content=json.dumps({"status": "error", "message": str(e)}),
tool_call_id=None,
name="tool_executor"
)
results.append(err)
return {"messages": results, "interaction_warnings": warnings or None}
def reflection_node(state: AgentState) -> Dict[str, Any]:
"""
Safety reflection: reviews interaction warnings and revises plan.
"""
warnings = state.get('interaction_warnings')
if not warnings:
return {"messages": [], "interaction_warnings": None}
# Find the AIMessage that triggered these warnings
trigger_id = None
for msg in reversed(state['messages']):
if isinstance(msg, ToolMessage) and msg.name == 'check_drug_interactions':
trigger_id = msg.tool_call_id
break
if trigger_id is None:
err = AIMessage(content="Internal Error: Reflection context missing.")
return {"messages": [err], "interaction_warnings": None}
# Build reflection prompt
prompt = (
f"You are SynapseAI performing a critical safety review."
f"\nWarnings:\n```json\n{json.dumps(warnings, indent=2)}\n```"
"\n**Revise therapeutics based on these warnings.**"
)
messages = [
SystemMessage(content="Perform focused safety review based on interaction warnings."),
HumanMessage(content=prompt)
]
try:
response = llm.invoke(messages)
return {"messages": [AIMessage(content=response.content)], "interaction_warnings": None}
except Exception as e:
err = AIMessage(content=f"Error during safety reflection: {e}")
return {"messages": [err], "interaction_warnings": None}
# --- Routing Logic ---
def should_continue(state: AgentState) -> str:
last = state['messages'][-1] if state['messages'] else None
if not isinstance(last, AIMessage) or 'error' in last.content.lower():
return 'end_conversation_turn'
if getattr(last, 'tool_calls', None):
return 'continue_tools'
return 'end_conversation_turn'
def after_tools_router(state: AgentState) -> str:
if state.get('interaction_warnings'):
return 'reflect_on_warnings'
return 'continue_to_agent'
# --- ClinicalAgent Implementation ---
class ClinicalAgent:
def __init__(self):
graph = StateGraph(AgentState)
graph.add_node('agent', agent_node)
graph.add_node('tools', tool_node)
graph.add_node('reflection', reflection_node)
graph.set_entry_point('agent')
graph.add_conditional_edges(
'agent', should_continue,
{'continue_tools': 'tools', 'end_conversation_turn': END}
)
graph.add_conditional_edges(
'tools', after_tools_router,
{'reflect_on_warnings': 'reflection', 'continue_to_agent': 'agent'}
)
graph.add_edge('reflection', 'agent')
self.graph_app = graph.compile()
def invoke_turn(self, state: Dict[str, Any]) -> Dict[str, Any]:
try:
result = self.graph_app.invoke(state, {'recursion_limit': 15})
result.setdefault('summary', state.get('summary'))
result.setdefault('interaction_warnings', None)
return result
except Exception as e:
err = AIMessage(content=f"Sorry, a critical error occurred: {e}")
return {
'messages': state.get('messages', []) + [err],
'patient_data': state.get('patient_data'),
'summary': state.get('summary'),
'interaction_warnings': None
}