|
import os |
|
import re |
|
import json |
|
import traceback |
|
import requests |
|
from functools import lru_cache |
|
from typing import Any, Dict, List, Optional, TypedDict, Annotated |
|
|
|
from langchain_groq import ChatGroq |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage |
|
from langchain_core.pydantic_v1 import BaseModel, Field |
|
from langchain_core.tools import tool |
|
from langgraph.prebuilt import ToolExecutor |
|
from langgraph.graph import StateGraph, END |
|
|
|
|
|
UMLS_API_KEY = os.environ.get("UMLS_API_KEY") |
|
GROQ_API_KEY = os.environ.get("GROQ_API_KEY") |
|
TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY") |
|
|
|
|
|
AGENT_MODEL_NAME = "llama3-70b-8192" |
|
AGENT_TEMPERATURE = 0.1 |
|
MAX_SEARCH_RESULTS = 3 |
|
|
|
|
|
class ClinicalPrompts: |
|
""" |
|
Comprehensive system prompt defining SynapseAI behavior. |
|
""" |
|
SYSTEM_PROMPT = ( |
|
""" |
|
You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation. |
|
Your goal is to support healthcare professionals by analyzing patient data, |
|
providing differential diagnoses, suggesting evidence-based management plans, |
|
and identifying risks according to current standards of care. |
|
|
|
**Core Directives for this Conversation:** |
|
1. **Analyze Sequentially:** Process information turn-by-turn. Base your responses on the *entire* conversation history. |
|
2. **Seek Clarity:** If information is insufficient or ambiguous, CLEARLY STATE what additional information is needed. Do NOT guess. |
|
3. **Structured Assessment (When Ready):** When sufficient information is available, provide a comprehensive assessment |
|
using the specified JSON structure. Output this JSON as the primary content. |
|
4. **Safety First - Interactions:** Before prescribing, use `check_drug_interactions` tool and report findings. |
|
5. **Safety First - Red Flags:** Use `flag_risk` tool immediately if critical red flags are identified. |
|
6. **Tool Use:** Employ tools (`order_lab_test`, `prescribe_medication`, `check_drug_interactions`, |
|
`flag_risk`, `tavily_search_results`) logically within the flow. |
|
7. **Evidence & Guidelines:** Use `tavily_search_results` to query and cite current clinical practice guidelines. |
|
8. **Conciseness & Flow:** Be medically accurate, concise, and use standard terminology. |
|
""" |
|
) |
|
|
|
|
|
RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST" |
|
OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json" |
|
|
|
|
|
@lru_cache(maxsize=256) |
|
def get_rxcui(drug_name: str) -> Optional[str]: |
|
""" |
|
Retrieve RxCUI for a given drug name via RxNorm API. |
|
""" |
|
if not drug_name or not isinstance(drug_name, str): |
|
return None |
|
|
|
name = drug_name.strip() |
|
if not name: |
|
return None |
|
|
|
try: |
|
|
|
params = {"name": name, "search": 1} |
|
res = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10) |
|
res.raise_for_status() |
|
data = res.json() |
|
|
|
ids = data.get("idGroup", {}).get("rxnormId") |
|
if ids: |
|
return ids[0] |
|
|
|
|
|
params = {"name": name} |
|
res = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10) |
|
res.raise_for_status() |
|
data = res.json() |
|
|
|
for group in data.get("drugGroup", {}).get("conceptGroup", []): |
|
if group.get("tty") in ["SBD", "SCD", "GPCK", "BPCK", "IN", "MIN", "PIN"]: |
|
props = group.get("conceptProperties") or [] |
|
if props: |
|
return props[0].get("rxcui") |
|
|
|
except Exception: |
|
pass |
|
|
|
return None |
|
|
|
@lru_cache(maxsize=128) |
|
def get_openfda_label( |
|
rxcui: Optional[str] = None, |
|
drug_name: Optional[str] = None |
|
) -> Optional[dict]: |
|
""" |
|
Fetch drug label info from OpenFDA using RxCUI or drug name. |
|
""" |
|
if not (rxcui or drug_name): |
|
return None |
|
|
|
query_parts: List[str] = [] |
|
if rxcui: |
|
query_parts.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"') |
|
if drug_name: |
|
name_lower = drug_name.lower() |
|
query_parts.append( |
|
f'(openfda.brand_name:"{name_lower}" OR openfda.generic_name:"{name_lower}")' |
|
) |
|
|
|
search_query = " OR ".join(query_parts) |
|
params = {"search": search_query, "limit": 1} |
|
|
|
try: |
|
res = requests.get(OPENFDA_API_BASE, params=params, timeout=15) |
|
res.raise_for_status() |
|
data = res.json() |
|
results = data.get("results") or [] |
|
if results: |
|
return results[0] |
|
except Exception: |
|
pass |
|
|
|
return None |
|
|
|
|
|
def search_text_list( |
|
text_list: Optional[List[str]], |
|
search_terms: List[str] |
|
) -> List[str]: |
|
""" |
|
Case-insensitive search for terms in text_list; returns highlighted snippets. |
|
""" |
|
snippets: List[str] = [] |
|
if not text_list or not search_terms: |
|
return snippets |
|
|
|
lower_terms = [t.lower() for t in search_terms if t] |
|
|
|
for text in text_list: |
|
if not isinstance(text, str): |
|
continue |
|
|
|
text_lower = text.lower() |
|
for term in lower_terms: |
|
idx = text_lower.find(term) |
|
if idx != -1: |
|
start = max(0, idx - 50) |
|
end = min(len(text), idx + len(term) + 100) |
|
snippet = text[start:end] |
|
snippet = re.sub( |
|
f"({re.escape(term)})", |
|
r"**\1**", |
|
snippet, |
|
flags=re.IGNORECASE |
|
) |
|
snippets.append(f"...{snippet}...") |
|
break |
|
|
|
return snippets |
|
|
|
|
|
|
|
def parse_bp(bp_string: str) -> Optional[tuple[int, int]]: |
|
""" |
|
Parse a blood pressure string like '120/80' into (systolic, diastolic). |
|
""" |
|
if not isinstance(bp_string, str): |
|
return None |
|
|
|
match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string.strip()) |
|
if match: |
|
return int(match.group(1)), int(match.group(2)) |
|
|
|
return None |
|
|
|
|
|
def check_red_flags(patient_data: dict) -> List[str]: |
|
""" |
|
Evaluate patient_data for predefined red flags; return unique list. |
|
""" |
|
flags: List[str] = [] |
|
if not patient_data: |
|
return flags |
|
|
|
symptoms = [s.lower() for s in patient_data.get("hpi", {}).get("symptoms", [])] |
|
vitals = patient_data.get("vitals", {}) |
|
history = patient_data.get("pmh", {}).get("conditions", "").lower() |
|
|
|
|
|
symptom_flags = { |
|
"chest pain": "Chest Pain reported", |
|
"shortness of breath": "Shortness of Breath reported", |
|
"severe headache": "Severe Headache reported", |
|
"sudden vision loss": "Sudden Vision Loss reported", |
|
"weakness on one side": "Unilateral Weakness reported (potential stroke)", |
|
"hemoptysis": "Hemoptysis (coughing up blood)", |
|
"syncope": "Syncope (fainting)" |
|
} |
|
for key, desc in symptom_flags.items(): |
|
if key in symptoms: |
|
flags.append(f"Red Flag: {desc}.") |
|
|
|
|
|
temp = vitals.get("temp_c") |
|
hr = vitals.get("hr_bpm") |
|
rr = vitals.get("rr_rpm") |
|
spo2 = vitals.get("spo2_percent") |
|
bp_str = vitals.get("bp_mmhg") |
|
|
|
if temp is not None and temp >= 38.5: |
|
flags.append(f"Red Flag: Fever ({temp}°C).") |
|
if hr is not None: |
|
if hr >= 120: |
|
flags.append(f"Red Flag: Tachycardia ({hr} bpm).") |
|
if hr <= 50: |
|
flags.append(f"Red Flag: Bradycardia ({hr} bpm).") |
|
if rr is not None and rr >= 24: |
|
flags.append(f"Red Flag: Tachypnea ({rr} rpm).") |
|
if spo2 is not None and spo2 <= 92: |
|
flags.append(f"Red Flag: Hypoxia ({spo2}%).") |
|
|
|
if bp_str: |
|
parsed = parse_bp(bp_str) |
|
if parsed: |
|
sys, dia = parsed |
|
if sys >= 180 or dia >= 110: |
|
flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {bp_str} mmHg).") |
|
if sys <= 90 or dia <= 60: |
|
flags.append(f"Red Flag: Hypotension (BP: {bp_str} mmHg).") |
|
|
|
|
|
if "history of mi" in history and "chest pain" in symptoms: |
|
flags.append("Red Flag: History of MI with current Chest Pain.") |
|
if "history of dvt/pe" in history and "shortness of breath" in symptoms: |
|
flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.") |
|
|
|
return list(set(flags)) |
|
|
|
|
|
def format_patient_data_for_prompt(data: dict) -> str: |
|
""" |
|
Convert patient data dict into a formatted string for LLM prompts. |
|
""" |
|
if not data: |
|
return "No patient data provided." |
|
|
|
lines: List[str] = [] |
|
for section, content in data.items(): |
|
title = section.replace('_', ' ').title() |
|
|
|
if isinstance(content, dict) and any(content.values()): |
|
lines.append(f"**{title}:**") |
|
for key, val in content.items(): |
|
if val: |
|
key_title = key.replace('_', ' ').title() |
|
lines.append(f" - {key_title}: {val}") |
|
elif isinstance(content, list) and content: |
|
lines.append(f"**{title}:** {', '.join(map(str, content))}") |
|
elif content: |
|
lines.append(f"**{title}:** {content}") |
|
|
|
return "\n".join(lines) |
|
|
|
|
|
class LabOrderInput(BaseModel): |
|
test_name: str = Field(...) |
|
reason: str = Field(...) |
|
priority: str = Field("Routine") |
|
|
|
class PrescriptionInput(BaseModel): |
|
medication_name: str = Field(...) |
|
dosage: str = Field(...) |
|
route: str = Field(...) |
|
frequency: str = Field(...) |
|
duration: str = Field("As directed") |
|
reason: str = Field(...) |
|
|
|
class InteractionCheckInput(BaseModel): |
|
potential_prescription: str = Field(...) |
|
current_medications: Optional[List[str]] = Field(None) |
|
allergies: Optional[List[str]] = Field(None) |
|
|
|
class FlagRiskInput(BaseModel): |
|
risk_description: str = Field(...) |
|
urgency: str = Field("High") |
|
|
|
|
|
@tool("order_lab_test", args_schema=LabOrderInput) |
|
def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str: |
|
""" |
|
Place a lab order with given test_name, reason, and priority. |
|
""" |
|
return json.dumps({ |
|
"status": "success", |
|
"message": f"Lab Ordered: {test_name} ({priority})", |
|
"details": f"Reason: {reason}" |
|
}) |
|
|
|
@tool("prescribe_medication", args_schema=PrescriptionInput) |
|
def prescribe_medication( |
|
medication_name: str, |
|
dosage: str, |
|
route: str, |
|
frequency: str, |
|
duration: str, |
|
reason: str |
|
) -> str: |
|
""" |
|
Prepare a prescription with dosage, route, frequency, and duration. |
|
""" |
|
return json.dumps({ |
|
"status": "success", |
|
"message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", |
|
"details": f"Duration: {duration}. Reason: {reason}" |
|
}) |
|
|
|
@tool("check_drug_interactions", args_schema=InteractionCheckInput) |
|
def check_drug_interactions( |
|
potential_prescription: str, |
|
current_medications: Optional[List[str]] = None, |
|
allergies: Optional[List[str]] = None |
|
) -> str: |
|
""" |
|
Check for allergy and drug-drug interactions using RxNorm and OpenFDA. |
|
""" |
|
warnings: List[str] = [] |
|
med_lower = potential_prescription.lower().strip() |
|
|
|
|
|
current = [ |
|
re.match(r"^\s*([a-zA-Z\-]+)", m).group(1).lower() |
|
for m in (current_medications or []) |
|
if re.match(r"^\s*([a-zA-Z\-]+)", m) |
|
] |
|
allergy_list = [a.lower().strip() for a in (allergies or [])] |
|
|
|
|
|
rxcui = get_rxcui(potential_prescription) |
|
label = get_openfda_label(rxcui=rxcui, drug_name=potential_prescription) |
|
if not (rxcui or label): |
|
warnings.append(f"INFO: Could not identify '{potential_prescription}'.") |
|
|
|
|
|
for alg in allergy_list: |
|
if alg == med_lower: |
|
warnings.append(f"CRITICAL ALLERGY: Patient allergic to '{alg}'.") |
|
|
|
|
|
|
|
if label: |
|
for field in (label.get("contraindications") or [], label.get("warnings_and_cautions") or []): |
|
snippets = search_text_list(field, allergy_list) |
|
if snippets: |
|
warnings.append( |
|
f"Label Allergy Risk: {', '.join(snippets)}" |
|
) |
|
|
|
|
|
if rxcui or label: |
|
for cm in current: |
|
if cm == med_lower: |
|
continue |
|
cm_rxcui = get_rxcui(cm) |
|
cm_label = get_openfda_label(rxcui=cm_rxcui, drug_name=cm) |
|
|
|
|
|
status = ( |
|
"warning" if any( |
|
w.startswith("CRITICAL") or "Interaction" in w for w in warnings |
|
) else "clear" |
|
) |
|
message = ( |
|
f"Interaction/Allergy check: {len(warnings)} issue(s) identified." |
|
if warnings else |
|
"No major interactions or allergy issues identified." |
|
) |
|
|
|
return json.dumps({"status": status, "message": message, "warnings": warnings}) |
|
|
|
@tool("flag_risk", args_schema=FlagRiskInput) |
|
def flag_risk(risk_description: str, urgency: str) -> str: |
|
""" |
|
Flag a critical risk with given description and urgency. |
|
""" |
|
return json.dumps({ |
|
"status": "flagged", |
|
"message": f"Risk '{risk_description}' flagged with {urgency} urgency." |
|
}) |
|
|
|
|
|
search_tool = TavilySearchResults( |
|
max_results=MAX_SEARCH_RESULTS, |
|
name="tavily_search_results" |
|
) |
|
all_tools = [ |
|
order_lab_test, |
|
prescribe_medication, |
|
check_drug_interactions, |
|
flag_risk, |
|
search_tool |
|
] |
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[List[Any], None] |
|
patient_data: Optional[dict] |
|
summary: Optional[str] |
|
interaction_warnings: Optional[List[str]] |
|
|
|
|
|
llm = ChatGroq( |
|
temperature=AGENT_TEMPERATURE, |
|
model=AGENT_MODEL_NAME |
|
) |
|
model_with_tools = llm.bind_tools(all_tools) |
|
tool_executor = ToolExecutor(all_tools) |
|
|
|
|
|
|
|
def agent_node(state: AgentState) -> Dict[str, Any]: |
|
""" |
|
Primary agent node: sends messages to LLM and returns its response. |
|
""" |
|
messages = state.get("messages", []) |
|
if not messages or not isinstance(messages[0], SystemMessage): |
|
messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + messages |
|
|
|
try: |
|
response = model_with_tools.invoke(messages) |
|
return {"messages": [response]} |
|
except Exception as e: |
|
err = AIMessage(content=f"Error: {e}") |
|
return {"messages": [err]} |
|
|
|
|
|
def tool_node(state: AgentState) -> Dict[str, Any]: |
|
""" |
|
Executes any pending tool calls from the last AIMessage. |
|
""" |
|
last = state['messages'][-1] |
|
if not isinstance(last, AIMessage) or not getattr(last, 'tool_calls', None): |
|
return {"messages": [], "interaction_warnings": None} |
|
|
|
calls = last.tool_calls |
|
|
|
blocked = set() |
|
for call in calls: |
|
if call['name'] == 'prescribe_medication': |
|
|
|
med = call['args'].get('medication_name', '').lower() |
|
if med not in {c['args'].get('potential_prescription', '').lower() for c in calls if c['name']=='check_drug_interactions'}: |
|
blocked.add(call['id']) |
|
msg = ToolMessage( |
|
content=json.dumps({ |
|
"status": "error", |
|
"message": f"Interaction check needed for '{med}'." |
|
}), |
|
tool_call_id=call['id'], |
|
name=call['name'] |
|
) |
|
|
|
calls.append(msg) |
|
|
|
|
|
patient = state.get('patient_data', {}) |
|
for call in calls: |
|
if call['name'] == 'check_drug_interactions': |
|
call['args']['current_medications'] = patient.get('medications', {}).get('current', []) |
|
call['args']['allergies'] = patient.get('allergies', []) |
|
|
|
|
|
to_execute = [c for c in calls if c['id'] not in blocked] |
|
results: List[ToolMessage] = [] |
|
warnings: List[str] = [] |
|
|
|
try: |
|
responses = tool_executor.batch(to_execute, return_exceptions=True) |
|
for call, resp in zip(to_execute, responses): |
|
if isinstance(resp, Exception): |
|
err_msg = ToolMessage( |
|
content=json.dumps({"status": "error", "message": str(resp)}), |
|
tool_call_id=call['id'], |
|
name=call['name'] |
|
) |
|
results.append(err_msg) |
|
else: |
|
tm = ToolMessage( |
|
content=str(resp), |
|
tool_call_id=call['id'], |
|
name=call['name'] |
|
) |
|
results.append(tm) |
|
if call['name'] == 'check_drug_interactions': |
|
data = json.loads(str(resp)) |
|
if data.get('warnings'): |
|
warnings.extend(data['warnings']) |
|
except Exception as e: |
|
err = ToolMessage( |
|
content=json.dumps({"status": "error", "message": str(e)}), |
|
tool_call_id=None, |
|
name="tool_executor" |
|
) |
|
results.append(err) |
|
|
|
return {"messages": results, "interaction_warnings": warnings or None} |
|
|
|
|
|
def reflection_node(state: AgentState) -> Dict[str, Any]: |
|
""" |
|
Safety reflection: reviews interaction warnings and revises plan. |
|
""" |
|
warnings = state.get('interaction_warnings') |
|
if not warnings: |
|
return {"messages": [], "interaction_warnings": None} |
|
|
|
|
|
trigger_id = None |
|
for msg in reversed(state['messages']): |
|
if isinstance(msg, ToolMessage) and msg.name == 'check_drug_interactions': |
|
trigger_id = msg.tool_call_id |
|
break |
|
|
|
if trigger_id is None: |
|
err = AIMessage(content="Internal Error: Reflection context missing.") |
|
return {"messages": [err], "interaction_warnings": None} |
|
|
|
|
|
prompt = ( |
|
f"You are SynapseAI performing a critical safety review." |
|
f"\nWarnings:\n```json\n{json.dumps(warnings, indent=2)}\n```" |
|
"\n**Revise therapeutics based on these warnings.**" |
|
) |
|
messages = [ |
|
SystemMessage(content="Perform focused safety review based on interaction warnings."), |
|
HumanMessage(content=prompt) |
|
] |
|
|
|
try: |
|
response = llm.invoke(messages) |
|
return {"messages": [AIMessage(content=response.content)], "interaction_warnings": None} |
|
except Exception as e: |
|
err = AIMessage(content=f"Error during safety reflection: {e}") |
|
return {"messages": [err], "interaction_warnings": None} |
|
|
|
|
|
|
|
def should_continue(state: AgentState) -> str: |
|
last = state['messages'][-1] if state['messages'] else None |
|
if not isinstance(last, AIMessage) or 'error' in last.content.lower(): |
|
return 'end_conversation_turn' |
|
if getattr(last, 'tool_calls', None): |
|
return 'continue_tools' |
|
return 'end_conversation_turn' |
|
|
|
|
|
def after_tools_router(state: AgentState) -> str: |
|
if state.get('interaction_warnings'): |
|
return 'reflect_on_warnings' |
|
return 'continue_to_agent' |
|
|
|
|
|
class ClinicalAgent: |
|
def __init__(self): |
|
graph = StateGraph(AgentState) |
|
graph.add_node('agent', agent_node) |
|
graph.add_node('tools', tool_node) |
|
graph.add_node('reflection', reflection_node) |
|
|
|
graph.set_entry_point('agent') |
|
graph.add_conditional_edges( |
|
'agent', should_continue, |
|
{'continue_tools': 'tools', 'end_conversation_turn': END} |
|
) |
|
graph.add_conditional_edges( |
|
'tools', after_tools_router, |
|
{'reflect_on_warnings': 'reflection', 'continue_to_agent': 'agent'} |
|
) |
|
graph.add_edge('reflection', 'agent') |
|
|
|
self.graph_app = graph.compile() |
|
|
|
def invoke_turn(self, state: Dict[str, Any]) -> Dict[str, Any]: |
|
try: |
|
result = self.graph_app.invoke(state, {'recursion_limit': 15}) |
|
result.setdefault('summary', state.get('summary')) |
|
result.setdefault('interaction_warnings', None) |
|
return result |
|
except Exception as e: |
|
err = AIMessage(content=f"Sorry, a critical error occurred: {e}") |
|
return { |
|
'messages': state.get('messages', []) + [err], |
|
'patient_data': state.get('patient_data'), |
|
'summary': state.get('summary'), |
|
'interaction_warnings': None |
|
} |
|
|