Update agent.py
Browse files
agent.py
CHANGED
@@ -43,8 +43,8 @@ class ClinicalPrompts:
|
|
43 |
def wrap_message(msg: Any) -> AIMessage:
|
44 |
"""
|
45 |
Ensures the given message is an AIMessage.
|
46 |
-
If it is a dict, extracts the 'content' field (or serializes the dict).
|
47 |
-
Otherwise, converts the message to a string.
|
48 |
"""
|
49 |
if isinstance(msg, AIMessage):
|
50 |
return msg
|
@@ -358,7 +358,6 @@ def tool_node(state: AgentState) -> Dict[str, Any]:
|
|
358 |
new_state = {"messages": []}
|
359 |
return propagate_state(new_state, state)
|
360 |
last = wrap_message(messages_list[-1])
|
361 |
-
# Safely retrieve pending tool_calls
|
362 |
tool_calls = last.__dict__.get("tool_calls")
|
363 |
if not (isinstance(last, AIMessage) and tool_calls):
|
364 |
logger.warning("tool_node invoked without pending tool_calls")
|
@@ -464,7 +463,7 @@ def should_continue(state: AgentState) -> str:
|
|
464 |
state["done"] = True
|
465 |
return "end_conversation_turn"
|
466 |
state["done"] = False
|
467 |
-
return "
|
468 |
|
469 |
def after_tools_router(state: AgentState) -> str:
|
470 |
if state.get("interaction_warnings"):
|
@@ -476,11 +475,11 @@ class ClinicalAgent:
|
|
476 |
def __init__(self):
|
477 |
logger.info("Building ClinicalAgent workflow")
|
478 |
wf = StateGraph(AgentState)
|
479 |
-
wf.add_node("
|
480 |
wf.add_node("tools", tool_node)
|
481 |
wf.add_node("reflection", reflection_node)
|
482 |
-
wf.set_entry_point("
|
483 |
-
wf.add_conditional_edges("
|
484 |
"continue_tools": "tools",
|
485 |
"end_conversation_turn": END
|
486 |
})
|
@@ -488,7 +487,7 @@ class ClinicalAgent:
|
|
488 |
"reflection": "reflection",
|
489 |
"end_conversation_turn": END
|
490 |
})
|
491 |
-
# Removed edge from reflection back to
|
492 |
self.graph_app = wf.compile()
|
493 |
logger.info("ClinicalAgent ready")
|
494 |
|
|
|
43 |
def wrap_message(msg: Any) -> AIMessage:
|
44 |
"""
|
45 |
Ensures the given message is an AIMessage.
|
46 |
+
If it is a dict, it extracts the 'content' field (or serializes the dict).
|
47 |
+
Otherwise, it converts the message to a string.
|
48 |
"""
|
49 |
if isinstance(msg, AIMessage):
|
50 |
return msg
|
|
|
358 |
new_state = {"messages": []}
|
359 |
return propagate_state(new_state, state)
|
360 |
last = wrap_message(messages_list[-1])
|
|
|
361 |
tool_calls = last.__dict__.get("tool_calls")
|
362 |
if not (isinstance(last, AIMessage) and tool_calls):
|
363 |
logger.warning("tool_node invoked without pending tool_calls")
|
|
|
463 |
state["done"] = True
|
464 |
return "end_conversation_turn"
|
465 |
state["done"] = False
|
466 |
+
return "start"
|
467 |
|
468 |
def after_tools_router(state: AgentState) -> str:
|
469 |
if state.get("interaction_warnings"):
|
|
|
475 |
def __init__(self):
|
476 |
logger.info("Building ClinicalAgent workflow")
|
477 |
wf = StateGraph(AgentState)
|
478 |
+
wf.add_node("start", agent_node)
|
479 |
wf.add_node("tools", tool_node)
|
480 |
wf.add_node("reflection", reflection_node)
|
481 |
+
wf.set_entry_point("start")
|
482 |
+
wf.add_conditional_edges("start", should_continue, {
|
483 |
"continue_tools": "tools",
|
484 |
"end_conversation_turn": END
|
485 |
})
|
|
|
487 |
"reflection": "reflection",
|
488 |
"end_conversation_turn": END
|
489 |
})
|
490 |
+
# Removed edge from reflection back to start.
|
491 |
self.graph_app = wf.compile()
|
492 |
logger.info("ClinicalAgent ready")
|
493 |
|