mgbam commited on
Commit
bf46e9b
·
verified ·
1 Parent(s): caa98a8

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +7 -8
agent.py CHANGED
@@ -43,8 +43,8 @@ class ClinicalPrompts:
43
  def wrap_message(msg: Any) -> AIMessage:
44
  """
45
  Ensures the given message is an AIMessage.
46
- If it is a dict, extracts the 'content' field (or serializes the dict).
47
- Otherwise, converts the message to a string.
48
  """
49
  if isinstance(msg, AIMessage):
50
  return msg
@@ -358,7 +358,6 @@ def tool_node(state: AgentState) -> Dict[str, Any]:
358
  new_state = {"messages": []}
359
  return propagate_state(new_state, state)
360
  last = wrap_message(messages_list[-1])
361
- # Safely retrieve pending tool_calls
362
  tool_calls = last.__dict__.get("tool_calls")
363
  if not (isinstance(last, AIMessage) and tool_calls):
364
  logger.warning("tool_node invoked without pending tool_calls")
@@ -464,7 +463,7 @@ def should_continue(state: AgentState) -> str:
464
  state["done"] = True
465
  return "end_conversation_turn"
466
  state["done"] = False
467
- return "agent"
468
 
469
  def after_tools_router(state: AgentState) -> str:
470
  if state.get("interaction_warnings"):
@@ -476,11 +475,11 @@ class ClinicalAgent:
476
  def __init__(self):
477
  logger.info("Building ClinicalAgent workflow")
478
  wf = StateGraph(AgentState)
479
- wf.add_node("agent", agent_node)
480
  wf.add_node("tools", tool_node)
481
  wf.add_node("reflection", reflection_node)
482
- wf.set_entry_point("agent")
483
- wf.add_conditional_edges("agent", should_continue, {
484
  "continue_tools": "tools",
485
  "end_conversation_turn": END
486
  })
@@ -488,7 +487,7 @@ class ClinicalAgent:
488
  "reflection": "reflection",
489
  "end_conversation_turn": END
490
  })
491
- # Removed edge from reflection back to agent.
492
  self.graph_app = wf.compile()
493
  logger.info("ClinicalAgent ready")
494
 
 
43
  def wrap_message(msg: Any) -> AIMessage:
44
  """
45
  Ensures the given message is an AIMessage.
46
+ If it is a dict, it extracts the 'content' field (or serializes the dict).
47
+ Otherwise, it converts the message to a string.
48
  """
49
  if isinstance(msg, AIMessage):
50
  return msg
 
358
  new_state = {"messages": []}
359
  return propagate_state(new_state, state)
360
  last = wrap_message(messages_list[-1])
 
361
  tool_calls = last.__dict__.get("tool_calls")
362
  if not (isinstance(last, AIMessage) and tool_calls):
363
  logger.warning("tool_node invoked without pending tool_calls")
 
463
  state["done"] = True
464
  return "end_conversation_turn"
465
  state["done"] = False
466
+ return "start"
467
 
468
  def after_tools_router(state: AgentState) -> str:
469
  if state.get("interaction_warnings"):
 
475
  def __init__(self):
476
  logger.info("Building ClinicalAgent workflow")
477
  wf = StateGraph(AgentState)
478
+ wf.add_node("start", agent_node)
479
  wf.add_node("tools", tool_node)
480
  wf.add_node("reflection", reflection_node)
481
+ wf.set_entry_point("start")
482
+ wf.add_conditional_edges("start", should_continue, {
483
  "continue_tools": "tools",
484
  "end_conversation_turn": END
485
  })
 
487
  "reflection": "reflection",
488
  "end_conversation_turn": END
489
  })
490
+ # Removed edge from reflection back to start.
491
  self.graph_app = wf.compile()
492
  logger.info("ClinicalAgent ready")
493