mgbam commited on
Commit
a90f7d4
Β·
verified Β·
1 Parent(s): 86911ce

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +66 -54
agent.py CHANGED
@@ -22,8 +22,8 @@ logger = logging.getLogger(__name__)
22
  logging.basicConfig(level=logging.INFO)
23
 
24
  # ── Environment Variables ─────────────────────────────────────────────────────
25
- UMLS_API_KEY = os.getenv("UMLS_API_KEY")
26
- GROQ_API_KEY = os.getenv("GROQ_API_KEY")
27
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
28
 
29
  if not all([UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY]):
@@ -31,8 +31,8 @@ if not all([UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY]):
31
  raise RuntimeError("Missing required API keys")
32
 
33
  # ── Agent Configuration ───────────────────────────────────────────────────────
34
- AGENT_MODEL_NAME = "llama3-70b-8192"
35
- AGENT_TEMPERATURE = 0.1
36
  MAX_SEARCH_RESULTS = 3
37
 
38
  class ClinicalPrompts:
@@ -42,30 +42,32 @@ class ClinicalPrompts:
42
  """
43
 
44
  # ── Helper Functions ──────────────────────────────────────────────────────────
45
- UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
46
- RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
47
- OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
48
 
49
  @lru_cache(maxsize=256)
50
  def get_rxcui(drug_name: str) -> Optional[str]:
51
- """Lookup RxNorm CUI for a drug name."""
52
  drug_name = (drug_name or "").strip()
53
  if not drug_name:
54
  return None
55
  logger.info(f"Looking up RxCUI for '{drug_name}'")
56
  try:
 
57
  params = {"name": drug_name, "search": 1}
58
  r = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10)
59
  r.raise_for_status()
60
- data = r.json().get("idGroup", {})
61
- if ids := data.get("rxnormId"):
62
  logger.info(f"Found RxCUI {ids[0]} for '{drug_name}'")
63
  return ids[0]
64
- # fallback to broader search
65
  r = requests.get(f"{RXNORM_API_BASE}/drugs.json", params={"name": drug_name}, timeout=10)
66
  r.raise_for_status()
67
- for group in r.json().get("drugGroup", {}).get("conceptGroup", []):
68
- if props := group.get("conceptProperties"):
 
69
  logger.info(f"Found RxCUI {props[0]['rxcui']} via /drugs for '{drug_name}'")
70
  return props[0]["rxcui"]
71
  except Exception:
@@ -74,7 +76,7 @@ def get_rxcui(drug_name: str) -> Optional[str]:
74
 
75
  @lru_cache(maxsize=128)
76
  def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[Dict[str, Any]]:
77
- """Fetch label data from OpenFDA by RxCUI or drug name."""
78
  if not (rxcui or drug_name):
79
  return None
80
  terms = []
@@ -96,7 +98,7 @@ def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = No
96
  return None
97
 
98
  def search_text_list(texts: List[str], terms: List[str]) -> List[str]:
99
- """Return snippets from texts containing any of the search terms."""
100
  snippets = []
101
  lowers = [t.lower() for t in terms if t]
102
  for text in texts or []:
@@ -105,7 +107,7 @@ def search_text_list(texts: List[str], terms: List[str]) -> List[str]:
105
  if term in tl:
106
  i = tl.find(term)
107
  start = max(0, i - 50)
108
- end = min(len(text), i + len(term) + 100)
109
  snippet = text[start:end]
110
  snippet = re.sub(f"({re.escape(term)})", r"**\1**", snippet, flags=re.IGNORECASE)
111
  snippets.append(f"...{snippet}...")
@@ -113,34 +115,37 @@ def search_text_list(texts: List[str], terms: List[str]) -> List[str]:
113
  return snippets
114
 
115
  def parse_bp(bp: str) -> Optional[tuple[int, int]]:
116
- """Parse a 'SYS/DIA' blood pressure string into a tuple."""
117
  if m := re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", (bp or "").strip()):
118
  return int(m.group(1)), int(m.group(2))
119
  return None
120
 
121
  def check_red_flags(patient_data: Dict[str, Any]) -> List[str]:
122
- """Identify immediate red flags in patient data."""
123
  flags: List[str] = []
124
- hpi = patient_data.get("hpi", {})
125
  vitals = patient_data.get("vitals", {})
126
- symptoms = [s.lower() for s in hpi.get("symptoms", []) if isinstance(s, str)]
127
- # Symptom-based
 
128
  mapping = {
129
  "chest pain": "Chest pain reported",
130
  "shortness of breath": "Shortness of breath reported",
131
  "severe headache": "Severe headache reported",
132
- "syncope": "Syncope (fainting) reported",
133
- "hemoptysis": "Hemoptysis (coughing blood) reported"
134
  }
135
  for term, desc in mapping.items():
136
- if term in symptoms:
137
  flags.append(f"Red Flag: {desc}.")
138
- # Vitals-based
 
139
  temp = vitals.get("temp_c")
140
  hr = vitals.get("hr_bpm")
141
  rr = vitals.get("rr_rpm")
142
  spo2 = vitals.get("spo2_percent")
143
  bp = parse_bp(vitals.get("bp_mmhg", ""))
 
144
  if temp is not None and temp >= 38.5:
145
  flags.append(f"Red Flag: Fever ({temp}Β°C).")
146
  if hr is not None:
@@ -158,10 +163,11 @@ def check_red_flags(patient_data: Dict[str, Any]) -> List[str]:
158
  flags.append(f"Red Flag: Hypertensive urgency/emergency ({sys}/{dia} mmHg).")
159
  if sys <= 90 or dia <= 60:
160
  flags.append(f"Red Flag: Hypotension ({sys}/{dia} mmHg).")
161
- return list(dict.fromkeys(flags)) # preserve order, dedupe
 
162
 
163
  def format_patient_data_for_prompt(data: Dict[str, Any]) -> str:
164
- """Convert patient_data dict into a markdown-like prompt string."""
165
  if not data:
166
  return "No patient data provided."
167
  lines: List[str] = []
@@ -181,29 +187,30 @@ def format_patient_data_for_prompt(data: Dict[str, Any]) -> str:
181
  # ── Tool Input Schemas ────────────────────────────────────────────────────────
182
  class LabOrderInput(BaseModel):
183
  test_name: str = Field(...)
184
- reason: str = Field(...)
185
- priority: str = Field("Routine")
186
 
187
  class PrescriptionInput(BaseModel):
188
- medication_name: str = Field(...)
189
- dosage: str = Field(...)
190
- route: str = Field(...)
191
- frequency: str = Field(...)
192
- duration: str = Field("As directed")
193
- reason: str = Field(...)
194
 
195
  class InteractionCheckInput(BaseModel):
196
- potential_prescription: str = Field(...)
197
  current_medications: Optional[List[str]] = Field(None)
198
- allergies: Optional[List[str]] = Field(None)
199
 
200
  class FlagRiskInput(BaseModel):
201
  risk_description: str = Field(...)
202
- urgency: str = Field("High")
203
 
204
  # ── Tool Implementations ──────────────────────────────────────────────────────
205
  @tool("order_lab_test", args_schema=LabOrderInput)
206
  def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
 
207
  logger.info(f"Ordering lab test: {test_name}, reason: {reason}, priority: {priority}")
208
  return json.dumps({
209
  "status": "success",
@@ -220,6 +227,7 @@ def prescribe_medication(
220
  duration: str,
221
  reason: str
222
  ) -> str:
 
223
  logger.info(f"Preparing prescription: {medication_name} {dosage}, route: {route}, freq: {frequency}")
224
  return json.dumps({
225
  "status": "success",
@@ -233,15 +241,17 @@ def check_drug_interactions(
233
  current_medications: Optional[List[str]] = None,
234
  allergies: Optional[List[str]] = None
235
  ) -> str:
 
236
  logger.info(f"Checking interactions for: {potential_prescription}")
237
  warnings: List[str] = []
238
  pm = [m.lower().strip() for m in (current_medications or []) if m]
239
  al = [a.lower().strip() for a in (allergies or []) if a]
240
 
241
- # Allergy checks
242
  if potential_prescription.lower().strip() in al:
243
  warnings.append(f"CRITICAL ALLERGY: Patient allergic to '{potential_prescription}'.")
244
- # RxNorm/OpenFDA lookups
 
245
  rxcui = get_rxcui(potential_prescription)
246
  label = get_openfda_label(rxcui=rxcui, drug_name=potential_prescription)
247
  if not (rxcui or label):
@@ -259,16 +269,15 @@ def check_drug_interactions(
259
  for med in pm:
260
  mrxcui = get_rxcui(med)
261
  mlabel = get_openfda_label(rxcui=mrxcui, drug_name=med)
262
- # check in both labels
263
  for sec in ("drug_interactions",):
264
  for src_label, src_name in ((label, potential_prescription), (mlabel, med)):
265
  items = src_label.get(sec) if src_label else None
266
  if isinstance(items, list):
267
- snippets = search_text_list(items, [med if src_name==potential_prescription else potential_prescription])
268
  if snippets:
269
  warnings.append(f"Interaction ({src_name} label): {'; '.join(snippets)}")
270
 
271
- status = "warning" if warnings else "clear"
272
  message = (
273
  f"{len(warnings)} issue(s) found for '{potential_prescription}'."
274
  if warnings else
@@ -278,19 +287,21 @@ def check_drug_interactions(
278
 
279
  @tool("flag_risk", args_schema=FlagRiskInput)
280
  def flag_risk(risk_description: str, urgency: str = "High") -> str:
 
281
  logger.info(f"Flagging risk: {risk_description} (urgency={urgency})")
282
  return json.dumps({
283
  "status": "flagged",
284
  "message": f"Risk '{risk_description}' flagged with {urgency} urgency."
285
  })
286
 
 
287
  search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results")
288
- all_tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool]
289
 
290
  # ── LLM & Tool Executor ──────────────────────────────────────────────────────
291
- llm = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME)
292
  model_with_tools = llm.bind_tools(all_tools)
293
- tool_executor = ToolExecutor(all_tools)
294
 
295
  # ── State Definition ──────────────────────────────────────────────────────────
296
  class AgentState(TypedDict):
@@ -319,17 +330,19 @@ def tool_node(state: AgentState) -> Dict[str, Any]:
319
  return {"messages": [], "interaction_warnings": None}
320
 
321
  calls = last.tool_calls
322
- # Safety: require interaction check before prescribing
323
  blocked_ids = set()
324
  for call in calls:
325
  if call["name"] == "prescribe_medication":
326
  med = call["args"].get("medication_name", "").lower()
327
- if not any(c["name"] == "check_drug_interactions" and c["args"].get("potential_prescription","").lower() == med for c in calls):
 
 
 
 
328
  logger.warning(f"Blocking prescribe_medication for '{med}' without interaction check")
329
  blocked_ids.add(call["id"])
330
 
331
  to_execute = [c for c in calls if c["id"] not in blocked_ids]
332
- # Augment interaction checks with patient data
333
  pd = state.get("patient_data", {})
334
  for call in to_execute:
335
  if call["name"] == "check_drug_interactions":
@@ -337,7 +350,7 @@ def tool_node(state: AgentState) -> Dict[str, Any]:
337
  call["args"].setdefault("allergies", pd.get("allergies", []))
338
 
339
  messages: List[ToolMessage] = []
340
- warnings: List[str] = []
341
  try:
342
  responses = tool_executor.batch(to_execute, return_exceptions=True)
343
  for call, resp in zip(to_execute, responses):
@@ -353,7 +366,6 @@ def tool_node(state: AgentState) -> Dict[str, Any]:
353
  messages.append(ToolMessage(content=content, tool_call_id=call["id"], name=call["name"]))
354
  except Exception as e:
355
  logger.exception("Critical error in tool_node")
356
- # return an error message for each pending call
357
  for call in to_execute:
358
  messages.append(ToolMessage(
359
  content=json.dumps({"status": "error", "message": str(e)}),
@@ -368,8 +380,7 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
368
  logger.warning("reflection_node called without warnings")
369
  return {"messages": [], "interaction_warnings": None}
370
 
371
- # Find the AIMessage that triggered the warnings
372
- triggering: Optional[AIMessage] = None
373
  for msg in reversed(state["messages"]):
374
  if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
375
  triggering = msg
@@ -381,7 +392,8 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
381
  prompt = (
382
  "You are SynapseAI, performing a focused safety review of the following plan:\n\n"
383
  f"{triggering.content}\n\n"
384
- "Highlight any issues based on these warnings:\n" + "\n".join(f"- {w}" for w in warns)
 
385
  )
386
  try:
387
  resp = llm.invoke([SystemMessage(content="Safety reflection"), HumanMessage(content=prompt)])
 
22
  logging.basicConfig(level=logging.INFO)
23
 
24
  # ── Environment Variables ─────────────────────────────────────────────────────
25
+ UMLS_API_KEY = os.getenv("UMLS_API_KEY")
26
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
27
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
28
 
29
  if not all([UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY]):
 
31
  raise RuntimeError("Missing required API keys")
32
 
33
  # ── Agent Configuration ───────────────────────────────────────────────────────
34
+ AGENT_MODEL_NAME = "llama3-70b-8192"
35
+ AGENT_TEMPERATURE = 0.1
36
  MAX_SEARCH_RESULTS = 3
37
 
38
  class ClinicalPrompts:
 
42
  """
