import os import re import json import logging import traceback from functools import lru_cache from typing import List, Dict, Any, Optional, TypedDict import requests from langchain_groq import ChatGroq from langchain_community.tools.tavily_search import TavilySearchResults from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import tool from langgraph.prebuilt import ToolExecutor from langgraph.graph import StateGraph, END # ── Logging Configuration ────────────────────────────────────────────── logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) # ── Environment Variables ────────────────────────────────────────────── UMLS_API_KEY = os.getenv("UMLS_API_KEY") GROQ_API_KEY = os.getenv("GROQ_API_KEY") TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") if not all([UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY]): logger.error("Missing one or more required API keys: UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY") raise RuntimeError("Missing required API keys") # ── Agent Configuration ────────────────────────────────────────────── AGENT_MODEL_NAME = "llama3-70b-8192" AGENT_TEMPERATURE = 0.1 MAX_SEARCH_RESULTS = 3 class ClinicalPrompts: SYSTEM_PROMPT = """ You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation... [SYSTEM PROMPT CONTENT HERE] """ # ── Message Normalization Helpers ───────────────────────────────────────── def wrap_message(msg: Any) -> AIMessage: """ Ensures the given message is an AIMessage. If it is a dict, extracts the 'content' field (or serializes the dict). Otherwise, converts the message to a string. """ if isinstance(msg, AIMessage): return msg elif isinstance(msg, dict): return AIMessage(content=msg.get("content", json.dumps(msg))) else: return AIMessage(content=str(msg)) def normalize_messages(state: Dict[str, Any]) -> Dict[str, Any]: """ Normalizes all messages in the state to be AIMessage objects. """ state["messages"] = [wrap_message(m) for m in state.get("messages", [])] return state # ── Helper Functions ───────────────────────────────────────────────────── UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key" RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST" OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json" @lru_cache(maxsize=256) def get_rxcui(drug_name: str) -> Optional[str]: """Lookup RxNorm CUI for a given drug name.""" drug_name = (drug_name or "").strip() if not drug_name: return None logger.info(f"Looking up RxCUI for '{drug_name}'") try: params = {"name": drug_name, "search": 1} r = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10) r.raise_for_status() ids = r.json().get("idGroup", {}).get("rxnormId") if ids: logger.info(f"Found RxCUI {ids[0]} for '{drug_name}'") return ids[0] r = requests.get(f"{RXNORM_API_BASE}/drugs.json", params={"name": drug_name}, timeout=10) r.raise_for_status() for grp in r.json().get("drugGroup", {}).get("conceptGroup", []): props = grp.get("conceptProperties") if props: logger.info(f"Found RxCUI {props[0]['rxcui']} via /drugs for '{drug_name}'") return props[0]["rxcui"] except Exception: logger.exception(f"Error fetching RxCUI for '{drug_name}'") return None @lru_cache(maxsize=128) def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[Dict[str, Any]]: """Fetch the OpenFDA label for a drug by RxCUI or name.""" if not (rxcui or drug_name): return None terms = [] if rxcui: terms.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"') if drug_name: dn = drug_name.lower() terms.append(f'(openfda.brand_name:"{dn}" OR openfda.generic_name:"{dn}")') query = " OR ".join(terms) logger.info(f"Looking up OpenFDA label with query: {query}") try: r = requests.get(OPENFDA_API_BASE, params={"search": query, "limit": 1}, timeout=15) r.raise_for_status() results = r.json().get("results", []) if results: return results[0] except Exception: logger.exception("Error fetching OpenFDA label") return None def search_text_list(texts: List[str], terms: List[str]) -> List[str]: """Return highlighted snippets from a list of texts containing any of the search terms.""" snippets = [] lowers = [t.lower() for t in terms if t] for text in texts or []: tl = text.lower() for term in lowers: if term in tl: i = tl.find(term) start = max(0, i - 50) end = min(len(text), i + len(term) + 100) snippet = text[start:end] snippet = re.sub(f"({re.escape(term)})", r"**\1**", snippet, flags=re.IGNORECASE) snippets.append(f"...{snippet}...") break return snippets def parse_bp(bp: str) -> Optional[tuple[int, int]]: """Parse 'SYS/DIA' blood pressure string into a (sys, dia) tuple.""" if m := re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", (bp or "").strip()): return int(m.group(1)), int(m.group(2)) return None def check_red_flags(patient_data: Dict[str, Any]) -> List[str]: """Identify immediate red flags from patient_data.""" flags: List[str] = [] hpi = patient_data.get("hpi", {}) vitals = patient_data.get("vitals", {}) syms = [s.lower() for s in hpi.get("symptoms", []) if isinstance(s, str)] mapping = { "chest pain": "Chest pain reported", "shortness of breath": "Shortness of breath reported", "severe headache": "Severe headache reported", "syncope": "Syncope reported", "hemoptysis": "Hemoptysis reported" } for term, desc in mapping.items(): if term in syms: flags.append(f"Red Flag: {desc}.") temp = vitals.get("temp_c") hr = vitals.get("hr_bpm") rr = vitals.get("rr_rpm") spo2 = vitals.get("spo2_percent") bp = parse_bp(vitals.get("bp_mmhg", "")) if temp is not None and temp >= 38.5: flags.append(f"Red Flag: Fever ({temp}°C).") if hr is not None: if hr >= 120: flags.append(f"Red Flag: Tachycardia ({hr} bpm).") if hr <= 50: flags.append(f"Red Flag: Bradycardia ({hr} bpm).") if rr is not None and rr >= 24: flags.append(f"Red Flag: Tachypnea ({rr} rpm).") if spo2 is not None and spo2 <= 92: flags.append(f"Red Flag: Hypoxia ({spo2}%).") if bp: sys, dia = bp if sys >= 180 or dia >= 110: flags.append(f"Red Flag: Hypertensive urgency/emergency ({sys}/{dia} mmHg).") if sys <= 90 or dia <= 60: flags.append(f"Red Flag: Hypotension ({sys}/{dia} mmHg).") return list(dict.fromkeys(flags)) def format_patient_data_for_prompt(data: Dict[str, Any]) -> str: """Format patient_data dict into a markdown-like prompt section.""" if not data: return "No patient data provided." lines: List[str] = [] for section, value in data.items(): title = section.replace("_", " ").title() if isinstance(value, dict) and any(value.values()): lines.append(f"**{title}:**") for k, v in value.items(): if v: lines.append(f"- {k.replace('_',' ').title()}: {v}") elif isinstance(value, list) and value: lines.append(f"**{title}:** {', '.join(map(str, value))}") elif value: lines.append(f"**{title}:** {value}") return "\n".join(lines) # ── Tool Input Schemas ───────────────────────────────────────────────────── class LabOrderInput(BaseModel): test_name: str = Field(...) reason: str = Field(...) priority: str = Field("Routine") class PrescriptionInput(BaseModel): medication_name: str = Field(...) dosage: str = Field(...) route: str = Field(...) frequency: str = Field(...) duration: str = Field("As directed") reason: str = Field(...) class InteractionCheckInput(BaseModel): potential_prescription: str current_medications: Optional[List[str]] = Field(None) allergies: Optional[List[str]] = Field(None) class FlagRiskInput(BaseModel): risk_description: str = Field(...) urgency: str = Field("High") # ── Tool Implementations ─────────────────────────────────────────────────── @tool("order_lab_test", args_schema=LabOrderInput) def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str: """ Place an order for a laboratory test. """ logger.info(f"Ordering lab test: {test_name}, reason: {reason}, priority: {priority}") return json.dumps({ "status": "success", "message": f"Lab Ordered: {test_name} ({priority})", "details": f"Reason: {reason}" }) @tool("prescribe_medication", args_schema=PrescriptionInput) def prescribe_medication( medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str ) -> str: """ Prepare a medication prescription. """ logger.info(f"Preparing prescription: {medication_name} {dosage}, route: {route}, freq: {frequency}") return json.dumps({ "status": "success", "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", "details": f"Duration: {duration}. Reason: {reason}" }) @tool("check_drug_interactions", args_schema=InteractionCheckInput) def check_drug_interactions( potential_prescription: str, current_medications: Optional[List[str]] = None, allergies: Optional[List[str]] = None ) -> str: """ Check for drug–drug interactions and allergy risks. """ logger.info(f"Checking interactions for: {potential_prescription}") warnings: List[str] = [] pm = [m.lower().strip() for m in (current_medications or []) if m] al = [a.lower().strip() for a in (allergies or []) if a] if potential_prescription.lower().strip() in al: warnings.append(f"CRITICAL ALLERGY: Patient allergic to '{potential_prescription}'.") rxcui = get_rxcui(potential_prescription) label = get_openfda_label(rxcui=rxcui, drug_name=potential_prescription) if not (rxcui or label): warnings.append(f"INFO: Could not identify '{potential_prescription}'. Checks may be incomplete.") for section in ("contraindications", "warnings_and_cautions", "warnings"): items = label.get(section) if label else None if isinstance(items, list): snippets = search_text_list(items, al) if snippets: warnings.append(f"ALLERGY RISK ({section}): {'; '.join(snippets)}") for med in pm: mrxcui = get_rxcui(med) mlabel = get_openfda_label(rxcui=mrxcui, drug_name=med) for sec in ("drug_interactions",): for src_label, src_name in ((label, potential_prescription), (mlabel, med)): items = src_label.get(sec) if src_label else None if isinstance(items, list): snippets = search_text_list(items, [med if src_name == potential_prescription else potential_prescription]) if snippets: warnings.append(f"Interaction ({src_name} label): {'; '.join(snippets)}") status = "warning" if warnings else "clear" message = ( f"{len(warnings)} issue(s) found for '{potential_prescription}'." if warnings else f"No major interactions or allergy issues identified for '{potential_prescription}'." ) return json.dumps({"status": status, "message": message, "warnings": warnings}) @tool("flag_risk", args_schema=FlagRiskInput) def flag_risk(risk_description: str, urgency: str = "High") -> str: """ Flag a clinical risk with given urgency. """ logger.info(f"Flagging risk: {risk_description} (urgency={urgency})") return json.dumps({ "status": "flagged", "message": f"Risk '{risk_description}' flagged with {urgency} urgency." }) # ── Include Tavily search tool ───────────────────────────────────────────── search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results") all_tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool] # ── LLM & Tool Executor ─────────────────────────────────────────────────── llm = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME) model_with_tools = llm.bind_tools(all_tools) tool_executor = ToolExecutor(all_tools) # ── State Definition ───────────────────────────────────────────────────── class AgentState(TypedDict): messages: List[Any] patient_data: Optional[Dict[str, Any]] summary: Optional[str] interaction_warnings: Optional[List[str]] done: Optional[bool] iterations: Optional[int] # Helper to propagate state fields between nodes def propagate_state(new: Dict[str, Any], old: Dict[str, Any]) -> Dict[str, Any]: for key in ["iterations", "done", "patient_data", "summary", "interaction_warnings"]: if key in old and key not in new: new[key] = old[key] return new # ── Graph Nodes ───────────────────────────────────────────────────────── def agent_node(state: AgentState) -> Dict[str, Any]: state = normalize_messages(state) if state.get("done", False): return state msgs = state.get("messages", []) if not msgs or not isinstance(msgs[0], SystemMessage): msgs = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + msgs logger.info(f"Invoking LLM with {len(msgs)} messages") try: response = model_with_tools.invoke(msgs) response = wrap_message(response) new_state = {"messages": [response]} return propagate_state(new_state, state) except Exception as e: logger.exception("Error in agent_node") new_state = {"messages": [wrap_message(AIMessage(content=f"Error: {e}"))]} return propagate_state(new_state, state) def tool_node(state: AgentState) -> Dict[str, Any]: state = normalize_messages(state) if state.get("done", False): return state messages_list = state.get("messages", []) if not messages_list: logger.warning("tool_node invoked with no messages") new_state = {"messages": []} return propagate_state(new_state, state) last = wrap_message(messages_list[-1]) tool_calls = last.__dict__.get("tool_calls") if not (isinstance(last, AIMessage) and tool_calls): logger.warning("tool_node invoked without pending tool_calls") new_state = {"messages": []} return propagate_state(new_state, state) calls = tool_calls blocked_ids = set() for call in calls: if call["name"] == "prescribe_medication": med = call["args"].get("medication_name", "").lower() if not any( c["name"] == "check_drug_interactions" and c["args"].get("potential_prescription", "").lower() == med for c in calls ): logger.warning(f"Blocking prescribe_medication for '{med}' without interaction check") blocked_ids.add(call["id"]) to_execute = [c for c in calls if c["id"] not in blocked_ids] pd = state.get("patient_data", {}) for call in to_execute: if call["name"] == "check_drug_interactions": call["args"].setdefault("current_medications", pd.get("medications", {}).get("current", [])) call["args"].setdefault("allergies", pd.get("allergies", [])) messages: List[ToolMessage] = [] warnings: List[str] = [] try: responses = tool_executor.batch(to_execute, return_exceptions=True) for call, resp in zip(to_execute, responses): if isinstance(resp, Exception): logger.exception(f"Error executing tool {call['name']}") content = json.dumps({"status": "error", "message": str(resp)}) else: content = str(resp) if call["name"] == "check_drug_interactions": data = json.loads(content) if data.get("status") == "warning": warnings.extend(data.get("warnings", [])) messages.append(ToolMessage(content=content, tool_call_id=call["id"], name=call["name"])) except Exception as e: logger.exception("Critical error in tool_node") for call in to_execute: messages.append(ToolMessage( content=json.dumps({"status": "error", "message": str(e)}), tool_call_id=call["id"], name=call["name"] )) new_state = {"messages": messages, "interaction_warnings": warnings or None} return propagate_state(new_state, state) def reflection_node(state: AgentState) -> Dict[str, Any]: state = normalize_messages(state) if state.get("done", False): return state warns = state.get("interaction_warnings") if not warns: logger.warning("reflection_node called without warnings") new_state = {"messages": []} return propagate_state(new_state, state) triggering = None for msg in reversed(state.get("messages", [])): wrapped = wrap_message(msg) if isinstance(wrapped, AIMessage) and wrapped.__dict__.get("tool_calls"): triggering = wrapped break if not triggering: new_state = {"messages": [AIMessage(content="Internal Error: reflection context missing.")]} return propagate_state(new_state, state) prompt = ( "You are SynapseAI, performing a focused safety review of the following plan:\n\n" f"{triggering.content}\n\n" "Highlight any issues based on these warnings:\n" + "\n".join(f"- {w}" for w in warns) ) try: resp = llm.invoke([SystemMessage(content="Safety reflection"), HumanMessage(content=prompt)]) new_state = {"messages": [wrap_message(resp)]} return propagate_state(new_state, state) except Exception as e: logger.exception("Error during reflection") new_state = {"messages": [AIMessage(content=f"Error during reflection: {e}")]} return propagate_state(new_state, state) # ── Routing Functions ──────────────────────────────────────────────────── def should_continue(state: AgentState) -> str: state = normalize_messages(state) state.setdefault("iterations", 0) state["iterations"] += 1 logger.info(f"Iteration count: {state['iterations']}") if state["iterations"] >= 4: state.setdefault("messages", []).append(AIMessage(content="Final output: consultation complete.")) state["done"] = True return "end_conversation_turn" if not state.get("messages"): state["done"] = True return "end_conversation_turn" last = wrap_message(state["messages"][-1]) if not isinstance(last, AIMessage): state["done"] = True return "end_conversation_turn" if last.__dict__.get("tool_calls"): return "continue_tools" if "consultation complete" in last.content.lower(): state["done"] = True return "end_conversation_turn" # If no tool calls are present, terminate the conversation instead of looping. state["done"] = True return "end_conversation_turn" def after_tools_router(state: AgentState) -> str: if state.get("interaction_warnings"): return "reflection" return "end_conversation_turn" # ── ClinicalAgent ───────────────────────────────────────────────────────── class ClinicalAgent: def __init__(self): logger.info("Building ClinicalAgent workflow") wf = StateGraph(AgentState) wf.add_node("start", agent_node) wf.add_node("tools", tool_node) wf.add_node("reflection", reflection_node) wf.set_entry_point("start") wf.add_conditional_edges("start", should_continue, { "continue_tools": "tools", "end_conversation_turn": END }) wf.add_conditional_edges("tools", after_tools_router, { "reflection": "reflection", "end_conversation_turn": END }) self.graph_app = wf.compile() logger.info("ClinicalAgent ready") def invoke_turn(self, state: Dict[str, Any]) -> Dict[str, Any]: try: result = self.graph_app.invoke(state, {"recursion_limit": 100}) result.setdefault("summary", state.get("summary")) result.setdefault("interaction_warnings", None) return result except Exception as e: logger.exception("Error during graph invocation") return { "messages": state.get("messages", []) + [AIMessage(content=f"Error: {e}")], "patient_data": state.get("patient_data"), "summary": state.get("summary"), "interaction_warnings": None }