File size: 20,257 Bytes
86911ce
 
f1cd1c6
86911ce
 
 
55bb7be
86911ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a90f7d4
 
86911ce
 
 
 
 
 
 
a90f7d4
 
86911ce
 
 
 
 
 
 
 
 
a90f7d4
 
 
86911ce
 
 
a90f7d4
86911ce
 
 
 
 
a90f7d4
86911ce
 
 
a90f7d4
 
86911ce
 
a90f7d4
86911ce
 
a90f7d4
 
 
86911ce
 
 
 
 
 
 
 
a90f7d4
86911ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a90f7d4
86911ce
 
 
 
 
 
 
 
a90f7d4
86911ce
 
 
 
 
 
 
a90f7d4
86911ce
 
 
 
 
a90f7d4
86911ce
a90f7d4
86911ce
a90f7d4
 
 
86911ce
 
 
 
a90f7d4
 
86911ce
 
a90f7d4
86911ce
a90f7d4
 
86911ce
 
 
 
 
a90f7d4
86911ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a90f7d4
 
86911ce
 
a90f7d4
86911ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a90f7d4
 
86911ce
 
a90f7d4
 
 
 
 
 
86911ce
 
a90f7d4
86911ce
a90f7d4
86911ce
 
 
a90f7d4
86911ce
 
 
 
a90f7d4
86911ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a90f7d4
86911ce
 
 
 
 
 
 
 
 
 
 
 
 
a90f7d4
86911ce
 
 
 
 
a90f7d4
86911ce
 
a90f7d4
 
86911ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a90f7d4
86911ce
 
 
a90f7d4
86911ce
 
 
 
 
 
 
 
 
a90f7d4
86911ce
 
 
 
 
 
a90f7d4
86911ce
a90f7d4
86911ce
 
a90f7d4
86911ce
a90f7d4
86911ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a90f7d4
 
 
 
 
86911ce
 
 
 
 
 
 
 
 
 
 
a90f7d4
86911ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a90f7d4
86911ce
 
 
 
 
 
 
 
 
 
 
a90f7d4
 
86911ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7e593a
86911ce
 
 
 
e7e593a
86911ce
 
 
 
 
 
e7e593a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
# agent.py

import os
import re
import json
import logging
import traceback
from functools import lru_cache
from typing import List, Dict, Any, Optional, TypedDict

import requests
from langchain_groq import ChatGroq
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool
from langgraph.prebuilt import ToolExecutor
from langgraph.graph import StateGraph, END

# ── Logging Configuration ─────────────────────────────────────────────────────
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

# ── Environment Variables ─────────────────────────────────────────────────────
UMLS_API_KEY   = os.getenv("UMLS_API_KEY")
GROQ_API_KEY   = os.getenv("GROQ_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")

if not all([UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY]):
    logger.error("Missing one or more required API keys: UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY")
    raise RuntimeError("Missing required API keys")

# ── Agent Configuration ───────────────────────────────────────────────────────
AGENT_MODEL_NAME   = "llama3-70b-8192"
AGENT_TEMPERATURE  = 0.1
MAX_SEARCH_RESULTS = 3

class ClinicalPrompts:
    SYSTEM_PROMPT = """
    You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation...
    [SYSTEM PROMPT CONTENT HERE]
    """

# ── Helper Functions ──────────────────────────────────────────────────────────
UMLS_AUTH_ENDPOINT  = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
RXNORM_API_BASE     = "https://rxnav.nlm.nih.gov/REST"
OPENFDA_API_BASE    = "https://api.fda.gov/drug/label.json"

