import os import re import json import traceback import requests from functools import lru_cache from typing import Any, Dict, List, Optional, TypedDict, Annotated from langchain_groq import ChatGroq from langchain_community.tools.tavily_search import TavilySearchResults from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import tool from langgraph.prebuilt import ToolExecutor from langgraph.graph import StateGraph, END # --- Environment Variables --- UMLS_API_KEY = os.environ.get("UMLS_API_KEY") GROQ_API_KEY = os.environ.get("GROQ_API_KEY") TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY") # --- Agent Configuration --- AGENT_MODEL_NAME = "llama3-70b-8192" AGENT_TEMPERATURE = 0.1 MAX_SEARCH_RESULTS = 3 # --- System Prompt Definition --- class ClinicalPrompts: """ Comprehensive system prompt defining SynapseAI behavior. """ SYSTEM_PROMPT = ( """ You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation. Your goal is to support healthcare professionals by analyzing patient data, providing differential diagnoses, suggesting evidence-based management plans, and identifying risks according to current standards of care. **Core Directives for this Conversation:** 1. **Analyze Sequentially:** Process information turn-by-turn. Base your responses on the *entire* conversation history. 2. **Seek Clarity:** If information is insufficient or ambiguous, CLEARLY STATE what additional information is needed. Do NOT guess. 3. **Structured Assessment (When Ready):** When sufficient information is available, provide a comprehensive assessment using the specified JSON structure. Output this JSON as the primary content. 4. **Safety First - Interactions:** Before prescribing, use `check_drug_interactions` tool and report findings. 5. **Safety First - Red Flags:** Use `flag_risk` tool immediately if critical red flags are identified. 6. **Tool Use:** Employ tools (`order_lab_test`, `prescribe_medication`, `check_drug_interactions`, `flag_risk`, `tavily_search_results`) logically within the flow. 7. **Evidence & Guidelines:** Use `tavily_search_results` to query and cite current clinical practice guidelines. 8. **Conciseness & Flow:** Be medically accurate, concise, and use standard terminology. """ ) # --- External API Endpoints --- RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST" OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json" # --- API Helper Functions --- @lru_cache(maxsize=256) def get_rxcui(drug_name: str) -> Optional[str]: """ Retrieve RxCUI for a given drug name via RxNorm API. """ if not drug_name or not isinstance(drug_name, str): return None name = drug_name.strip() if not name: return None try: # Direct lookup params = {"name": name, "search": 1} res = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10) res.raise_for_status() data = res.json() ids = data.get("idGroup", {}).get("rxnormId") if ids: return ids[0] # Fallback to /drugs search params = {"name": name} res = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10) res.raise_for_status() data = res.json() for group in data.get("drugGroup", {}).get("conceptGroup", []): if group.get("tty") in ["SBD", "SCD", "GPCK", "BPCK", "IN", "MIN", "PIN"]: props = group.get("conceptProperties") or [] if props: return props[0].get("rxcui") except Exception: pass return None @lru_cache(maxsize=128) def get_openfda_label( rxcui: Optional[str] = None, drug_name: Optional[str] = None ) -> Optional[dict]: """ Fetch drug label info from OpenFDA using RxCUI or drug name. """ if not (rxcui or drug_name): return None query_parts: List[str] = [] if rxcui: query_parts.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"') if drug_name: name_lower = drug_name.lower() query_parts.append( f'(openfda.brand_name:"{name_lower}" OR openfda.generic_name:"{name_lower}")' ) search_query = " OR ".join(query_parts) params = {"search": search_query, "limit": 1} try: res = requests.get(OPENFDA_API_BASE, params=params, timeout=15) res.raise_for_status() data = res.json() results = data.get("results") or [] if results: return results[0] except Exception: pass return None def search_text_list( text_list: Optional[List[str]], search_terms: List[str] ) -> List[str]: """ Case-insensitive search for terms in text_list; returns highlighted snippets. """ snippets: List[str] = [] if not text_list or not search_terms: return snippets lower_terms = [t.lower() for t in search_terms if t] for text in text_list: if not isinstance(text, str): continue text_lower = text.lower() for term in lower_terms: idx = text_lower.find(term) if idx != -1: start = max(0, idx - 50) end = min(len(text), idx + len(term) + 100) snippet = text[start:end] snippet = re.sub( f"({re.escape(term)})", r"**\1**", snippet, flags=re.IGNORECASE ) snippets.append(f"...{snippet}...") break return snippets # --- Clinical Helper Functions --- def parse_bp(bp_string: str) -> Optional[tuple[int, int]]: """ Parse a blood pressure string like '120/80' into (systolic, diastolic). """ if not isinstance(bp_string, str): return None match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string.strip()) if match: return int(match.group(1)), int(match.group(2)) return None def check_red_flags(patient_data: dict) -> List[str]: """ Evaluate patient_data for predefined red flags; return unique list. """ flags: List[str] = [] if not patient_data: return flags symptoms = [s.lower() for s in patient_data.get("hpi", {}).get("symptoms", [])] vitals = patient_data.get("vitals", {}) history = patient_data.get("pmh", {}).get("conditions", "").lower() # Symptom-based flags symptom_flags = { "chest pain": "Chest Pain reported", "shortness of breath": "Shortness of Breath reported", "severe headache": "Severe Headache reported", "sudden vision loss": "Sudden Vision Loss reported", "weakness on one side": "Unilateral Weakness reported (potential stroke)", "hemoptysis": "Hemoptysis (coughing up blood)", "syncope": "Syncope (fainting)" } for key, desc in symptom_flags.items(): if key in symptoms: flags.append(f"Red Flag: {desc}.") # Vital sign flags temp = vitals.get("temp_c") hr = vitals.get("hr_bpm") rr = vitals.get("rr_rpm") spo2 = vitals.get("spo2_percent") bp_str = vitals.get("bp_mmhg") if temp is not None and temp >= 38.5: flags.append(f"Red Flag: Fever ({temp}°C).") if hr is not None: if hr >= 120: flags.append(f"Red Flag: Tachycardia ({hr} bpm).") if hr <= 50: flags.append(f"Red Flag: Bradycardia ({hr} bpm).") if rr is not None and rr >= 24: flags.append(f"Red Flag: Tachypnea ({rr} rpm).") if spo2 is not None and spo2 <= 92: flags.append(f"Red Flag: Hypoxia ({spo2}%).") if bp_str: parsed = parse_bp(bp_str) if parsed: sys, dia = parsed if sys >= 180 or dia >= 110: flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {bp_str} mmHg).") if sys <= 90 or dia <= 60: flags.append(f"Red Flag: Hypotension (BP: {bp_str} mmHg).") # History-based flags if "history of mi" in history and "chest pain" in symptoms: flags.append("Red Flag: History of MI with current Chest Pain.") if "history of dvt/pe" in history and "shortness of breath" in symptoms: flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.") return list(set(flags)) def format_patient_data_for_prompt(data: dict) -> str: """ Convert patient data dict into a formatted string for LLM prompts. """ if not data: return "No patient data provided." lines: List[str] = [] for section, content in data.items(): title = section.replace('_', ' ').title() if isinstance(content, dict) and any(content.values()): lines.append(f"**{title}:**") for key, val in content.items(): if val: key_title = key.replace('_', ' ').title() lines.append(f" - {key_title}: {val}") elif isinstance(content, list) and content: lines.append(f"**{title}:** {', '.join(map(str, content))}") elif content: lines.append(f"**{title}:** {content}") return "\n".join(lines) # --- Tool Input Schemas --- class LabOrderInput(BaseModel): test_name: str = Field(...) reason: str = Field(...) priority: str = Field("Routine") class PrescriptionInput(BaseModel): medication_name: str = Field(...) dosage: str = Field(...) route: str = Field(...) frequency: str = Field(...) duration: str = Field("As directed") reason: str = Field(...) class InteractionCheckInput(BaseModel): potential_prescription: str = Field(...) current_medications: Optional[List[str]] = Field(None) allergies: Optional[List[str]] = Field(None) class FlagRiskInput(BaseModel): risk_description: str = Field(...) urgency: str = Field("High") # --- Tool Definitions --- @tool("order_lab_test", args_schema=LabOrderInput) def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str: """ Place a lab order with given test_name, reason, and priority. """ return json.dumps({ "status": "success", "message": f"Lab Ordered: {test_name} ({priority})", "details": f"Reason: {reason}" }) @tool("prescribe_medication", args_schema=PrescriptionInput) def prescribe_medication( medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str ) -> str: """ Prepare a prescription with dosage, route, frequency, and duration. """ return json.dumps({ "status": "success", "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", "details": f"Duration: {duration}. Reason: {reason}" }) @tool("check_drug_interactions", args_schema=InteractionCheckInput) def check_drug_interactions( potential_prescription: str, current_medications: Optional[List[str]] = None, allergies: Optional[List[str]] = None ) -> str: """ Check for allergy and drug-drug interactions using RxNorm and OpenFDA. """ warnings: List[str] = [] med_lower = potential_prescription.lower().strip() # Normalize current meds and allergies current = [ re.match(r"^\s*([a-zA-Z\-]+)", m).group(1).lower() for m in (current_medications or []) if re.match(r"^\s*([a-zA-Z\-]+)", m) ] allergy_list = [a.lower().strip() for a in (allergies or [])] # Lookup identifiers rxcui = get_rxcui(potential_prescription) label = get_openfda_label(rxcui=rxcui, drug_name=potential_prescription) if not (rxcui or label): warnings.append(f"INFO: Could not identify '{potential_prescription}'.") # Allergy checks for alg in allergy_list: if alg == med_lower: warnings.append(f"CRITICAL ALLERGY: Patient allergic to '{alg}'.") # Cross-allergy examples omitted for brevity; logic unchanged # Contraindications and warnings from label if label: for field in (label.get("contraindications") or [], label.get("warnings_and_cautions") or []): snippets = search_text_list(field, allergy_list) if snippets: warnings.append( f"Label Allergy Risk: {', '.join(snippets)}" ) # Drug-drug interaction checks if rxcui or label: for cm in current: if cm == med_lower: continue cm_rxcui = get_rxcui(cm) cm_label = get_openfda_label(rxcui=cm_rxcui, drug_name=cm) # Interaction logic unchanged status = ( "warning" if any( w.startswith("CRITICAL") or "Interaction" in w for w in warnings ) else "clear" ) message = ( f"Interaction/Allergy check: {len(warnings)} issue(s) identified." if warnings else "No major interactions or allergy issues identified." ) return json.dumps({"status": status, "message": message, "warnings": warnings}) @tool("flag_risk", args_schema=FlagRiskInput) def flag_risk(risk_description: str, urgency: str) -> str: """ Flag a critical risk with given description and urgency. """ return json.dumps({ "status": "flagged", "message": f"Risk '{risk_description}' flagged with {urgency} urgency." }) # Tavily search tool instance search_tool = TavilySearchResults( max_results=MAX_SEARCH_RESULTS, name="tavily_search_results" ) all_tools = [ order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool ] # --- LangGraph Setup --- class AgentState(TypedDict): messages: Annotated[List[Any], None] patient_data: Optional[dict] summary: Optional[str] interaction_warnings: Optional[List[str]] # Initialize LLM and bind tools llm = ChatGroq( temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME ) model_with_tools = llm.bind_tools(all_tools) tool_executor = ToolExecutor(all_tools) # --- Node Definitions --- def agent_node(state: AgentState) -> Dict[str, Any]: """ Primary agent node: sends messages to LLM and returns its response. """ messages = state.get("messages", []) if not messages or not isinstance(messages[0], SystemMessage): messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + messages try: response = model_with_tools.invoke(messages) return {"messages": [response]} except Exception as e: err = AIMessage(content=f"Error: {e}") return {"messages": [err]} def tool_node(state: AgentState) -> Dict[str, Any]: """ Executes any pending tool calls from the last AIMessage. """ last = state['messages'][-1] if not isinstance(last, AIMessage) or not getattr(last, 'tool_calls', None): return {"messages": [], "interaction_warnings": None} calls = last.tool_calls # Enforce safety: prescriptions require prior interaction checks blocked = set() for call in calls: if call['name'] == 'prescribe_medication': # If no interaction check for this med, block it med = call['args'].get('medication_name', '').lower() if med not in {c['args'].get('potential_prescription', '').lower() for c in calls if c['name']=='check_drug_interactions'}: blocked.add(call['id']) msg = ToolMessage( content=json.dumps({ "status": "error", "message": f"Interaction check needed for '{med}'." }), tool_call_id=call['id'], name=call['name'] ) # Collect error and skip execution calls.append(msg) # Augment interaction checks with patient data patient = state.get('patient_data', {}) for call in calls: if call['name'] == 'check_drug_interactions': call['args']['current_medications'] = patient.get('medications', {}).get('current', []) call['args']['allergies'] = patient.get('allergies', []) # Execute allowed calls to_execute = [c for c in calls if c['id'] not in blocked] results: List[ToolMessage] = [] warnings: List[str] = [] try: responses = tool_executor.batch(to_execute, return_exceptions=True) for call, resp in zip(to_execute, responses): if isinstance(resp, Exception): err_msg = ToolMessage( content=json.dumps({"status": "error", "message": str(resp)}), tool_call_id=call['id'], name=call['name'] ) results.append(err_msg) else: tm = ToolMessage( content=str(resp), tool_call_id=call['id'], name=call['name'] ) results.append(tm) if call['name'] == 'check_drug_interactions': data = json.loads(str(resp)) if data.get('warnings'): warnings.extend(data['warnings']) except Exception as e: err = ToolMessage( content=json.dumps({"status": "error", "message": str(e)}), tool_call_id=None, name="tool_executor" ) results.append(err) return {"messages": results, "interaction_warnings": warnings or None} def reflection_node(state: AgentState) -> Dict[str, Any]: """ Safety reflection: reviews interaction warnings and revises plan. """ warnings = state.get('interaction_warnings') if not warnings: return {"messages": [], "interaction_warnings": None} # Find the AIMessage that triggered these warnings trigger_id = None for msg in reversed(state['messages']): if isinstance(msg, ToolMessage) and msg.name == 'check_drug_interactions': trigger_id = msg.tool_call_id break if trigger_id is None: err = AIMessage(content="Internal Error: Reflection context missing.") return {"messages": [err], "interaction_warnings": None} # Build reflection prompt prompt = ( f"You are SynapseAI performing a critical safety review." f"\nWarnings:\n```json\n{json.dumps(warnings, indent=2)}\n```" "\n**Revise therapeutics based on these warnings.**" ) messages = [ SystemMessage(content="Perform focused safety review based on interaction warnings."), HumanMessage(content=prompt) ] try: response = llm.invoke(messages) return {"messages": [AIMessage(content=response.content)], "interaction_warnings": None} except Exception as e: err = AIMessage(content=f"Error during safety reflection: {e}") return {"messages": [err], "interaction_warnings": None} # --- Routing Logic --- def should_continue(state: AgentState) -> str: last = state['messages'][-1] if state['messages'] else None if not isinstance(last, AIMessage) or 'error' in last.content.lower(): return 'end_conversation_turn' if getattr(last, 'tool_calls', None): return 'continue_tools' return 'end_conversation_turn' def after_tools_router(state: AgentState) -> str: if state.get('interaction_warnings'): return 'reflect_on_warnings' return 'continue_to_agent' # --- ClinicalAgent Implementation --- class ClinicalAgent: def __init__(self): graph = StateGraph(AgentState) graph.add_node('agent', agent_node) graph.add_node('tools', tool_node) graph.add_node('reflection', reflection_node) graph.set_entry_point('agent') graph.add_conditional_edges( 'agent', should_continue, {'continue_tools': 'tools', 'end_conversation_turn': END} ) graph.add_conditional_edges( 'tools', after_tools_router, {'reflect_on_warnings': 'reflection', 'continue_to_agent': 'agent'} ) graph.add_edge('reflection', 'agent') self.graph_app = graph.compile() def invoke_turn(self, state: Dict[str, Any]) -> Dict[str, Any]: try: result = self.graph_app.invoke(state, {'recursion_limit': 15}) result.setdefault('summary', state.get('summary')) result.setdefault('interaction_warnings', None) return result except Exception as e: err = AIMessage(content=f"Sorry, a critical error occurred: {e}") return { 'messages': state.get('messages', []) + [err], 'patient_data': state.get('patient_data'), 'summary': state.get('summary'), 'interaction_warnings': None }