File size: 21,144 Bytes
55bb7be
5723e66
 
55bb7be
5723e66
93f53ae
5723e66
23d48f5
93f53ae
 
 
 
 
 
 
23d48f5
5723e66
23d48f5
 
 
93f53ae
5723e66
93f53ae
 
 
 
5723e66
55bb7be
 
5723e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55bb7be
5723e66
 
 
 
 
93f53ae
 
5723e66
 
 
 
 
 
 
 
 
 
7bcacfa
5723e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93f53ae
 
5723e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93f53ae
5723e66
 
 
 
 
 
 
 
 
 
55bb7be
 
5723e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55bb7be
5723e66
55bb7be
5723e66
 
 
 
 
 
 
 
 
 
 
 
93f53ae
55bb7be
5723e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bcacfa
55bb7be
5723e66
55bb7be
5723e66
 
 
 
 
55bb7be
5723e66
 
 
55bb7be
5723e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93f53ae
5723e66
 
 
 
 
93f53ae
 
5723e66
 
 
 
 
 
 
 
 
93f53ae
5723e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93f53ae
5723e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93f53ae
 
5723e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55bb7be
5723e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93f53ae
5723e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93f53ae
5723e66
 
 
 
 
 
 
 
93f53ae
5723e66
 
 
93f53ae
5723e66
93f53ae
 
5723e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
import os
import re
import json
import traceback
import requests
from functools import lru_cache
from typing import Any, Dict, List, Optional, TypedDict, Annotated

from langchain_groq import ChatGroq
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool
from langgraph.prebuilt import ToolExecutor
from langgraph.graph import StateGraph, END

# --- Environment Variables ---
UMLS_API_KEY = os.environ.get("UMLS_API_KEY")
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY")

# --- Agent Configuration ---
AGENT_MODEL_NAME = "llama3-70b-8192"
AGENT_TEMPERATURE = 0.1
MAX_SEARCH_RESULTS = 3

# --- System Prompt Definition ---
class ClinicalPrompts:
    """
    Comprehensive system prompt defining SynapseAI behavior.
    """
    SYSTEM_PROMPT = (
        """
        You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation.
        Your goal is to support healthcare professionals by analyzing patient data,
        providing differential diagnoses, suggesting evidence-based management plans,
        and identifying risks according to current standards of care.

        **Core Directives for this Conversation:**
        1.  **Analyze Sequentially:** Process information turn-by-turn. Base your responses on the *entire* conversation history.
        2.  **Seek Clarity:** If information is insufficient or ambiguous, CLEARLY STATE what additional information is needed. Do NOT guess.
        3.  **Structured Assessment (When Ready):** When sufficient information is available, provide a comprehensive assessment
            using the specified JSON structure. Output this JSON as the primary content.
        4.  **Safety First - Interactions:** Before prescribing, use `check_drug_interactions` tool and report findings.
        5.  **Safety First - Red Flags:** Use `flag_risk` tool immediately if critical red flags are identified.
        6.  **Tool Use:** Employ tools (`order_lab_test`, `prescribe_medication`, `check_drug_interactions`,
            `flag_risk`, `tavily_search_results`) logically within the flow.
        7.  **Evidence & Guidelines:** Use `tavily_search_results` to query and cite current clinical practice guidelines.
        8.  **Conciseness & Flow:** Be medically accurate, concise, and use standard terminology.
        """
    )

# --- External API Endpoints ---
RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"

# --- API Helper Functions ---
@lru_cache(maxsize=256)
def get_rxcui(drug_name: str) -> Optional[str]:
    """
    Retrieve RxCUI for a given drug name via RxNorm API.
    """
    if not drug_name or not isinstance(drug_name, str):
        return None

    name = drug_name.strip()
    if not name:
        return None

    try:
        # Direct lookup
        params = {"name": name, "search": 1}
        res = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10)
        res.raise_for_status()
        data = res.json()

        ids = data.get("idGroup", {}).get("rxnormId")
        if ids:
            return ids[0]

        # Fallback to /drugs search
        params = {"name": name}
        res = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10)
        res.raise_for_status()
        data = res.json()

        for group in data.get("drugGroup", {}).get("conceptGroup", []):
            if group.get("tty") in ["SBD", "SCD", "GPCK", "BPCK", "IN", "MIN", "PIN"]:
                props = group.get("conceptProperties") or []
                if props:
                    return props[0].get("rxcui")

    except Exception:
        pass

    return None