@lru_cache(maxsize=256)
def get_rxcui(drug_name: str) -> Optional[str]:
    """Lookup RxNorm CUI for a given drug name."""
    drug_name = (drug_name or "").strip()
    if not drug_name:
        return None
    logger.info(f"Looking up RxCUI for '{drug_name}'")
    try:
        # First attempt
        params = {"name": drug_name, "search": 1}
        r = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10)
        r.raise_for_status()
        ids = r.json().get("idGroup", {}).get("rxnormId")
        if ids:
            logger.info(f"Found RxCUI {ids[0]} for '{drug_name}'")
            return ids[0]
        # Fallback search
        r = requests.get(f"{RXNORM_API_BASE}/drugs.json", params={"name": drug_name}, timeout=10)
        r.raise_for_status()
        for grp in r.json().get("drugGroup", {}).get("conceptGroup", []):
            props = grp.get("conceptProperties")
            if props:
                logger.info(f"Found RxCUI {props[0]['rxcui']} via /drugs for '{drug_name}'")
                return props[0]["rxcui"]
    except Exception:
        logger.exception(f"Error fetching RxCUI for '{drug_name}'")
    return None

@lru_cache(maxsize=128)
def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[Dict[str, Any]]:
    """Fetch the OpenFDA label for a drug by RxCUI or name."""
    if not (rxcui or drug_name):
        return None
    terms = []
    if rxcui:
        terms.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"')
    if drug_name:
        dn = drug_name.lower()
        terms.append(f'(openfda.brand_name:"{dn}" OR openfda.generic_name:"{dn}")')
    query = " OR ".join(terms)
    logger.info(f"Looking up OpenFDA label with query: {query}")
    try:
        r = requests.get(OPENFDA_API_BASE, params={"search": query, "limit": 1}, timeout=15)
        r.raise_for_status()
        results = r.json().get("results", [])
        if results:
            return results[0]
    except Exception:
        logger.exception("Error fetching OpenFDA label")
    return None

def search_text_list(texts: List[str], terms: List[str]) -> List[str]:
    """Return highlighted snippets from a list of texts containing any of the search terms."""
    snippets = []
    lowers = [t.lower() for t in terms if t]
    for text in texts or []:
        tl = text.lower()
        for term in lowers:
            if term in tl:
                i = tl.find(term)
                start = max(0, i - 50)
                end   = min(len(text), i + len(term) + 100)
                snippet = text[start:end]
                snippet = re.sub(f"({re.escape(term)})", r"**\1**", snippet, flags=re.IGNORECASE)
                snippets.append(f"...{snippet}...")
                break
    return snippets

def parse_bp(bp: str) -> Optional[tuple[int, int]]:
    """Parse 'SYS/DIA' blood pressure string into a (sys, dia) tuple."""
    if m := re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", (bp or "").strip()):
        return int(m.group(1)), int(m.group(2))
    return None

def check_red_flags(patient_data: Dict[str, Any]) -> List[str]:
    """Identify immediate red flags from patient_data."""
    flags: List[str] = []
    hpi    = patient_data.get("hpi", {})
    vitals = patient_data.get("vitals", {})
    syms   = [s.lower() for s in hpi.get("symptoms", []) if isinstance(s, str)]

    # Symptom-based flags
    mapping = {
        "chest pain": "Chest pain reported",
        "shortness of breath": "Shortness of breath reported",
        "severe headache": "Severe headache reported",
        "syncope": "Syncope reported",
        "hemoptysis": "Hemoptysis reported"
    }
    for term, desc in mapping.items():
        if term in syms:
            flags.append(f"Red Flag: {desc}.")

    # Vitals-based flags
    temp = vitals.get("temp_c")
    hr   = vitals.get("hr_bpm")
    rr   = vitals.get("rr_rpm")
    spo2 = vitals.get("spo2_percent")
    bp   = parse_bp(vitals.get("bp_mmhg", ""))

    if temp is not None and temp >= 38.5:
        flags.append(f"Red Flag: Fever ({temp}Β°C).")
    if hr is not None:
        if hr >= 120:
            flags.append(f"Red Flag: Tachycardia ({hr} bpm).")
        if hr <= 50:
            flags.append(f"Red Flag: Bradycardia ({hr} bpm).")
    if rr is not None and rr >= 24:
        flags.append(f"Red Flag: Tachypnea ({rr} rpm).")
    if spo2 is not None and spo2 <= 92:
        flags.append(f"Red Flag: Hypoxia ({spo2}%).")
    if bp:
        sys, dia = bp
        if sys >= 180 or dia >= 110:
            flags.append(f"Red Flag: Hypertensive urgency/emergency ({sys}/{dia} mmHg).")
        if sys <= 90 or dia <= 60:
            flags.append(f"Red Flag: Hypotension ({sys}/{dia} mmHg).")

    return list(dict.fromkeys(flags))  # dedupe, preserve order