43
 
44
  # ── Helper Functions ──────────────────────────────────────────────────────────
45
+ UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
46
+ RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
47
+ OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
48
 
49
  @lru_cache(maxsize=256)
50
  def get_rxcui(drug_name: str) -> Optional[str]:
51
+ """Lookup RxNorm CUI for a given drug name."""
52
  drug_name = (drug_name or "").strip()
53
  if not drug_name:
54
  return None
55
  logger.info(f"Looking up RxCUI for '{drug_name}'")
56
  try:
57
+ # First attempt
58
  params = {"name": drug_name, "search": 1}
59
  r = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10)
60
  r.raise_for_status()
61
+ ids = r.json().get("idGroup", {}).get("rxnormId")
62
+ if ids:
63
  logger.info(f"Found RxCUI {ids[0]} for '{drug_name}'")
64
  return ids[0]
65
+ # Fallback search
66
  r = requests.get(f"{RXNORM_API_BASE}/drugs.json", params={"name": drug_name}, timeout=10)
67
  r.raise_for_status()
68
+ for grp in r.json().get("drugGroup", {}).get("conceptGroup", []):
69
+ props = grp.get("conceptProperties")
70
+ if props:
71
  logger.info(f"Found RxCUI {props[0]['rxcui']} via /drugs for '{drug_name}'")
