mgbam commited on
Commit
caaaced
Β·
verified Β·
1 Parent(s): e22139a

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +33 -13
agent.py CHANGED
@@ -39,6 +39,20 @@ class ClinicalPrompts:
39
  [SYSTEM PROMPT CONTENT HERE]
40
  """
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  # ── Helper Functions ─────────────────────────────────────────────────────
43
  UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
44
  RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
@@ -318,22 +332,30 @@ def agent_node(state: AgentState) -> Dict[str, Any]:
318
  logger.info(f"Invoking LLM with {len(msgs)} messages")
319
  try:
320
  response = model_with_tools.invoke(msgs)
 
321
  new_state = {"messages": [response]}
322
  return propagate_state(new_state, state)
323
  except Exception as e:
324
  logger.exception("Error in agent_node")
325
- new_state = {"messages": [AIMessage(content=f"Error: {e}")]}
326
  return propagate_state(new_state, state)
327
 
328
  def tool_node(state: AgentState) -> Dict[str, Any]:
329
  if state.get("done", False):
330
  return state
331
- last = state.get("messages", [])[-1]
332
- if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None):
 
 
 
 
 
 
 
333
  logger.warning("tool_node invoked without pending tool_calls")
334
  new_state = {"messages": []}
335
  return propagate_state(new_state, state)
336
- calls = last.tool_calls
337
  blocked_ids = set()
338
  for call in calls:
339
  if call["name"] == "prescribe_medication":
@@ -387,8 +409,9 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
387
  return propagate_state(new_state, state)
388
  triggering = None
389
  for msg in reversed(state.get("messages", [])):
390
- if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
391
- triggering = msg
 
392
  break
393
  if not triggering:
394
  new_state = {"messages": [AIMessage(content="Internal Error: reflection context missing.")]}
@@ -401,7 +424,7 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
401
  )
402
  try:
403
  resp = llm.invoke([SystemMessage(content="Safety reflection"), HumanMessage(content=prompt)])
404
- new_state = {"messages": [AIMessage(content=resp.content)]}
405
  return propagate_state(new_state, state)
406
  except Exception as e:
407
  logger.exception("Error during reflection")
@@ -413,7 +436,6 @@ def should_continue(state: AgentState) -> str:
413
  state.setdefault("iterations", 0)
414
  state["iterations"] += 1
415
  logger.info(f"Iteration count: {state['iterations']}")
416
- # When iterations exceed threshold, force final output and terminate.
417
  if state["iterations"] >= 4:
418
  state.setdefault("messages", []).append(AIMessage(content="Final output: consultation complete."))
419
  state["done"] = True
@@ -421,11 +443,11 @@ def should_continue(state: AgentState) -> str:
421
  if not state.get("messages"):
422
  state["done"] = True
423
  return "end_conversation_turn"
424
- last = state["messages"][-1]
425
  if not isinstance(last, AIMessage):
426
  state["done"] = True
427
  return "end_conversation_turn"
428
- if getattr(last, "tool_calls", None):
429
  return "continue_tools"
430
  if "consultation complete" in last.content.lower():
431
  state["done"] = True
@@ -434,7 +456,6 @@ def should_continue(state: AgentState) -> str:
434
  return "agent"
435
 
436
  def after_tools_router(state: AgentState) -> str:
437
- # Instead of routing back to agent, route reflection to END to break the cycle.
438
  if state.get("interaction_warnings"):
439
  return "reflection"
440
  return "end_conversation_turn"
@@ -456,13 +477,12 @@ class ClinicalAgent:
456
  "reflection": "reflection",
457
  "end_conversation_turn": END
458
  })
459
- # Removed the edge from reflection back to agent to break the cycle.
460
  self.graph_app = wf.compile()
461
  logger.info("ClinicalAgent ready")
462
 
463
  def invoke_turn(self, state: Dict[str, Any]) -> Dict[str, Any]:
464
  try:
465
- # Increase recursion limit if needed.
466
  result = self.graph_app.invoke(state, {"recursion_limit": 100})
467
  result.setdefault("summary", state.get("summary"))
468
  result.setdefault("interaction_warnings", None)
 
39
  [SYSTEM PROMPT CONTENT HERE]
40
  """
41
 
