mgbam commited on
Commit
b5a480c
Β·
verified Β·
1 Parent(s): 53eb85a

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +7 -29
agent.py CHANGED
@@ -1,5 +1,3 @@
1
- # agent.py
2
-
3
  import os
4
  import re
5
  import json
@@ -54,7 +52,6 @@ def get_rxcui(drug_name: str) -> Optional[str]:
54
  return None
55
  logger.info(f"Looking up RxCUI for '{drug_name}'")
56
  try:
57
- # First attempt
58
  params = {"name": drug_name, "search": 1}
59
  r = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10)
60
  r.raise_for_status()
@@ -62,7 +59,6 @@ def get_rxcui(drug_name: str) -> Optional[str]:
62
  if ids:
63
  logger.info(f"Found RxCUI {ids[0]} for '{drug_name}'")
64
  return ids[0]
65
- # Fallback search
66
  r = requests.get(f"{RXNORM_API_BASE}/drugs.json", params={"name": drug_name}, timeout=10)
67
  r.raise_for_status()
68
  for grp in r.json().get("drugGroup", {}).get("conceptGroup", []):
@@ -126,8 +122,6 @@ def check_red_flags(patient_data: Dict[str, Any]) -> List[str]:
126
  hpi = patient_data.get("hpi", {})
127
  vitals = patient_data.get("vitals", {})
128
  syms = [s.lower() for s in hpi.get("symptoms", []) if isinstance(s, str)]