72
  return props[0]["rxcui"]
73
  except Exception:
 
76
 
77
  @lru_cache(maxsize=128)
78
  def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[Dict[str, Any]]:
79
+ """Fetch the OpenFDA label for a drug by RxCUI or name."""
80
  if not (rxcui or drug_name):
81
  return None
82
  terms = []
 
98
  return None
99
 
100
  def search_text_list(texts: List[str], terms: List[str]) -> List[str]:
101
+ """Return highlighted snippets from a list of texts containing any of the search terms."""
102
  snippets = []
103
  lowers = [t.lower() for t in terms if t]
104
  for text in texts or []:
 
107
  if term in tl:
108
  i = tl.find(term)
109
  start = max(0, i - 50)
110
+ end = min(len(text), i + len(term) + 100)
111
  snippet = text[start:end]
112
  snippet = re.sub(f"({re.escape(term)})", r"**\1**", snippet, flags=re.IGNORECASE)
113
  snippets.append(f"...{snippet}...")
 
115
  return snippets
116
 
117
  def parse_bp(bp: str) -> Optional[tuple[int, int]]:
118
+ """Parse 'SYS/DIA' blood pressure string into a (sys, dia) tuple."""
119
  if m := re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", (bp or "").strip()):
120
  return int(m.group(1)), int(m.group(2))