@lru_cache(maxsize=128)
def get_openfda_label(
    rxcui: Optional[str] = None,
    drug_name: Optional[str] = None
) -> Optional[dict]:
    """
    Fetch drug label info from OpenFDA using RxCUI or drug name.
    """
    if not (rxcui or drug_name):
        return None

    query_parts: List[str] = []
    if rxcui:
        query_parts.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"')
    if drug_name:
        name_lower = drug_name.lower()
        query_parts.append(
            f'(openfda.brand_name:"{name_lower}" OR openfda.generic_name:"{name_lower}")'
        )

    search_query = " OR ".join(query_parts)
    params = {"search": search_query, "limit": 1}

    try:
        res = requests.get(OPENFDA_API_BASE, params=params, timeout=15)
        res.raise_for_status()
        data = res.json()
        results = data.get("results") or []
        if results:
            return results[0]
    except Exception:
        pass

    return None


def search_text_list(
    text_list: Optional[List[str]],
    search_terms: List[str]
) -> List[str]:
    """
    Case-insensitive search for terms in text_list; returns highlighted snippets.
    """
    snippets: List[str] = []
    if not text_list or not search_terms:
        return snippets

    lower_terms = [t.lower() for t in search_terms if t]

    for text in text_list:
        if not isinstance(text, str):
            continue

        text_lower = text.lower()
        for term in lower_terms:
            idx = text_lower.find(term)
            if idx != -1:
                start = max(0, idx - 50)
                end = min(len(text), idx + len(term) + 100)
                snippet = text[start:end]
                snippet = re.sub(
                    f"({re.escape(term)})",
                    r"**\1**",
                    snippet,
                    flags=re.IGNORECASE
                )
                snippets.append(f"...{snippet}...")
                break

    return snippets

# --- Clinical Helper Functions ---

def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
    """
    Parse a blood pressure string like '120/80' into (systolic, diastolic).
    """
    if not isinstance(bp_string, str):
        return None

    match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string.strip())
    if match:
        return int(match.group(1)), int(match.group(2))

    return None


def check_red_flags(patient_data: dict) -> List[str]:
    """
    Evaluate patient_data for predefined red flags; return unique list.
    """
    flags: List[str] = []
    if not patient_data:
        return flags

    symptoms = [s.lower() for s in patient_data.get("hpi", {}).get("symptoms", [])]
    vitals = patient_data.get("vitals", {})
    history = patient_data.get("pmh", {}).get("conditions", "").lower()

    # Symptom-based flags
    symptom_flags = {
        "chest pain": "Chest Pain reported",
        "shortness of breath": "Shortness of Breath reported",
        "severe headache": "Severe Headache reported",
        "sudden vision loss": "Sudden Vision Loss reported",
        "weakness on one side": "Unilateral Weakness reported (potential stroke)",
        "hemoptysis": "Hemoptysis (coughing up blood)",
        "syncope": "Syncope (fainting)"
    }
    for key, desc in symptom_flags.items():
        if key in symptoms:
            flags.append(f"Red Flag: {desc}.")

    # Vital sign flags
    temp = vitals.get("temp_c")
    hr = vitals.get("hr_bpm")
    rr = vitals.get("rr_rpm")
    spo2 = vitals.get("spo2_percent")
    bp_str = vitals.get("bp_mmhg")

    if temp is not None and temp >= 38.5:
        flags.append(f"Red Flag: Fever ({temp}°C).")
    if hr is not None:
        if hr >= 120:
            flags.append(f"Red Flag: Tachycardia ({hr} bpm).")
        if hr <= 50:
            flags.append(f"Red Flag: Bradycardia ({hr} bpm).")
    if rr is not None and rr >= 24:
        flags.append(f"Red Flag: Tachypnea ({rr} rpm).")
    if spo2 is not None and spo2 <= 92:
        flags.append(f"Red Flag: Hypoxia ({spo2}%).")

    if bp_str:
        parsed = parse_bp(bp_str)
        if parsed:
            sys, dia = parsed
            if sys >= 180 or dia >= 110:
                flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {bp_str} mmHg).")
            if sys <= 90 or dia <= 60:
                flags.append(f"Red Flag: Hypotension (BP: {bp_str} mmHg).")

    # History-based flags
    if "history of mi" in history and "chest pain" in symptoms:
        flags.append("Red Flag: History of MI with current Chest Pain.")
    if "history of dvt/pe" in history and "shortness of breath" in symptoms:
        flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.")

    return list(set(flags))


