|
|
|
|
|
import os |
|
import re |
|
import json |
|
import logging |
|
import traceback |
|
from functools import lru_cache |
|
from typing import List, Dict, Any, Optional, TypedDict |
|
|
|
import requests |
|
from langchain_groq import ChatGroq |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage |
|
from langchain_core.pydantic_v1 import BaseModel, Field |
|
from langchain_core.tools import tool |
|
from langgraph.prebuilt import ToolExecutor |
|
from langgraph.graph import StateGraph, END |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
UMLS_API_KEY = os.getenv("UMLS_API_KEY") |
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
|
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") |
|
|
|
if not all([UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY]): |
|
logger.error("Missing one or more required API keys: UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY") |
|
raise RuntimeError("Missing required API keys") |
|
|
|
|
|
AGENT_MODEL_NAME = "llama3-70b-8192" |
|
AGENT_TEMPERATURE = 0.1 |
|
MAX_SEARCH_RESULTS = 3 |
|
|
|
class ClinicalPrompts: |
|
SYSTEM_PROMPT = """ |
|
You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation... |
|
[SYSTEM PROMPT CONTENT HERE] |
|
""" |
|
|
|
|
|
UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key" |
|
RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST" |
|
OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json" |
|
|
|
@lru_cache(maxsize=256) |
|
def get_rxcui(drug_name: str) -> Optional[str]: |
|
"""Lookup RxNorm CUI for a given drug name.""" |
|
drug_name = (drug_name or "").strip() |
|
if not drug_name: |
|
return None |
|
logger.info(f"Looking up RxCUI for '{drug_name}'") |
|
try: |
|
|
|
params = {"name": drug_name, "search": 1} |
|
r = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10) |
|
r.raise_for_status() |
|
ids = r.json().get("idGroup", {}).get("rxnormId") |
|
if ids: |
|
logger.info(f"Found RxCUI {ids[0]} for '{drug_name}'") |
|
return ids[0] |
|
|
|
r = requests.get(f"{RXNORM_API_BASE}/drugs.json", params={"name": drug_name}, timeout=10) |
|
r.raise_for_status() |
|
for grp in r.json().get("drugGroup", {}).get("conceptGroup", []): |
|
props = grp.get("conceptProperties") |
|
if props: |
|
logger.info(f"Found RxCUI {props[0]['rxcui']} via /drugs for '{drug_name}'") |
|
return props[0]["rxcui"] |
|
except Exception: |
|
logger.exception(f"Error fetching RxCUI for '{drug_name}'") |
|
return None |
|
|
|
@lru_cache(maxsize=128) |
|
def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[Dict[str, Any]]: |
|
"""Fetch the OpenFDA label for a drug by RxCUI or name.""" |
|
if not (rxcui or drug_name): |
|
return None |
|
terms = [] |
|
if rxcui: |
|
terms.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"') |
|
if drug_name: |
|
dn = drug_name.lower() |
|
terms.append(f'(openfda.brand_name:"{dn}" OR openfda.generic_name:"{dn}")') |
|
query = " OR ".join(terms) |
|
logger.info(f"Looking up OpenFDA label with query: {query}") |
|
try: |
|
r = requests.get(OPENFDA_API_BASE, params={"search": query, "limit": 1}, timeout=15) |
|
r.raise_for_status() |
|
results = r.json().get("results", []) |
|
if results: |
|
return results[0] |
|
except Exception: |
|
logger.exception("Error fetching OpenFDA label") |
|
return None |
|
|
|
def search_text_list(texts: List[str], terms: List[str]) -> List[str]: |
|
"""Return highlighted snippets from a list of texts containing any of the search terms.""" |
|
snippets = [] |
|
lowers = [t.lower() for t in terms if t] |
|
for text in texts or []: |
|
tl = text.lower() |
|
for term in lowers: |
|
if term in tl: |
|
i = tl.find(term) |
|
start = max(0, i - 50) |
|
end = min(len(text), i + len(term) + 100) |
|
snippet = text[start:end] |
|
snippet = re.sub(f"({re.escape(term)})", r"**\1**", snippet, flags=re.IGNORECASE) |
|
snippets.append(f"...{snippet}...") |
|
break |
|
return snippets |
|
|
|
def parse_bp(bp: str) -> Optional[tuple[int, int]]: |
|
"""Parse 'SYS/DIA' blood pressure string into a (sys, dia) tuple.""" |
|
if m := re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", (bp or "").strip()): |
|
return int(m.group(1)), int(m.group(2)) |
|
return None |
|
|
|
def check_red_flags(patient_data: Dict[str, Any]) -> List[str]: |
|
"""Identify immediate red flags from patient_data.""" |
|
flags: List[str] = [] |
|
hpi = patient_data.get("hpi", {}) |
|
vitals = patient_data.get("vitals", {}) |
|
syms = [s.lower() for s in hpi.get("symptoms", []) if isinstance(s, str)] |
|
|
|
|
|
mapping = { |
|
"chest pain": "Chest pain reported", |
|
"shortness of breath": "Shortness of breath reported", |
|
"severe headache": "Severe headache reported", |
|
"syncope": "Syncope reported", |
|
"hemoptysis": "Hemoptysis reported" |
|
} |
|
for term, desc in mapping.items(): |
|
if term in syms: |
|
flags.append(f"Red Flag: {desc}.") |
|
|
|
|
|
temp = vitals.get("temp_c") |
|
hr = vitals.get("hr_bpm") |
|
rr = vitals.get("rr_rpm") |
|
spo2 = vitals.get("spo2_percent") |
|
bp = parse_bp(vitals.get("bp_mmhg", "")) |
|
|
|
if temp is not None and temp >= 38.5: |
|
flags.append(f"Red Flag: Fever ({temp}Β°C).") |
|
if hr is not None: |
|
if hr >= 120: |
|
flags.append(f"Red Flag: Tachycardia ({hr} bpm).") |
|
if hr <= 50: |
|
flags.append(f"Red Flag: Bradycardia ({hr} bpm).") |
|
if rr is not None and rr >= 24: |
|
flags.append(f"Red Flag: Tachypnea ({rr} rpm).") |
|
if spo2 is not None and spo2 <= 92: |
|
flags.append(f"Red Flag: Hypoxia ({spo2}%).") |
|
if bp: |
|
sys, dia = bp |
|
if sys >= 180 or dia >= 110: |
|
flags.append(f"Red Flag: Hypertensive urgency/emergency ({sys}/{dia} mmHg).") |
|
if sys <= 90 or dia <= 60: |
|
flags.append(f"Red Flag: Hypotension ({sys}/{dia} mmHg).") |
|
|
|
return list(dict.fromkeys(flags)) |
|
|
|
def format_patient_data_for_prompt(data: Dict[str, Any]) -> str: |
|
"""Format patient_data dict into a markdown-like prompt section.""" |
|
if not data: |
|
return "No patient data provided." |
|
lines: List[str] = [] |
|
for section, value in data.items(): |
|
title = section.replace("_", " ").title() |
|
if isinstance(value, dict) and any(value.values()): |
|
lines.append(f"**{title}:**") |
|
for k, v in value.items(): |
|
if v: |
|
lines.append(f"- {k.replace('_',' ').title()}: {v}") |
|
elif isinstance(value, list) and value: |
|
lines.append(f"**{title}:** {', '.join(map(str, value))}") |
|
elif value: |
|
lines.append(f"**{title}:** {value}") |
|
return "\n".join(lines) |
|
|
|
|
|
class LabOrderInput(BaseModel): |
|
test_name: str = Field(...) |
|
reason: str = Field(...) |
|
priority: str = Field("Routine") |
|
|
|
class PrescriptionInput(BaseModel): |
|
medication_name: str = Field(...) |
|
dosage: str = Field(...) |
|
route: str = Field(...) |
|
frequency: str = Field(...) |
|
duration: str = Field("As directed") |
|
reason: str = Field(...) |
|
|
|
class InteractionCheckInput(BaseModel): |
|
potential_prescription: str |
|
current_medications: Optional[List[str]] = Field(None) |
|
allergies: Optional[List[str]] = Field(None) |
|
|
|
class FlagRiskInput(BaseModel): |
|
risk_description: str = Field(...) |
|
urgency: str = Field("High") |
|
|
|
|
|
@tool("order_lab_test", args_schema=LabOrderInput) |
|
def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str: |
|
"""Place an order for a laboratory test.""" |
|
logger.info(f"Ordering lab test: {test_name}, reason: {reason}, priority: {priority}") |
|
return json.dumps({ |
|
"status": "success", |
|
"message": f"Lab Ordered: {test_name} ({priority})", |
|
"details": f"Reason: {reason}" |
|
}) |
|
|
|
@tool("prescribe_medication", args_schema=PrescriptionInput) |
|
def prescribe_medication( |
|
medication_name: str, |
|
dosage: str, |
|
route: str, |
|
frequency: str, |
|
duration: str, |
|
reason: str |
|
) -> str: |
|
"""Prepare a medication prescription.""" |
|
logger.info(f"Preparing prescription: {medication_name} {dosage}, route: {route}, freq: {frequency}") |
|
return json.dumps({ |
|
"status": "success", |
|
"message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", |
|
"details": f"Duration: {duration}. Reason: {reason}" |
|
}) |
|
|
|
@tool("check_drug_interactions", args_schema=InteractionCheckInput) |
|
def check_drug_interactions( |
|
potential_prescription: str, |
|
current_medications: Optional[List[str]] = None, |
|
allergies: Optional[List[str]] = None |
|
) -> str: |
|
"""Check for drugβdrug interactions and allergy risks.""" |
|
logger.info(f"Checking interactions for: {potential_prescription}") |
|
warnings: List[str] = [] |
|
pm = [m.lower().strip() for m in (current_medications or []) if m] |
|
al = [a.lower().strip() for a in (allergies or []) if a] |
|
|
|
|
|
if potential_prescription.lower().strip() in al: |
|
warnings.append(f"CRITICAL ALLERGY: Patient allergic to '{potential_prescription}'.") |
|
|
|
|
|
rxcui = get_rxcui(potential_prescription) |
|
label = get_openfda_label(rxcui=rxcui, drug_name=potential_prescription) |
|
if not (rxcui or label): |
|
warnings.append(f"INFO: Could not identify '{potential_prescription}'. Checks may be incomplete.") |
|
|
|
|
|
for section in ("contraindications", "warnings_and_cautions", "warnings"): |
|
items = label.get(section) if label else None |
|
if isinstance(items, list): |
|
snippets = search_text_list(items, al) |
|
if snippets: |
|
warnings.append(f"ALLERGY RISK ({section}): {'; '.join(snippets)}") |
|
|
|
|
|
for med in pm: |
|
mrxcui = get_rxcui(med) |
|
mlabel = get_openfda_label(rxcui=mrxcui, drug_name=med) |
|
for sec in ("drug_interactions",): |
|
for src_label, src_name in ((label, potential_prescription), (mlabel, med)): |
|
items = src_label.get(sec) if src_label else None |
|
if isinstance(items, list): |
|
snippets = search_text_list(items, [med if src_name == potential_prescription else potential_prescription]) |
|
if snippets: |
|
warnings.append(f"Interaction ({src_name} label): {'; '.join(snippets)}") |
|
|
|
status = "warning" if warnings else "clear" |
|
message = ( |
|
f"{len(warnings)} issue(s) found for '{potential_prescription}'." |
|
if warnings else |
|
f"No major interactions or allergy issues identified for '{potential_prescription}'." |
|
) |
|
return json.dumps({"status": status, "message": message, "warnings": warnings}) |
|
|
|
@tool("flag_risk", args_schema=FlagRiskInput) |
|
def flag_risk(risk_description: str, urgency: str = "High") -> str: |
|
"""Flag a clinical risk with given urgency.""" |
|
logger.info(f"Flagging risk: {risk_description} (urgency={urgency})") |
|
return json.dumps({ |
|
"status": "flagged", |
|
"message": f"Risk '{risk_description}' flagged with {urgency} urgency." |
|
}) |
|
|
|
|
|
search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results") |
|
all_tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool] |
|
|
|
|
|
llm = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME) |
|
model_with_tools = llm.bind_tools(all_tools) |
|
tool_executor = ToolExecutor(all_tools) |
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: List[Any] |
|
patient_data: Optional[Dict[str, Any]] |
|
summary: Optional[str] |
|
interaction_warnings: Optional[List[str]] |
|
|
|
|
|
def agent_node(state: AgentState) -> Dict[str, Any]: |
|
msgs = state["messages"] |
|
if not msgs or not isinstance(msgs[0], SystemMessage): |
|
msgs = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + msgs |
|
logger.info(f"Invoking LLM with {len(msgs)} messages") |
|
try: |
|
response = model_with_tools.invoke(msgs) |
|
return {"messages": [response]} |
|
except Exception as e: |
|
logger.exception("Error in agent_node") |
|
return {"messages": [AIMessage(content=f"Error: {e}")]} |
|
|
|
def tool_node(state: AgentState) -> Dict[str, Any]: |
|
last = state["messages"][-1] |
|
if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None): |
|
logger.warning("tool_node invoked without pending tool_calls") |
|
return {"messages": [], "interaction_warnings": None} |
|
|
|
calls = last.tool_calls |
|
blocked_ids = set() |
|
for call in calls: |
|
if call["name"] == "prescribe_medication": |
|
med = call["args"].get("medication_name", "").lower() |
|
if not any( |
|
c["name"] == "check_drug_interactions" and |
|
c["args"].get("potential_prescription","").lower() == med |
|
for c in calls |
|
): |
|
logger.warning(f"Blocking prescribe_medication for '{med}' without interaction check") |
|
blocked_ids.add(call["id"]) |
|
|
|
to_execute = [c for c in calls if c["id"] not in blocked_ids] |
|
pd = state.get("patient_data", {}) |
|
for call in to_execute: |
|
if call["name"] == "check_drug_interactions": |
|
call["args"].setdefault("current_medications", pd.get("medications", {}).get("current", [])) |
|
call["args"].setdefault("allergies", pd.get("allergies", [])) |
|
|
|
messages: List[ToolMessage] = [] |
|
warnings: List[str] = [] |
|
try: |
|
responses = tool_executor.batch(to_execute, return_exceptions=True) |
|
for call, resp in zip(to_execute, responses): |
|
if isinstance(resp, Exception): |
|
logger.exception(f"Error executing tool {call['name']}") |
|
content = json.dumps({"status": "error", "message": str(resp)}) |
|
else: |
|
content = str(resp) |
|
if call["name"] == "check_drug_interactions": |
|
data = json.loads(content) |
|
if data.get("status") == "warning": |
|
warnings.extend(data.get("warnings", [])) |
|
messages.append(ToolMessage(content=content, tool_call_id=call["id"], name=call["name"])) |
|
except Exception as e: |
|
logger.exception("Critical error in tool_node") |
|
for call in to_execute: |
|
messages.append(ToolMessage( |
|
content=json.dumps({"status": "error", "message": str(e)}), |
|
tool_call_id=call["id"], |
|
name=call["name"] |
|
)) |
|
return {"messages": messages, "interaction_warnings": warnings or None} |
|
|
|
def reflection_node(state: AgentState) -> Dict[str, Any]: |
|
warns = state.get("interaction_warnings") |
|
if not warns: |
|
logger.warning("reflection_node called without warnings") |
|
return {"messages": [], "interaction_warnings": None} |
|
|
|
triggering = None |
|
for msg in reversed(state["messages"]): |
|
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): |
|
triggering = msg |
|
break |
|
|
|
if not triggering: |
|
return {"messages": [AIMessage(content="Internal Error: reflection context missing.")], "interaction_warnings": None} |
|
|
|
prompt = ( |
|
"You are SynapseAI, performing a focused safety review of the following plan:\n\n" |
|
f"{triggering.content}\n\n" |
|
"Highlight any issues based on these warnings:\n" + |
|
"\n".join(f"- {w}" for w in warns) |
|
) |
|
try: |
|
resp = llm.invoke([SystemMessage(content="Safety reflection"), HumanMessage(content=prompt)]) |
|
return {"messages": [AIMessage(content=resp.content)], "interaction_warnings": None} |
|
except Exception as e: |
|
logger.exception("Error during reflection") |
|
return {"messages": [AIMessage(content=f"Error during reflection: {e}")], "interaction_warnings": None} |
|
|
|
|
|
def should_continue(state: AgentState) -> str: |
|
last = state["messages"][-1] if state["messages"] else None |
|
if not isinstance(last, AIMessage): |
|
return "end_conversation_turn" |
|
if getattr(last, "tool_calls", None): |
|
return "continue_tools" |
|
return "end_conversation_turn" |
|
|
|
def after_tools_router(state: AgentState) -> str: |
|
return "reflection" if state.get("interaction_warnings") else "agent" |
|
|
|
|
|
class ClinicalAgent: |
|
def __init__(self): |
|
logger.info("Building ClinicalAgent workflow") |
|
wf = StateGraph(AgentState) |
|
wf.add_node("agent", agent_node) |
|
wf.add_node("tools", tool_node) |
|
wf.add_node("reflection", reflection_node) |
|
wf.set_entry_point("agent") |
|
wf.add_conditional_edges("agent", should_continue, { |
|
"continue_tools": "tools", |
|
"end_conversation_turn": END |
|
}) |
|
wf.add_conditional_edges("tools", after_tools_router, { |
|
"reflection": "reflection", |
|
"agent": "agent" |
|
}) |
|
wf.add_edge("reflection", "agent") |
|
self.graph_app = wf.compile() |
|
logger.info("ClinicalAgent ready") |
|
|
|
def invoke_turn(self, state: Dict[str, Any]) -> Dict[str, Any]: |
|
try: |
|
result = self.graph_app.invoke(state, {"recursion_limit": 15}) |
|
result.setdefault("summary", state.get("summary")) |
|
result.setdefault("interaction_warnings", None) |
|
return result |
|
except Exception as e: |
|
logger.exception("Error during graph invocation") |
|
return { |
|
"messages": state.get("messages", []) + [AIMessage(content=f"Error: {e}")], |
|
"patient_data": state.get("patient_data"), |
|
"summary": state.get("summary"), |
|
"interaction_warnings": None |
|
} |
|
|