def format_patient_data_for_prompt(data: Dict[str, Any]) -> str:
    """Format patient_data dict into a markdown-like prompt section."""
    if not data:
        return "No patient data provided."
    lines: List[str] = []
    for section, value in data.items():
        title = section.replace("_", " ").title()
        if isinstance(value, dict) and any(value.values()):
            lines.append(f"**{title}:**")
            for k, v in value.items():
                if v:
                    lines.append(f"- {k.replace('_',' ').title()}: {v}")
        elif isinstance(value, list) and value:
            lines.append(f"**{title}:** {', '.join(map(str, value))}")
        elif value:
            lines.append(f"**{title}:** {value}")
    return "\n".join(lines)

# ── Tool Input Schemas ────────────────────────────────────────────────────────
class LabOrderInput(BaseModel):
    test_name: str = Field(...)
    reason: str    = Field(...)
    priority: str  = Field("Routine")

class PrescriptionInput(BaseModel):
    medication_name: str          = Field(...)
    dosage: str                   = Field(...)
    route: str                    = Field(...)
    frequency: str                = Field(...)
    duration: str                 = Field("As directed")
    reason: str                   = Field(...)

class InteractionCheckInput(BaseModel):
    potential_prescription: str
    current_medications: Optional[List[str]] = Field(None)
    allergies: Optional[List[str]]             = Field(None)

class FlagRiskInput(BaseModel):
    risk_description: str = Field(...)
    urgency: str          = Field("High")

# ── Tool Implementations ──────────────────────────────────────────────────────
@tool("order_lab_test", args_schema=LabOrderInput)
def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
    """Place an order for a laboratory test."""
    logger.info(f"Ordering lab test: {test_name}, reason: {reason}, priority: {priority}")
    return json.dumps({
        "status": "success",
        "message": f"Lab Ordered: {test_name} ({priority})",
        "details": f"Reason: {reason}"
    })

@tool("prescribe_medication", args_schema=PrescriptionInput)
def prescribe_medication(
    medication_name: str,
    dosage: str,
    route: str,
    frequency: str,
    duration: str,
    reason: str
) -> str:
    """Prepare a medication prescription."""
    logger.info(f"Preparing prescription: {medication_name} {dosage}, route: {route}, freq: {frequency}")
    return json.dumps({
        "status": "success",
        "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}",
        "details": f"Duration: {duration}. Reason: {reason}"
    })