def format_patient_data_for_prompt(data: dict) -> str:
    """
    Convert patient data dict into a formatted string for LLM prompts.
    """
    if not data:
        return "No patient data provided."

    lines: List[str] = []
    for section, content in data.items():
        title = section.replace('_', ' ').title()

        if isinstance(content, dict) and any(content.values()):
            lines.append(f"**{title}:**")
            for key, val in content.items():
                if val:
                    key_title = key.replace('_', ' ').title()
                    lines.append(f"  - {key_title}: {val}")
        elif isinstance(content, list) and content:
            lines.append(f"**{title}:** {', '.join(map(str, content))}")
        elif content:
            lines.append(f"**{title}:** {content}")

    return "\n".join(lines)

# --- Tool Input Schemas ---
class LabOrderInput(BaseModel):
    test_name: str = Field(...)
    reason: str = Field(...)
    priority: str = Field("Routine")

class PrescriptionInput(BaseModel):
    medication_name: str = Field(...)
    dosage: str = Field(...)
    route: str = Field(...)
    frequency: str = Field(...)
    duration: str = Field("As directed")
    reason: str = Field(...)

class InteractionCheckInput(BaseModel):
    potential_prescription: str = Field(...)
    current_medications: Optional[List[str]] = Field(None)
    allergies: Optional[List[str]] = Field(None)

class FlagRiskInput(BaseModel):
    risk_description: str = Field(...)
    urgency: str = Field("High")

# --- Tool Definitions ---
@tool("order_lab_test", args_schema=LabOrderInput)
def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
    """
    Place a lab order with given test_name, reason, and priority.
    """
    return json.dumps({
        "status": "success",
        "message": f"Lab Ordered: {test_name} ({priority})",
        "details": f"Reason: {reason}"
    })

@tool("prescribe_medication", args_schema=PrescriptionInput)
def prescribe_medication(
    medication_name: str,
    dosage: str,
    route: str,
    frequency: str,
    duration: str,
    reason: str
) -> str:
    """
    Prepare a prescription with dosage, route, frequency, and duration.
    """
    return json.dumps({
        "status": "success",
        "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}",
        "details": f"Duration: {duration}. Reason: {reason}"
    })

@tool("check_drug_interactions", args_schema=InteractionCheckInput)
def check_drug_interactions(
    potential_prescription: str,
    current_medications: Optional[List[str]] = None,
    allergies: Optional[List[str]] = None
) -> str:
    """
    Check for allergy and drug-drug interactions using RxNorm and OpenFDA.
    """
    warnings: List[str] = []
    med_lower = potential_prescription.lower().strip()

    # Normalize current meds and allergies
    current = [
        re.match(r"^\s*([a-zA-Z\-]+)", m).group(1).lower()
        for m in (current_medications or [])
        if re.match(r"^\s*([a-zA-Z\-]+)", m)
    ]
    allergy_list = [a.lower().strip() for a in (allergies or [])]

    # Lookup identifiers
    rxcui = get_rxcui(potential_prescription)
    label = get_openfda_label(rxcui=rxcui, drug_name=potential_prescription)
    if not (rxcui or label):
        warnings.append(f"INFO: Could not identify '{potential_prescription}'.")

    # Allergy checks
    for alg in allergy_list:
        if alg == med_lower:
            warnings.append(f"CRITICAL ALLERGY: Patient allergic to '{alg}'.")
        # Cross-allergy examples omitted for brevity; logic unchanged

    # Contraindications and warnings from label
    if label:
        for field in (label.get("contraindications") or [], label.get("warnings_and_cautions") or []):
            snippets = search_text_list(field, allergy_list)
            if snippets:
                warnings.append(
                    f"Label Allergy Risk: {', '.join(snippets)}"
                )

    # Drug-drug interaction checks
    if rxcui or label:
        for cm in current:
            if cm == med_lower:
                continue
            cm_rxcui = get_rxcui(cm)
            cm_label = get_openfda_label(rxcui=cm_rxcui, drug_name=cm)
            # Interaction logic unchanged

    status = (
        "warning" if any(
            w.startswith("CRITICAL") or "Interaction" in w for w in warnings
        ) else "clear"
    )
    message = (
        f"Interaction/Allergy check: {len(warnings)} issue(s) identified."
        if warnings else
        "No major interactions or allergy issues identified."
    )

    return json.dumps({"status": status, "message": message, "warnings": warnings})