121
  return None
122
 
123
  def check_red_flags(patient_data: Dict[str, Any]) -> List[str]:
124
+ """Identify immediate red flags from patient_data."""
125
  flags: List[str] = []
126
+ hpi = patient_data.get("hpi", {})
127
  vitals = patient_data.get("vitals", {})
128
+ syms = [s.lower() for s in hpi.get("symptoms", []) if isinstance(s, str)]
129
+
130
+ # Symptom-based flags
131
  mapping = {
132
  "chest pain": "Chest pain reported",
133
  "shortness of breath": "Shortness of breath reported",
134
  "severe headache": "Severe headache reported",
135
+ "syncope": "Syncope reported",
136
+ "hemoptysis": "Hemoptysis reported"
137
  }
138
  for term, desc in mapping.items():
139
+ if term in syms:
140
  flags.append(f"Red Flag: {desc}.")
141
+
142
+ # Vitals-based flags
143
  temp = vitals.get("temp_c")
144
  hr = vitals.get("hr_bpm")
145
  rr = vitals.get("rr_rpm")
146
  spo2 = vitals.get("spo2_percent")
147
  bp = parse_bp(vitals.get("bp_mmhg", ""))
148
+
149
  if temp is not None and temp >= 38.5:
150
  flags.append(f"Red Flag: Fever ({temp}Β°C).")
