|
import os |
|
import re |
|
import json |
|
import requests |
|
import traceback |
|
import operator |
|
from functools import lru_cache |
|
from typing import Any, Dict, List, Optional, TypedDict, Annotated |
|
|
|
from langchain_groq import ChatGroq |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage |
|
from langchain_core.pydantic_v1 import BaseModel, Field |
|
from langchain_core.tools import tool |
|
from langgraph.prebuilt import ToolExecutor |
|
from langgraph.graph import StateGraph, END |
|
|
|
|
|
|
|
UMLS_API_KEY = os.environ.get("UMLS_API_KEY") |
|
GROQ_API_KEY = os.environ.get("GROQ_API_KEY") |
|
TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY") |
|
|
|
AGENT_MODEL_NAME = "llama3-70b-8192" |
|
AGENT_TEMPERATURE = 0.1 |
|
MAX_SEARCH_RESULTS = 3 |
|
|
|
UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key" |
|
RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST" |
|
OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json" |
|
|
|
|
|
class ClinicalPrompts: |
|
SYSTEM_PROMPT = ( |
|
""" |
|
You are SynapseAI, an expert AI clinical assistant in an interactive consultation. |
|
Analyze patient data, provide differential diagnoses, suggest management plans, |
|
and identify risks according to current standards of care. |
|
|
|
1. Process information sequentially; use full conversation history. |
|
2. Ask for clarification if data is insufficient; do not guess. |
|
3. When ready, output a complete JSON assessment as specified. |
|
4. Before prescribing, run drug-interaction checks and report results. |
|
5. Flag urgent red flags immediately. |
|
6. Use tools logically; await results when needed. |
|
7. Query clinical guidelines via tavily_search_results and cite them. |
|
8. Be concise, accurate, and use standard terminology. |
|
""" |
|
) |
|
|
|
|
|
|
|
@lru_cache(maxsize=256) |
|
def get_rxcui(drug_name: str) -> Optional[str]: |
|
"""Return RxNorm CUI for a given drug name.""" |
|
if not drug_name: |
|
return None |
|
name = drug_name.strip() |
|
if not name: |
|
return None |
|
|
|
try: |
|
|
|
params = {"name": name, "search": 1} |
|
resp = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10) |
|
resp.raise_for_status() |
|
data = resp.json() |
|
ids = data.get("idGroup", {}).get("rxnormId", []) |
|
if ids: |
|
return ids[0] |
|
|
|
|
|
params = {"name": name} |
|
resp = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10) |
|
resp.raise_for_status() |
|
data = resp.json() |
|
groups = data.get("drugGroup", {}).get("conceptGroup", []) |
|
for grp in groups: |
|
if grp.get("tty") in ["SBD", "SCD", "GPCK", "BPCK", "IN", "MIN", "PIN"]: |
|
props = grp.get("conceptProperties", []) |
|
if props: |
|
return props[0].get("rxcui") |
|
except Exception: |
|
traceback.print_exc() |
|
return None |
|
|
|
|
|
@lru_cache(maxsize=128) |
|
def get_openfda_label( |
|
rxcui: Optional[str] = None, |
|
drug_name: Optional[str] = None |
|
) -> Optional[dict]: |
|
"""Fetch OpenFDA drug label by RxCUI or name.""" |
|
if not (rxcui or drug_name): |
|
return None |
|
|
|
terms = [] |
|
if rxcui: |
|
terms.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"') |
|
if drug_name: |
|
name = drug_name.lower() |
|
terms.append(f'(openfda.brand_name:"{name}" OR openfda.generic_name:"{name}")') |
|
|
|
query = " OR ".join(terms) |
|
params = {"search": query, "limit": 1} |
|
|
|
try: |
|
resp = requests.get(OPENFDA_API_BASE, params=params, timeout=15) |
|
resp.raise_for_status() |
|
data = resp.json() |
|
results = data.get("results", []) |
|
if results: |
|
return results[0] |
|
except Exception: |
|
traceback.print_exc() |
|
return None |
|
|
|
|
|
def search_text_list(texts: List[str], terms: List[str]) -> List[str]: |
|
"""Return snippets where any term appears in texts.""" |
|
snippets = [] |
|
lowers = [t.lower() for t in terms if t] |
|
for txt in texts or []: |
|
if not isinstance(txt, str): |
|
continue |
|
low_txt = txt.lower() |
|
for term in lowers: |
|
idx = low_txt.find(term) |
|
if idx >= 0: |
|
start = max(0, idx - 50) |
|
end = min(len(txt), idx + len(term) + 100) |
|
snippet = txt[start:end] |
|
snippet = re.sub( |
|
f"({re.escape(term)})", |
|
r"**\1**", |
|
snippet, |
|
count=1, |
|
flags=re.IGNORECASE, |
|
) |
|
snippets.append(f"...{snippet}...") |
|
break |
|
return snippets |
|
|
|
|
|
def parse_bp(bp_str: str) -> Optional[tuple[int, int]]: |
|
"""Parse blood pressure string 'systolic/diastolic'.""" |
|
match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_str or "") |
|
if match: |
|
return int(match.group(1)), int(match.group(2)) |
|
return None |
|
|
|
|
|
def check_red_flags(patient_data: Dict) -> List[str]: |
|
"""Identify critical red flags from patient data.""" |
|
flags = [] |
|
if not patient_data: |
|
return flags |
|
|
|
symptoms = [s.lower() for s in patient_data.get("hpi", {}).get("symptoms", [])] |
|
vitals = patient_data.get("vitals", {}) |
|
history = patient_data.get("pmh", {}).get("conditions", "").lower() |
|
|
|
|
|
mapping = { |
|
"chest pain": "Chest Pain reported.", |
|
"shortness of breath": "Shortness of Breath reported.", |
|
"severe headache": "Severe Headache reported.", |
|
"sudden vision loss": "Sudden Vision Loss reported.", |
|
"weakness on one side": "Unilateral Weakness reported (potential stroke).", |
|
"hemoptysis": "Hemoptysis (coughing up blood).", |
|
"syncope": "Syncope (fainting).", |
|
} |
|
for term, desc in mapping.items(): |
|
if term in symptoms: |
|
flags.append(f"Red Flag: {desc}") |
|
|
|
|
|
temp = vitals.get("temp_c") |
|
hr = vitals.get("hr_bpm") |
|
rr = vitals.get("rr_rpm") |
|
spo2 = vitals.get("spo2_percent") |
|
bp = parse_bp(vitals.get("bp_mmhg", "")) |
|
|
|
if temp and temp >= 38.5: |
|
flags.append(f"Red Flag: Fever ({temp}°C).") |
|
if hr: |
|
if hr >= 120: |
|
flags.append(f"Red Flag: Tachycardia ({hr} bpm).") |
|
if hr <= 50: |
|
flags.append(f"Red Flag: Bradycardia ({hr} bpm).") |
|
if rr and rr >= 24: |
|
flags.append(f"Red Flag: Tachypnea ({rr} rpm).") |
|
if spo2 and spo2 <= 92: |
|
flags.append(f"Red Flag: Hypoxia ({spo2}%).") |
|
if bp: |
|
sys, dia = bp |
|
if sys >= 180 or dia >= 110: |
|
flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {sys}/{dia} mmHg).") |
|
if sys <= 90 or dia <= 60: |
|
flags.append(f"Red Flag: Hypotension (BP: {sys}/{dia} mmHg).") |
|
|
|
|
|
if "history of mi" in history and "chest pain" in symptoms: |
|
flags.append("Red Flag: History of MI with current Chest Pain.") |
|
if "history of dvt/pe" in history and "shortness of breath" in symptoms: |
|
flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.") |
|
|
|
return list(set(flags)) |
|
|
|
|
|
def format_patient_data_for_prompt(data: Dict) -> str: |
|
"""Convert patient data dict into a human-readable prompt section.""" |
|
if not data: |
|
return "No patient data provided." |
|
|
|
sections = [] |
|
for key, val in data.items(): |
|
title = key.replace("_", " ").title() |
|
if isinstance(val, dict) and any(val.values()): |
|
lines = [f"**{title}:**"] |
|
for subk, subv in val.items(): |
|
if subv: |
|
lines.append(f"- {subk.replace('_', ' ').title()}: {subv}") |
|
sections.append("\n".join(lines)) |
|
elif isinstance(val, list) and val: |
|
sections.append(f"**{title}:** {', '.join(map(str, val))}") |
|
elif val: |
|
sections.append(f"**{title}:** {val}") |
|
|
|
return "\n\n".join(sections) |
|
|
|
|
|
|
|
class LabOrderInput(BaseModel): |
|
test_name: str = Field(...) |
|
reason: str = Field(...) |
|
priority: str = Field("Routine") |
|
|
|
|
|
class PrescriptionInput(BaseModel): |
|
medication_name: str = Field(...) |
|
dosage: str = Field(...) |
|
route: str = Field(...) |
|
frequency: str = Field(...) |
|
duration: str = Field("As directed") |
|
reason: str = Field(...) |
|
|
|
|
|
class InteractionCheckInput(BaseModel): |
|
potential_prescription: str = Field(...) |
|
current_medications: Optional[List[str]] = Field(None) |
|
allergies: Optional[List[str]] = Field(None) |
|
|
|
|
|
class FlagRiskInput(BaseModel): |
|
risk_description: str = Field(...) |
|
urgency: str = Field("High") |
|
|
|
|
|
@tool("order_lab_test", args_schema=LabOrderInput) |
|
def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str: |
|
result = { |
|
"status": "success", |
|
"message": f"Lab Ordered: {test_name} ({priority})", |
|
"details": f"Reason: {reason}" |
|
} |
|
return json.dumps(result) |
|
|
|
|
|
@tool("prescribe_medication", args_schema=PrescriptionInput) |
|
def prescribe_medication( |
|
medication_name: str, |
|
dosage: str, |
|
route: str, |
|
frequency: str, |
|
duration: str, |
|
reason: str |
|
) -> str: |
|
result = { |
|
"status": "success", |
|
"message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", |
|
"details": f"Duration: {duration}. Reason: {reason}" |
|
} |
|
return json.dumps(result) |
|
|
|
|
|
@tool("check_drug_interactions", args_schema=InteractionCheckInput) |
|
def check_drug_interactions( |
|
potential_prescription: str, |
|
current_medications: Optional[List[str]] = None, |
|
allergies: Optional[List[str]] = None |
|
) -> str: |
|
warnings: List[str] = [] |
|
presc_lower = potential_prescription.lower().strip() |
|
current = [m.lower().strip() for m in (current_medications or [])] |
|
allergy_list = [a.lower().strip() for a in (allergies or [])] |
|
|
|
|
|
rxcui = get_rxcui(potential_prescription) |
|
label = get_openfda_label(rxcui=rxcui, drug_name=potential_prescription) |
|
if not rxcui and not label: |
|
warnings.append(f"INFO: Could not identify '{potential_prescription}'.") |
|
|
|
|
|
for alg in allergy_list: |
|
if alg == presc_lower: |
|
warnings.append(f"CRITICAL ALLERGY: Patient allergic to '{alg}'.") |
|
|
|
|
|
|
|
if rxcui or label: |
|
for med in current: |
|
if med and med != presc_lower: |
|
|
|
interactions = [] |
|
if label and label.get("drug_interactions"): |
|
interactions = search_text_list(label["drug_interactions"], [med]) |
|
if interactions: |
|
warnings.append( |
|
f"Potential Interaction: '{potential_prescription}' & '{med}'. Snippets: {'; '.join(interactions)}" |
|
) |
|
else: |
|
warnings.append(f"INFO: Skipped interaction check for '{potential_prescription}'.") |
|
|
|
status = "warning" if warnings else "clear" |
|
message = ( |
|
f"Interaction/Allergy check for '{potential_prescription}': {len(warnings)} issue(s)." |
|
if warnings else |
|
f"No major issues for '{potential_prescription}'." |
|
) |
|
return json.dumps({"status": status, "message": message, "warnings": warnings}) |
|
|
|
|
|
@tool("flag_risk", args_schema=FlagRiskInput) |
|
def flag_risk(risk_description: str, urgency: str) -> str: |
|
return json.dumps({ |
|
"status": "flagged", |
|
"message": f"Risk '{risk_description}' flagged with {urgency} urgency." |
|
}) |
|
|
|
|
|
|
|
search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results") |
|
all_tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool] |
|
|
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[List[Any], operator.add] |
|
patient_data: Optional[Dict] |
|
summary: Optional[str] |
|
interaction_warnings: Optional[List[str]] |
|
|
|
|
|
llm = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME) |
|
model_with_tools = llm.bind_tools(all_tools) |
|
tool_executor = ToolExecutor(all_tools) |
|
|
|
|
|
def agent_node(state: AgentState) -> Dict: |
|
"""Invoke the LLM agent node.""" |
|
msgs = state['messages'][:] |
|
if not msgs or not isinstance(msgs[0], SystemMessage): |
|
msgs.insert(0, SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)) |
|
|
|
try: |
|
response = model_with_tools.invoke(msgs) |
|
return {"messages": [response]} |
|
except Exception as e: |
|
traceback.print_exc() |
|
err = AIMessage(content=f"Error: {e}") |
|
return {"messages": [err]} |
|
|
|
|
|
def tool_node(state: AgentState) -> Dict: |
|
"""Execute any pending tool calls from the last AI message.""" |
|
last = state['messages'][-1] |
|
if not isinstance(last, AIMessage) or not getattr(last, 'tool_calls', None): |
|
return {"messages": [], "interaction_warnings": None} |
|
|
|
calls = last.tool_calls |
|
|
|
blocked_ids = set() |
|
for call in calls: |
|
if call['name'] == 'prescribe_medication': |
|
|
|
med = call['args'].get('medication_name', '').lower() |
|
if not any( |
|
c['name'] == 'check_drug_interactions' and |
|
c['args'].get('potential_prescription', '').lower() == med |
|
for c in calls |
|
): |
|
blocked_ids.add(call['id']) |
|
|
|
valid_calls = [c for c in calls if c['id'] not in blocked_ids] |
|
|
|
|
|
for c in valid_calls: |
|
if c['name'] == 'check_drug_interactions': |
|
c['args']['current_medications'] = state.get('patient_data', {}).get('medications', {}).get('current', []) |
|
c['args']['allergies'] = state.get('patient_data', {}).get('allergies', []) |
|
|
|
results = [] |
|
warnings: List[str] = [] |
|
try: |
|
responses = tool_executor.batch(valid_calls, return_exceptions=True) |
|
for call, resp in zip(valid_calls, responses): |
|
if isinstance(resp, Exception): |
|
traceback.print_exc() |
|
content = json.dumps({"status": "error", "message": str(resp)}) |
|
else: |
|
content = str(resp) |
|
if call['name'] == 'check_drug_interactions': |
|
data = json.loads(content) |
|
if data.get('warnings'): |
|
warnings.extend(data['warnings']) |
|
results.append(ToolMessage(content=content, tool_call_id=call['id'], name=call['name'])) |
|
except Exception as e: |
|
traceback.print_exc() |
|
content = json.dumps({"status": "error", "message": str(e)}) |
|
for c in valid_calls: |
|
results.append(ToolMessage(content=content, tool_call_id=c['id'], name=c['name'])) |
|
|
|
return {"messages": results, "interaction_warnings": warnings or None} |
|
|
|
|
|
def reflection_node(state: AgentState) -> Dict: |
|
"""Review interaction warnings and adjust plan if needed.""" |
|
warnings = state.get('interaction_warnings') |
|
if not warnings: |
|
return {"messages": [], "interaction_warnings": None} |
|
|
|
|
|
trigger_id = None |
|
for msg in reversed(state['messages']): |
|
if isinstance(msg, ToolMessage) and msg.name == 'check_drug_interactions': |
|
trigger_id = msg.tool_call_id |
|
break |
|
|
|
prompt = ( |
|
f"Interaction warnings:\n{json.dumps(warnings, indent=2)}\n" |
|
"Provide a revised therapeutics plan addressing these issues." |
|
) |
|
msgs = [ |
|
SystemMessage(content="Safety reflection on drug interactions."), |
|
HumanMessage(content=prompt) |
|
] |
|
|
|
try: |
|
resp = llm.invoke(msgs) |
|
return {"messages": [AIMessage(content=resp.content)], "interaction_warnings": None} |
|
except Exception as e: |
|
traceback.print_exc() |
|
return {"messages": [AIMessage(content=f"Reflection error: {e}")], "interaction_warnings": None} |
|
|
|
|
|
def should_continue(state: AgentState) -> str: |
|
last = state['messages'][-1] if state['messages'] else None |
|
if not isinstance(last, AIMessage): |
|
return 'end_conversation_turn' |
|
if getattr(last, 'tool_calls', None): |
|
return 'continue_tools' |
|
return 'end_conversation_turn' |
|
|
|
|
|
def after_tools_router(state: AgentState) -> str: |
|
if state.get('interaction_warnings'): |
|
return 'reflect_on_warnings' |
|
return 'continue_to_agent' |
|
|
|
|
|
class ClinicalAgent: |
|
def __init__(self): |
|
graph = StateGraph(AgentState) |
|
graph.add_node('agent', agent_node) |
|
graph.add_node('tools', tool_node) |
|
graph.add_node('reflection', reflection_node) |
|
graph.set_entry_point('agent') |
|
graph.add_conditional_edges( |
|
'agent', should_continue, |
|
{'continue_tools': 'tools', 'end_conversation_turn': END} |
|
) |
|
graph.add_conditional_edges( |
|
'tools', after_tools_router, |
|
{'reflect_on_warnings': 'reflection', 'continue_to_agent': 'agent'} |
|
) |
|
graph.add_edge('reflection', 'agent') |
|
self.app = graph.compile() |
|
|
|
def invoke_turn(self, state: Dict) -> Dict: |
|
try: |
|
result = self.app.invoke(state, {'recursion_limit': 15}) |
|
result.setdefault('summary', state.get('summary')) |
|
result.setdefault('interaction_warnings', None) |
|
return result |
|
except Exception as e: |
|
traceback.print_exc() |
|
err = AIMessage(content=f"Critical error: {e}") |
|
return { |
|
'messages': state.get('messages', []) + [err], |
|
'patient_data': state.get('patient_data'), |
|
'summary': state.get('summary'), |
|
'interaction_warnings': None |
|
} |
|
|