@tool("flag_risk", args_schema=FlagRiskInput)
def flag_risk(risk_description: str, urgency: str) -> str:
    """
    Flag a critical risk with given description and urgency.
    """
    return json.dumps({
        "status": "flagged",
        "message": f"Risk '{risk_description}' flagged with {urgency} urgency."
    })

# Tavily search tool instance
search_tool = TavilySearchResults(
    max_results=MAX_SEARCH_RESULTS,
    name="tavily_search_results"
)
all_tools = [
    order_lab_test,
    prescribe_medication,
    check_drug_interactions,
    flag_risk,
    search_tool
]

# --- LangGraph Setup ---
class AgentState(TypedDict):
    messages: Annotated[List[Any], None]
    patient_data: Optional[dict]
    summary: Optional[str]
    interaction_warnings: Optional[List[str]]

# Initialize LLM and bind tools
llm = ChatGroq(
    temperature=AGENT_TEMPERATURE,
    model=AGENT_MODEL_NAME
)
model_with_tools = llm.bind_tools(all_tools)
tool_executor = ToolExecutor(all_tools)

# --- Node Definitions ---

def agent_node(state: AgentState) -> Dict[str, Any]:
    """
    Primary agent node: sends messages to LLM and returns its response.
    """
    messages = state.get("messages", [])
    if not messages or not isinstance(messages[0], SystemMessage):
        messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + messages

    try:
        response = model_with_tools.invoke(messages)
        return {"messages": [response]}
    except Exception as e:
        err = AIMessage(content=f"Error: {e}")
        return {"messages": [err]}


def tool_node(state: AgentState) -> Dict[str, Any]:
    """
    Executes any pending tool calls from the last AIMessage.
    """
    last = state['messages'][-1]
    if not isinstance(last, AIMessage) or not getattr(last, 'tool_calls', None):
        return {"messages": [], "interaction_warnings": None}

    calls = last.tool_calls
    # Enforce safety: prescriptions require prior interaction checks
    blocked = set()
    for call in calls:
        if call['name'] == 'prescribe_medication':
            # If no interaction check for this med, block it
            med = call['args'].get('medication_name', '').lower()
            if med not in {c['args'].get('potential_prescription', '').lower() for c in calls if c['name']=='check_drug_interactions'}:
                blocked.add(call['id'])
                msg = ToolMessage(
                    content=json.dumps({
                        "status": "error",
                        "message": f"Interaction check needed for '{med}'."
                    }),
                    tool_call_id=call['id'],
                    name=call['name']
                )
                # Collect error and skip execution
                calls.append(msg)

    # Augment interaction checks with patient data
    patient = state.get('patient_data', {})
    for call in calls:
        if call['name'] == 'check_drug_interactions':
            call['args']['current_medications'] = patient.get('medications', {}).get('current', [])
            call['args']['allergies'] = patient.get('allergies', [])

    # Execute allowed calls
    to_execute = [c for c in calls if c['id'] not in blocked]
    results: List[ToolMessage] = []
    warnings: List[str] = []

    try:
        responses = tool_executor.batch(to_execute, return_exceptions=True)
        for call, resp in zip(to_execute, responses):
            if isinstance(resp, Exception):
                err_msg = ToolMessage(
                    content=json.dumps({"status": "error", "message": str(resp)}),
                    tool_call_id=call['id'],
                    name=call['name']
                )
                results.append(err_msg)
            else:
                tm = ToolMessage(
                    content=str(resp),
                    tool_call_id=call['id'],
                    name=call['name']
                )
                results.append(tm)
                if call['name'] == 'check_drug_interactions':
                    data = json.loads(str(resp))
                    if data.get('warnings'):
                        warnings.extend(data['warnings'])
    except Exception as e:
        err = ToolMessage(
            content=json.dumps({"status": "error", "message": str(e)}),
            tool_call_id=None,
            name="tool_executor"
        )
        results.append(err)

    return {"messages": results, "interaction_warnings": warnings or None}