151
  if hr is not None:
 
163
  flags.append(f"Red Flag: Hypertensive urgency/emergency ({sys}/{dia} mmHg).")
164
  if sys <= 90 or dia <= 60:
165
  flags.append(f"Red Flag: Hypotension ({sys}/{dia} mmHg).")
166
+
167
+ return list(dict.fromkeys(flags)) # dedupe, preserve order
168
 
169
  def format_patient_data_for_prompt(data: Dict[str, Any]) -> str:
170
+ """Format patient_data dict into a markdown-like prompt section."""
171
  if not data:
172
  return "No patient data provided."
173
  lines: List[str] = []
 
187
  # ── Tool Input Schemas ────────────────────────────────────────────────────────
188
  class LabOrderInput(BaseModel):
189
  test_name: str = Field(...)
190
+ reason: str = Field(...)
191
+ priority: str = Field("Routine")
192
 
193
  class PrescriptionInput(BaseModel):
194
+ medication_name: str = Field(...)
195
+ dosage: str = Field(...)
196
+ route: str = Field(...)
197
+ frequency: str = Field(...)
198
+ duration: str = Field("As directed")
199
+ reason: str = Field(...)
200
 
201
  class InteractionCheckInput(BaseModel):
202
+ potential_prescription: str
203
  current_medications: Optional[List[str]] = Field(None)
204
+ allergies: Optional[List[str]] = Field(None)
205
 
206
  class FlagRiskInput(BaseModel):
207
  risk_description: str = Field(...)
208
+ urgency: str = Field("High")
209
 
210
  # ── Tool Implementations ──────────────────────────────────────────────────────
211
  @tool("order_lab_test", args_schema=LabOrderInput)
212
  def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
213
+ """Place an order for a laboratory test."""
214
  logger.info(f"Ordering lab test: {test_name}, reason: {reason}, priority: {priority}")