42
+ # ── Helper: Message Wrapper ─────────────────────────────────────────────
43
+ def wrap_message(msg: Any) -> AIMessage:
44
+ """
45
+ Ensures that the given message is an AIMessage.
46
+ If it is a dict, it extracts the 'content' field (or serializes the dict).
47
+ Otherwise, it converts the message to a string.
48
+ """
49
+ if isinstance(msg, AIMessage):
50
+ return msg
51
+ elif isinstance(msg, dict):
52
+ return AIMessage(content=msg.get("content", json.dumps(msg)))
53
+ else:
54
+ return AIMessage(content=str(msg))
55
+
56
  # ── Helper Functions ─────────────────────────────────────────────────────
57
  UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
58
  RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
 
332
  logger.info(f"Invoking LLM with {len(msgs)} messages")
333
  try:
334
  response = model_with_tools.invoke(msgs)
335
+ response = wrap_message(response)
336
  new_state = {"messages": [response]}
337
  return propagate_state(new_state, state)
338
  except Exception as e:
339
  logger.exception("Error in agent_node")
340
+ new_state = {"messages": [wrap_message(AIMessage(content=f"Error: {e}"))]}
341
  return propagate_state(new_state, state)
342
 
343
  def tool_node(state: AgentState) -> Dict[str, Any]:
344
  if state.get("done", False):
345
  return state
346
+ messages_list = state.get("messages", [])
347
+ if not messages_list:
348
+ logger.warning("tool_node invoked with no messages")
349
+ new_state = {"messages": []}
350
+ return propagate_state(new_state, state)
351
+ last = wrap_message(messages_list[-1])
352
+ # Check for pending tool_calls using dict.get if necessary.
353
+ tool_calls = last.tool_calls if hasattr(last, "tool_calls") else last.__dict__.get("tool_calls")
354
+ if not (isinstance(last, AIMessage) and tool_calls):
355
  logger.warning("tool_node invoked without pending tool_calls")
356
  new_state = {"messages": []}
357
  return propagate_state(new_state, state)
358
+ calls = tool_calls
359
  blocked_ids = set()
360
  for call in calls:
361
  if call["name"] == "prescribe_medication":
 
409
  return propagate_state(new_state, state)
410
  triggering = None
411
  for msg in reversed(state.get("messages", [])):
412
+ wrapped = wrap_message(msg)
413
+ if isinstance(wrapped, AIMessage) and wrapped.__dict__.get("tool_calls"):
414
+ triggering = wrapped
415
  break
416
  if not triggering:
417
  new_state = {"messages": [AIMessage(content="Internal Error: reflection context missing.")]}
 
424
  )
425
  try:
426
  resp = llm.invoke([SystemMessage(content="Safety reflection"), HumanMessage(content=prompt)])
427
+ new_state = {"messages": [wrap_message(resp)]}
428
  return propagate_state(new_state, state)
429
  except Exception as e:
430
  logger.exception("Error during reflection")
 
436
  state.setdefault("iterations", 0)
437
  state["iterations"] += 1
438
  logger.info(f"Iteration count: {state['iterations']}")
 
439
  if state["iterations"] >= 4:
440
  state.setdefault("messages", []).append(AIMessage(content="Final output: consultation complete."))
441
  state["done"] = True
 
443
  if not state.get("messages"):
444
  state["done"] = True
445
  return "end_conversation_turn"
446
+ last = wrap_message(state["messages"][-1])
447
  if not isinstance(last, AIMessage):
448
  state["done"] = True
449
  return "end_conversation_turn"
450
+ if last.__dict__.get("tool_calls"):
451
  return "continue_tools"
452
  if "consultation complete" in last.content.lower():
453
  state["done"] = True
 
456
  return "agent"
457
 
458
  def after_tools_router(state: AgentState) -> str:
 
459
  if state.get("interaction_warnings"):
460
  return "reflection"
461
  return "end_conversation_turn"
 
477
  "reflection": "reflection",
478
  "end_conversation_turn": END
479
  })
480
+ # Removed edge from reflection to agent to break cycle.
481
  self.graph_app = wf.compile()
482
  logger.info("ClinicalAgent ready")
483
 
484
  def invoke_turn(self, state: Dict[str, Any]) -> Dict[str, Any]:
485
  try:
 
486
  result = self.graph_app.invoke(state, {"recursion_limit": 100})
487
  result.setdefault("summary", state.get("summary"))
488
  result.setdefault("interaction_warnings", None)