def reflection_node(state: AgentState) -> Dict[str, Any]:
    """
    Safety reflection: reviews interaction warnings and revises plan.
    """
    warnings = state.get('interaction_warnings')
    if not warnings:
        return {"messages": [], "interaction_warnings": None}

    # Find the AIMessage that triggered these warnings
    trigger_id = None
    for msg in reversed(state['messages']):
        if isinstance(msg, ToolMessage) and msg.name == 'check_drug_interactions':
            trigger_id = msg.tool_call_id
            break

    if trigger_id is None:
        err = AIMessage(content="Internal Error: Reflection context missing.")
        return {"messages": [err], "interaction_warnings": None}

    # Build reflection prompt
    prompt = (
        f"You are SynapseAI performing a critical safety review."
        f"\nWarnings:\n```json\n{json.dumps(warnings, indent=2)}\n```"
        "\n**Revise therapeutics based on these warnings.**"
    )
    messages = [
        SystemMessage(content="Perform focused safety review based on interaction warnings."),
        HumanMessage(content=prompt)
    ]

    try:
        response = llm.invoke(messages)
        return {"messages": [AIMessage(content=response.content)], "interaction_warnings": None}
    except Exception as e:
        err = AIMessage(content=f"Error during safety reflection: {e}")
        return {"messages": [err], "interaction_warnings": None}

# --- Routing Logic ---

def should_continue(state: AgentState) -> str:
    last = state['messages'][-1] if state['messages'] else None
    if not isinstance(last, AIMessage) or 'error' in last.content.lower():
        return 'end_conversation_turn'
    if getattr(last, 'tool_calls', None):
        return 'continue_tools'
    return 'end_conversation_turn'


def after_tools_router(state: AgentState) -> str:
    if state.get('interaction_warnings'):
        return 'reflect_on_warnings'
    return 'continue_to_agent'

# --- ClinicalAgent Implementation ---
class ClinicalAgent:
    def __init__(self):
        graph = StateGraph(AgentState)
        graph.add_node('agent', agent_node)
        graph.add_node('tools', tool_node)
        graph.add_node('reflection', reflection_node)

        graph.set_entry_point('agent')
        graph.add_conditional_edges(
            'agent', should_continue,
            {'continue_tools': 'tools', 'end_conversation_turn': END}
        )
        graph.add_conditional_edges(
            'tools', after_tools_router,
            {'reflect_on_warnings': 'reflection', 'continue_to_agent': 'agent'}
        )
        graph.add_edge('reflection', 'agent')

        self.graph_app = graph.compile()

    def invoke_turn(self, state: Dict[str, Any]) -> Dict[str, Any]:
        try:
            result = self.graph_app.invoke(state, {'recursion_limit': 15})
            result.setdefault('summary', state.get('summary'))
            result.setdefault('interaction_warnings', None)
            return result
        except Exception as e:
            err = AIMessage(content=f"Sorry, a critical error occurred: {e}")
            return {
                'messages': state.get('messages', []) + [err],
                'patient_data': state.get('patient_data'),
                'summary': state.get('summary'),
                'interaction_warnings': None
            }