215
  return json.dumps({
216
  "status": "success",
 
227
  duration: str,
228
  reason: str
229
  ) -> str:
230
+ """Prepare a medication prescription."""
231
  logger.info(f"Preparing prescription: {medication_name} {dosage}, route: {route}, freq: {frequency}")
232
  return json.dumps({
233
  "status": "success",
 
241
  current_medications: Optional[List[str]] = None,
242
  allergies: Optional[List[str]] = None
243
  ) -> str:
244
+ """Check for drug–drug interactions and allergy risks."""
245
  logger.info(f"Checking interactions for: {potential_prescription}")
246
  warnings: List[str] = []
247
  pm = [m.lower().strip() for m in (current_medications or []) if m]
248
  al = [a.lower().strip() for a in (allergies or []) if a]
249
 
250
+ # Allergy exact match
251
  if potential_prescription.lower().strip() in al:
252
  warnings.append(f"CRITICAL ALLERGY: Patient allergic to '{potential_prescription}'.")
253
+
254
+ # Identify drug via RxNorm/OpenFDA
255
  rxcui = get_rxcui(potential_prescription)
256
  label = get_openfda_label(rxcui=rxcui, drug_name=potential_prescription)
257
  if not (rxcui or label):
 
269
  for med in pm:
270
  mrxcui = get_rxcui(med)
271
  mlabel = get_openfda_label(rxcui=mrxcui, drug_name=med)
 
272
  for sec in ("drug_interactions",):
273
  for src_label, src_name in ((label, potential_prescription), (mlabel, med)):
274
  items = src_label.get(sec) if src_label else None
275
  if isinstance(items, list):
276
+ snippets = search_text_list(items, [med if src_name == potential_prescription else potential_prescription])
277
  if snippets:
278
  warnings.append(f"Interaction ({src_name} label): {'; '.join(snippets)}")
279
 
280
+ status = "warning" if warnings else "clear"
281
  message = (
282
  f"{len(warnings)} issue(s) found for '{potential_prescription}'."
283
  if warnings else
 
287
 
288
  @tool("flag_risk", args_schema=FlagRiskInput)
289
  def flag_risk(risk_description: str, urgency: str = "High") -> str:
290
+ """Flag a clinical risk with given urgency."""
291
  logger.info(f"Flagging risk: {risk_description} (urgency={urgency})")
292
  return json.dumps({
293
  "status": "flagged",
294
  "message": f"Risk '{risk_description}' flagged with {urgency} urgency."
295
  })
296
 
297
+ # Include the Tavily search tool
298
  search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results")
299
+ all_tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool]
300
 
301
  # ── LLM & Tool Executor ──────────────────────────────────────────────────────
302
+ llm = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME)
303
  model_with_tools = llm.bind_tools(all_tools)
304
+ tool_executor = ToolExecutor(all_tools)
305
 
306
  # ── State Definition ──────────────────────────────────────────────────────────
307
  class AgentState(TypedDict):
 
330
  return {"messages": [], "interaction_warnings": None}
331
 
332
  calls = last.tool_calls
 
333
  blocked_ids = set()
334
  for call in calls:
335
  if call["name"] == "prescribe_medication":
336
  med = call["args"].get("medication_name", "").lower()
337
+ if not any(
338
+ c["name"] == "check_drug_interactions" and
339
+ c["args"].get("potential_prescription","").lower() == med
340
+ for c in calls
341
+ ):
342
  logger.warning(f"Blocking prescribe_medication for '{med}' without interaction check")
343
  blocked_ids.add(call["id"])
344
 
345
  to_execute = [c for c in calls if c["id"] not in blocked_ids]
 
346
  pd = state.get("patient_data", {})
347
  for call in to_execute:
348
  if call["name"] == "check_drug_interactions":
 
350
  call["args"].setdefault("allergies", pd.get("allergies", []))
351
 
352
  messages: List[ToolMessage] = []
353
+ warnings: List[str] = []
354
  try:
355
  responses = tool_executor.batch(to_execute, return_exceptions=True)
356
  for call, resp in zip(to_execute, responses):
 
366
  messages.append(ToolMessage(content=content, tool_call_id=call["id"], name=call["name"]))
367
  except Exception as e:
368
  logger.exception("Critical error in tool_node")
 
369
  for call in to_execute:
370
  messages.append(ToolMessage(
371
  content=json.dumps({"status": "error", "message": str(e)}),
 
380
  logger.warning("reflection_node called without warnings")
381
  return {"messages": [], "interaction_warnings": None}
382
 
383
+ triggering = None
 
384
  for msg in reversed(state["messages"]):
385
  if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
386
  triggering = msg
 
392
  prompt = (
393
  "You are SynapseAI, performing a focused safety review of the following plan:\n\n"
394
  f"{triggering.content}\n\n"
395
+ "Highlight any issues based on these warnings:\n" +
396
+ "\n".join(f"- {w}" for w in warns)
397
  )
398
  try:
399
  resp = llm.invoke([SystemMessage(content="Safety reflection"), HumanMessage(content=prompt)])