129
-
130
- # Symptom-based flags
131
  mapping = {
132
  "chest pain": "Chest pain reported",
133
  "shortness of breath": "Shortness of breath reported",
@@ -138,14 +132,11 @@ def check_red_flags(patient_data: Dict[str, Any]) -> List[str]:
138
  for term, desc in mapping.items():
139
  if term in syms:
140
  flags.append(f"Red Flag: {desc}.")
141
-
142
- # Vitals-based flags
143
  temp = vitals.get("temp_c")
144
  hr = vitals.get("hr_bpm")
145
  rr = vitals.get("rr_rpm")
146
  spo2 = vitals.get("spo2_percent")
147
  bp = parse_bp(vitals.get("bp_mmhg", ""))
148
-
149
  if temp is not None and temp >= 38.5:
150
  flags.append(f"Red Flag: Fever ({temp}Β°C).")
151
  if hr is not None:
@@ -163,7 +154,6 @@ def check_red_flags(patient_data: Dict[str, Any]) -> List[str]:
163
  flags.append(f"Red Flag: Hypertensive urgency/emergency ({sys}/{dia} mmHg).")
164
  if sys <= 90 or dia <= 60:
165
  flags.append(f"Red Flag: Hypotension ({sys}/{dia} mmHg).")
166
-
167
  return list(dict.fromkeys(flags)) # dedupe, preserve order
168
 
169
  def format_patient_data_for_prompt(data: Dict[str, Any]) -> str:
@@ -210,7 +200,6 @@ class FlagRiskInput(BaseModel):
210
  # ── Tool Implementations ──────────────────────────────────────────────────────
211
  @tool("order_lab_test", args_schema=LabOrderInput)
212
  def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
213
- """Place an order for a laboratory test."""
214
  logger.info(f"Ordering lab test: {test_name}, reason: {reason}, priority: {priority}")
215
  return json.dumps({
216
  "status": "success",
@@ -227,7 +216,6 @@ def prescribe_medication(
227
  duration: str,
228
  reason: str
229
  ) -> str:
230
- """Prepare a medication prescription."""
231
  logger.info(f"Preparing prescription: {medication_name} {dosage}, route: {route}, freq: {frequency}")
232
  return json.dumps({
233
  "status": "success",
@@ -241,31 +229,22 @@ def check_drug_interactions(
241
  current_medications: Optional[List[str]] = None,
242
  allergies: Optional[List[str]] = None
243
  ) -> str:
244
- """Check for drug–drug interactions and allergy risks."""
245
  logger.info(f"Checking interactions for: {potential_prescription}")
246
  warnings: List[str] = []
247
  pm = [m.lower().strip() for m in (current_medications or []) if m]
248
  al = [a.lower().strip() for a in (allergies or []) if a]
249
-
250
- # Allergy exact match
251
  if potential_prescription.lower().strip() in al:
252
  warnings.append(f"CRITICAL ALLERGY: Patient allergic to '{potential_prescription}'.")
253
-
254
- # Identify drug via RxNorm/OpenFDA
255
  rxcui = get_rxcui(potential_prescription)
256
  label = get_openfda_label(rxcui=rxcui, drug_name=potential_prescription)
257
  if not (rxcui or label):
258
  warnings.append(f"INFO: Could not identify '{potential_prescription}'. Checks may be incomplete.")
259
-
260
- # Contraindications & warnings sections
261
  for section in ("contraindications", "warnings_and_cautions", "warnings"):
262
  items = label.get(section) if label else None
263
  if isinstance(items, list):
264
  snippets = search_text_list(items, al)
265
  if snippets:
266
  warnings.append(f"ALLERGY RISK ({section}): {'; '.join(snippets)}")
267
-
268
- # Drug–drug interactions
269
  for med in pm:
270
  mrxcui = get_rxcui(med)
271
  mlabel = get_openfda_label(rxcui=mrxcui, drug_name=med)
@@ -276,7 +255,6 @@ def check_drug_interactions(
276
  snippets = search_text_list(items, [med if src_name == potential_prescription else potential_prescription])
277
  if snippets:
278
  warnings.append(f"Interaction ({src_name} label): {'; '.join(snippets)}")
279
-
280
  status = "warning" if warnings else "clear"
281
  message = (
282
  f"{len(warnings)} issue(s) found for '{potential_prescription}'."
@@ -287,7 +265,6 @@ def check_drug_interactions(
287
 
288
  @tool("flag_risk", args_schema=FlagRiskInput)
289
  def flag_risk(risk_description: str, urgency: str = "High") -> str:
290
- """Flag a clinical risk with given urgency."""
291
  logger.info(f"Flagging risk: {risk_description} (urgency={urgency})")
292
  return json.dumps({
293
  "status": "flagged",
@@ -309,6 +286,7 @@ class AgentState(TypedDict):
309
  patient_data: Optional[Dict[str, Any]]
310
  summary: Optional[str]
311
  interaction_warnings: Optional[List[str]]
 
312
 
313
  # ── Graph Nodes ───────────────────────────────────────────────────────────────
314
  def agent_node(state: AgentState) -> Dict[str, Any]:
@@ -328,7 +306,6 @@ def tool_node(state: AgentState) -> Dict[str, Any]:
328
  if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None):
329
  logger.warning("tool_node invoked without pending tool_calls")
330
  return {"messages": [], "interaction_warnings": None}
331
-
332
  calls = last.tool_calls
333
  blocked_ids = set()
334
  for call in calls:
@@ -341,14 +318,12 @@ def tool_node(state: AgentState) -> Dict[str, Any]:
341
  ):
342
  logger.warning(f"Blocking prescribe_medication for '{med}' without interaction check")
343
  blocked_ids.add(call["id"])
344
-
345
  to_execute = [c for c in calls if c["id"] not in blocked_ids]
346
  pd = state.get("patient_data", {})
347
  for call in to_execute:
348
  if call["name"] == "check_drug_interactions":
349
  call["args"].setdefault("current_medications", pd.get("medications", {}).get("current", []))
350
  call["args"].setdefault("allergies", pd.get("allergies", []))
351
-
352
  messages: List[ToolMessage] = []
353
  warnings: List[str] = []
354
  try:
@@ -379,16 +354,13 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
379
  if not warns:
380
  logger.warning("reflection_node called without warnings")
381
  return {"messages": [], "interaction_warnings": None}
382
-
383
  triggering = None
384
  for msg in reversed(state["messages"]):
385
  if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
386
  triggering = msg
387
  break
388
-
389
  if not triggering:
390
  return {"messages": [AIMessage(content="Internal Error: reflection context missing.")], "interaction_warnings": None}
391
-
392
  prompt = (
393
  "You are SynapseAI, performing a focused safety review of the following plan:\n\n"
394
  f"{triggering.content}\n\n"
@@ -406,9 +378,13 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
406
  def should_continue(state: AgentState) -> str:
407
  last = state["messages"][-1] if state["messages"] else None
408
  if not isinstance(last, AIMessage):
 
 
409
  return "end_conversation_turn"
410
  if getattr(last, "tool_calls", None):
411
  return "continue_tools"
 
 
412
  return "end_conversation_turn"
413
 
414
  def after_tools_router(state: AgentState) -> str:
@@ -432,6 +408,8 @@ class ClinicalAgent:
432
  "agent": "agent"
433
  })
434
  wf.add_edge("reflection", "agent")
 
 
435
  self.graph_app = wf.compile()
436
  logger.info("ClinicalAgent ready")
437
 
 
 
 
1
  import os
2
  import re
3
  import json
 
52
  return None
53
  logger.info(f"Looking up RxCUI for '{drug_name}'")
54
  try:
 
55
  params = {"name": drug_name, "search": 1}
56
  r = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10)
57
  r.raise_for_status()
 
59
  if ids:
60
  logger.info(f"Found RxCUI {ids[0]} for '{drug_name}'")
61
  return ids[0]
 
62
  r = requests.get(f"{RXNORM_API_BASE}/drugs.json", params={"name": drug_name}, timeout=10)
63
  r.raise_for_status()
64
  for grp in r.json().get("drugGroup", {}).get("conceptGroup", []):
 
122
  hpi = patient_data.get("hpi", {})
123
  vitals = patient_data.get("vitals", {})
124
  syms = [s.lower() for s in hpi.get("symptoms", []) if isinstance(s, str)]
 
 
125
  mapping = {
126
  "chest pain": "Chest pain reported",
127
  "shortness of breath": "Shortness of breath reported",
 
132
  for term, desc in mapping.items():
133
  if term in syms:
134
  flags.append(f"Red Flag: {desc}.")
 
 
135
  temp = vitals.get("temp_c")
136
  hr = vitals.get("hr_bpm")
137
  rr = vitals.get("rr_rpm")
138
  spo2 = vitals.get("spo2_percent")
139
  bp = parse_bp(vitals.get("bp_mmhg", ""))
 
140
  if temp is not None and temp >= 38.5:
141
  flags.append(f"Red Flag: Fever ({temp}Β°C).")
142
  if hr is not None:
 
154
  flags.append(f"Red Flag: Hypertensive urgency/emergency ({sys}/{dia} mmHg).")
155
  if sys <= 90 or dia <= 60:
156
  flags.append(f"Red Flag: Hypotension ({sys}/{dia} mmHg).")
 
157
  return list(dict.fromkeys(flags)) # dedupe, preserve order
158
 
159
  def format_patient_data_for_prompt(data: Dict[str, Any]) -> str:
 
200
  # ── Tool Implementations ──────────────────────────────────────────────────────
201
  @tool("order_lab_test", args_schema=LabOrderInput)
202
  def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
 
203
  logger.info(f"Ordering lab test: {test_name}, reason: {reason}, priority: {priority}")
204
  return json.dumps({
205
  "status": "success",
 
216
  duration: str,
217
  reason: str
218
  ) -> str:
 
219
  logger.info(f"Preparing prescription: {medication_name} {dosage}, route: {route}, freq: {frequency}")
220
  return json.dumps({
221
  "status": "success",
 
229
  current_medications: Optional[List[str]] = None,
230
  allergies: Optional[List[str]] = None
231
  ) -> str:
 
232
  logger.info(f"Checking interactions for: {potential_prescription}")
233
  warnings: List[str] = []
234
  pm = [m.lower().strip() for m in (current_medications or []) if m]
235
  al = [a.lower().strip() for a in (allergies or []) if a]
 
 
236
  if potential_prescription.lower().strip() in al:
237
  warnings.append(f"CRITICAL ALLERGY: Patient allergic to '{potential_prescription}'.")
 
 
238
  rxcui = get_rxcui(potential_prescription)
239
  label = get_openfda_label(rxcui=rxcui, drug_name=potential_prescription)
240
  if not (rxcui or label):
241
  warnings.append(f"INFO: Could not identify '{potential_prescription}'. Checks may be incomplete.")
 
 
242
  for section in ("contraindications", "warnings_and_cautions", "warnings"):
243
  items = label.get(section) if label else None
244
  if isinstance(items, list):
245
  snippets = search_text_list(items, al)
246
  if snippets:
247
  warnings.append(f"ALLERGY RISK ({section}): {'; '.join(snippets)}")
 
 
248
  for med in pm:
249
  mrxcui = get_rxcui(med)
250
  mlabel = get_openfda_label(rxcui=mrxcui, drug_name=med)
 
255
  snippets = search_text_list(items, [med if src_name == potential_prescription else potential_prescription])
256
  if snippets:
257
  warnings.append(f"Interaction ({src_name} label): {'; '.join(snippets)}")
 
258
  status = "warning" if warnings else "clear"
259
  message = (
260
  f"{len(warnings)} issue(s) found for '{potential_prescription}'."
 
265
 
266
  @tool("flag_risk", args_schema=FlagRiskInput)
267
  def flag_risk(risk_description: str, urgency: str = "High") -> str:
 
268
  logger.info(f"Flagging risk: {risk_description} (urgency={urgency})")
269
  return json.dumps({
270
  "status": "flagged",
 
286
  patient_data: Optional[Dict[str, Any]]
287
  summary: Optional[str]
288
  interaction_warnings: Optional[List[str]]
289
+ done: Optional[bool]
290
 
291
  # ── Graph Nodes ───────────────────────────────────────────────────────────────
292
  def agent_node(state: AgentState) -> Dict[str, Any]:
 
306
  if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None):
307
  logger.warning("tool_node invoked without pending tool_calls")
308
  return {"messages": [], "interaction_warnings": None}
 
309
  calls = last.tool_calls
310
  blocked_ids = set()
311
  for call in calls:
 
318
  ):
319
  logger.warning(f"Blocking prescribe_medication for '{med}' without interaction check")
320
  blocked_ids.add(call["id"])
 
321
  to_execute = [c for c in calls if c["id"] not in blocked_ids]
322
  pd = state.get("patient_data", {})
323
  for call in to_execute:
324
  if call["name"] == "check_drug_interactions":
325
  call["args"].setdefault("current_medications", pd.get("medications", {}).get("current", []))
326
  call["args"].setdefault("allergies", pd.get("allergies", []))
 
327
  messages: List[ToolMessage] = []
328
  warnings: List[str] = []
329
  try:
 
354
  if not warns:
355
  logger.warning("reflection_node called without warnings")
356
  return {"messages": [], "interaction_warnings": None}
 
357
  triggering = None
358
  for msg in reversed(state["messages"]):
359
  if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
360
  triggering = msg
361
  break
 
362
  if not triggering:
363
  return {"messages": [AIMessage(content="Internal Error: reflection context missing.")], "interaction_warnings": None}
 
364
  prompt = (
365
  "You are SynapseAI, performing a focused safety review of the following plan:\n\n"
366
  f"{triggering.content}\n\n"
 
378
  def should_continue(state: AgentState) -> str:
379
  last = state["messages"][-1] if state["messages"] else None
380
  if not isinstance(last, AIMessage):
381
+ # Mark conversation as done
382
+ state["done"] = True
383
  return "end_conversation_turn"
384
  if getattr(last, "tool_calls", None):
385
  return "continue_tools"
386
+ # No further tool calls – conversation is finished.
387
+ state["done"] = True
388
  return "end_conversation_turn"
389
 
390
  def after_tools_router(state: AgentState) -> str:
 
408
  "agent": "agent"
409
  })
410
  wf.add_edge("reflection", "agent")
411
+ # Set termination condition: stop when state["done"] is True.
412
+ wf.set_termination_condition(lambda state: state.get("done", False))
413
  self.graph_app = wf.compile()
414
  logger.info("ClinicalAgent ready")
415