@tool("check_drug_interactions", args_schema=InteractionCheckInput)
def check_drug_interactions(
    potential_prescription: str,
    current_medications: Optional[List[str]] = None,
    allergies: Optional[List[str]] = None
) -> str:
    """Check for drug–drug interactions and allergy risks."""
    logger.info(f"Checking interactions for: {potential_prescription}")
    warnings: List[str] = []
    pm = [m.lower().strip() for m in (current_medications or []) if m]
    al = [a.lower().strip() for a in (allergies or []) if a]

    # Allergy exact match
    if potential_prescription.lower().strip() in al:
        warnings.append(f"CRITICAL ALLERGY: Patient allergic to '{potential_prescription}'.")

    # Identify drug via RxNorm/OpenFDA
    rxcui = get_rxcui(potential_prescription)
    label = get_openfda_label(rxcui=rxcui, drug_name=potential_prescription)
    if not (rxcui or label):
        warnings.append(f"INFO: Could not identify '{potential_prescription}'. Checks may be incomplete.")

    # Contraindications & warnings sections
    for section in ("contraindications", "warnings_and_cautions", "warnings"):
        items = label.get(section) if label else None
        if isinstance(items, list):
            snippets = search_text_list(items, al)
            if snippets:
                warnings.append(f"ALLERGY RISK ({section}): {'; '.join(snippets)}")

    # Drug–drug interactions
    for med in pm:
        mrxcui = get_rxcui(med)
        mlabel = get_openfda_label(rxcui=mrxcui, drug_name=med)
        for sec in ("drug_interactions",):
            for src_label, src_name in ((label, potential_prescription), (mlabel, med)):
                items = src_label.get(sec) if src_label else None
                if isinstance(items, list):
                    snippets = search_text_list(items, [med if src_name == potential_prescription else potential_prescription])
                    if snippets:
                        warnings.append(f"Interaction ({src_name} label): {'; '.join(snippets)}")

    status  = "warning" if warnings else "clear"
    message = (
        f"{len(warnings)} issue(s) found for '{potential_prescription}'."
        if warnings else
        f"No major interactions or allergy issues identified for '{potential_prescription}'."
    )
    return json.dumps({"status": status, "message": message, "warnings": warnings})

@tool("flag_risk", args_schema=FlagRiskInput)
def flag_risk(risk_description: str, urgency: str = "High") -> str:
    """Flag a clinical risk with given urgency."""
    logger.info(f"Flagging risk: {risk_description} (urgency={urgency})")
    return json.dumps({
        "status": "flagged",
        "message": f"Risk '{risk_description}' flagged with {urgency} urgency."
    })

# Include the Tavily search tool
search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results")
all_tools   = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool]

# ── LLM & Tool Executor ──────────────────────────────────────────────────────
llm              = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME)
model_with_tools = llm.bind_tools(all_tools)
tool_executor    = ToolExecutor(all_tools)

# ── State Definition ──────────────────────────────────────────────────────────
class AgentState(TypedDict):
    messages: List[Any]
    patient_data: Optional[Dict[str, Any]]
    summary: Optional[str]
    interaction_warnings: Optional[List[str]]

# ── Graph Nodes ───────────────────────────────────────────────────────────────
def agent_node(state: AgentState) -> Dict[str, Any]:
    msgs = state["messages"]
    if not msgs or not isinstance(msgs[0], SystemMessage):
        msgs = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + msgs
    logger.info(f"Invoking LLM with {len(msgs)} messages")
    try:
        response = model_with_tools.invoke(msgs)
        return {"messages": [response]}
    except Exception as e:
        logger.exception("Error in agent_node")
        return {"messages": [AIMessage(content=f"Error: {e}")]} 

def tool_node(state: AgentState) -> Dict[str, Any]:
    last = state["messages"][-1]
    if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None):
        logger.warning("tool_node invoked without pending tool_calls")
        return {"messages": [], "interaction_warnings": None}

    calls = last.tool_calls
    blocked_ids = set()
    for call in calls:
        if call["name"] == "prescribe_medication":
            med = call["args"].get("medication_name", "").lower()
            if not any(
                c["name"] == "check_drug_interactions" and
                c["args"].get("potential_prescription","").lower() == med
                for c in calls
            ):
                logger.warning(f"Blocking prescribe_medication for '{med}' without interaction check")
                blocked_ids.add(call["id"])

    to_execute = [c for c in calls if c["id"] not in blocked_ids]
    pd = state.get("patient_data", {})
    for call in to_execute:
        if call["name"] == "check_drug_interactions":
            call["args"].setdefault("current_medications", pd.get("medications", {}).get("current", []))
            call["args"].setdefault("allergies", pd.get("allergies", []))

    messages: List[ToolMessage] = []
    warnings: List[str]     = []
    try:
        responses = tool_executor.batch(to_execute, return_exceptions=True)
        for call, resp in zip(to_execute, responses):
            if isinstance(resp, Exception):
                logger.exception(f"Error executing tool {call['name']}")
                content = json.dumps({"status": "error", "message": str(resp)})
            else:
                content = str(resp)
                if call["name"] == "check_drug_interactions":
                    data = json.loads(content)
                    if data.get("status") == "warning":
                        warnings.extend(data.get("warnings", []))
            messages.append(ToolMessage(content=content, tool_call_id=call["id"], name=call["name"]))
    except Exception as e:
        logger.exception("Critical error in tool_node")
        for call in to_execute:
            messages.append(ToolMessage(
                content=json.dumps({"status": "error", "message": str(e)}),
                tool_call_id=call["id"],
                name=call["name"]
            ))
    return {"messages": messages, "interaction_warnings": warnings or None}

def reflection_node(state: AgentState) -> Dict[str, Any]:
    warns = state.get("interaction_warnings")
    if not warns:
        logger.warning("reflection_node called without warnings")
        return {"messages": [], "interaction_warnings": None}

    triggering = None
    for msg in reversed(state["messages"]):
        if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
            triggering = msg
            break

    if not triggering:
        return {"messages": [AIMessage(content="Internal Error: reflection context missing.")], "interaction_warnings": None}

    prompt = (
        "You are SynapseAI, performing a focused safety review of the following plan:\n\n"
        f"{triggering.content}\n\n"
        "Highlight any issues based on these warnings:\n" +
        "\n".join(f"- {w}" for w in warns)
    )
    try:
        resp = llm.invoke([SystemMessage(content="Safety reflection"), HumanMessage(content=prompt)])
        return {"messages": [AIMessage(content=resp.content)], "interaction_warnings": None}
    except Exception as e:
        logger.exception("Error during reflection")
        return {"messages": [AIMessage(content=f"Error during reflection: {e}")], "interaction_warnings": None}

# ── Routing Functions ────────────────────────────────────────────────────────
def should_continue(state: AgentState) -> str:
    last = state["messages"][-1] if state["messages"] else None
    if not isinstance(last, AIMessage):
        return "end_conversation_turn"
    if getattr(last, "tool_calls", None):
        return "continue_tools"
    return "end_conversation_turn"

def after_tools_router(state: AgentState) -> str:
    return "reflection" if state.get("interaction_warnings") else "agent"

# ── ClinicalAgent ────────────────────────────────────────────────────────────
class ClinicalAgent:
    def __init__(self):
        logger.info("Building ClinicalAgent workflow")
        wf = StateGraph(AgentState)
        wf.add_node("agent", agent_node)
        wf.add_node("tools", tool_node)
        wf.add_node("reflection", reflection_node)
        wf.set_entry_point("agent")
        wf.add_conditional_edges("agent", should_continue, {
            "continue_tools": "tools",
            "end_conversation_turn": END
        })
        wf.add_conditional_edges("tools", after_tools_router, {
            "reflection": "reflection",
            "agent": "agent"
        })
        wf.add_edge("reflection", "agent")
        self.graph_app = wf.compile()
        logger.info("ClinicalAgent ready")

    def invoke_turn(self, state: Dict[str, Any]) -> Dict[str, Any]:
        try:
            result = self.graph_app.invoke(state, {"recursion_limit": 15})
            result.setdefault("summary", state.get("summary"))
            result.setdefault("interaction_warnings", None)
            return result
        except Exception as e:
            logger.exception("Error during graph invocation")
            return {
                "messages": state.get("messages", []) + [AIMessage(content=f"Error: {e}")],
                "patient_data": state.get("patient_data"),
                "summary": state.get("summary"),
                "interaction_warnings": None
            }