Update agent.py
Browse files
agent.py
CHANGED
@@ -39,6 +39,20 @@ class ClinicalPrompts:
|
|
39 |
[SYSTEM PROMPT CONTENT HERE]
|
40 |
"""
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
# ββ Helper Functions βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
43 |
UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
|
44 |
RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
|
@@ -318,22 +332,30 @@ def agent_node(state: AgentState) -> Dict[str, Any]:
|
|
318 |
logger.info(f"Invoking LLM with {len(msgs)} messages")
|
319 |
try:
|
320 |
response = model_with_tools.invoke(msgs)
|
|
|
321 |
new_state = {"messages": [response]}
|
322 |
return propagate_state(new_state, state)
|
323 |
except Exception as e:
|
324 |
logger.exception("Error in agent_node")
|
325 |
-
new_state = {"messages": [AIMessage(content=f"Error: {e}")]}
|
326 |
return propagate_state(new_state, state)
|
327 |
|
328 |
def tool_node(state: AgentState) -> Dict[str, Any]:
|
329 |
if state.get("done", False):
|
330 |
return state
|
331 |
-
|
332 |
-
if not
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
logger.warning("tool_node invoked without pending tool_calls")
|
334 |
new_state = {"messages": []}
|
335 |
return propagate_state(new_state, state)
|
336 |
-
calls =
|
337 |
blocked_ids = set()
|
338 |
for call in calls:
|
339 |
if call["name"] == "prescribe_medication":
|
@@ -387,8 +409,9 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
|
|
387 |
return propagate_state(new_state, state)
|
388 |
triggering = None
|
389 |
for msg in reversed(state.get("messages", [])):
|
390 |
-
|
391 |
-
|
|
|
392 |
break
|
393 |
if not triggering:
|
394 |
new_state = {"messages": [AIMessage(content="Internal Error: reflection context missing.")]}
|
@@ -401,7 +424,7 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
|
|
401 |
)
|
402 |
try:
|
403 |
resp = llm.invoke([SystemMessage(content="Safety reflection"), HumanMessage(content=prompt)])
|
404 |
-
new_state = {"messages": [
|
405 |
return propagate_state(new_state, state)
|
406 |
except Exception as e:
|
407 |
logger.exception("Error during reflection")
|
@@ -413,7 +436,6 @@ def should_continue(state: AgentState) -> str:
|
|
413 |
state.setdefault("iterations", 0)
|
414 |
state["iterations"] += 1
|
415 |
logger.info(f"Iteration count: {state['iterations']}")
|
416 |
-
# When iterations exceed threshold, force final output and terminate.
|
417 |
if state["iterations"] >= 4:
|
418 |
state.setdefault("messages", []).append(AIMessage(content="Final output: consultation complete."))
|
419 |
state["done"] = True
|
@@ -421,11 +443,11 @@ def should_continue(state: AgentState) -> str:
|
|
421 |
if not state.get("messages"):
|
422 |
state["done"] = True
|
423 |
return "end_conversation_turn"
|
424 |
-
last = state["messages"][-1]
|
425 |
if not isinstance(last, AIMessage):
|
426 |
state["done"] = True
|
427 |
return "end_conversation_turn"
|
428 |
-
if
|
429 |
return "continue_tools"
|
430 |
if "consultation complete" in last.content.lower():
|
431 |
state["done"] = True
|
@@ -434,7 +456,6 @@ def should_continue(state: AgentState) -> str:
|
|
434 |
return "agent"
|
435 |
|
436 |
def after_tools_router(state: AgentState) -> str:
|
437 |
-
# Instead of routing back to agent, route reflection to END to break the cycle.
|
438 |
if state.get("interaction_warnings"):
|
439 |
return "reflection"
|
440 |
return "end_conversation_turn"
|
@@ -456,13 +477,12 @@ class ClinicalAgent:
|
|
456 |
"reflection": "reflection",
|
457 |
"end_conversation_turn": END
|
458 |
})
|
459 |
-
# Removed
|
460 |
self.graph_app = wf.compile()
|
461 |
logger.info("ClinicalAgent ready")
|
462 |
|
463 |
def invoke_turn(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
464 |
try:
|
465 |
-
# Increase recursion limit if needed.
|
466 |
result = self.graph_app.invoke(state, {"recursion_limit": 100})
|
467 |
result.setdefault("summary", state.get("summary"))
|
468 |
result.setdefault("interaction_warnings", None)
|
|
|
39 |
[SYSTEM PROMPT CONTENT HERE]
|
40 |
"""
|
41 |
|
42 |
+
# ββ Helper: Message Wrapper βββββββββββββββββββββββββββββββββββββββββββββ
|
43 |
+
def wrap_message(msg: Any) -> AIMessage:
|
44 |
+
"""
|
45 |
+
Ensures that the given message is an AIMessage.
|
46 |
+
If it is a dict, it extracts the 'content' field (or serializes the dict).
|
47 |
+
Otherwise, it converts the message to a string.
|
48 |
+
"""
|
49 |
+
if isinstance(msg, AIMessage):
|
50 |
+
return msg
|
51 |
+
elif isinstance(msg, dict):
|
52 |
+
return AIMessage(content=msg.get("content", json.dumps(msg)))
|
53 |
+
else:
|
54 |
+
return AIMessage(content=str(msg))
|
55 |
+
|
56 |
# ββ Helper Functions βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
57 |
UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
|
58 |
RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
|
|
|
332 |
logger.info(f"Invoking LLM with {len(msgs)} messages")
|
333 |
try:
|
334 |
response = model_with_tools.invoke(msgs)
|
335 |
+
response = wrap_message(response)
|
336 |
new_state = {"messages": [response]}
|
337 |
return propagate_state(new_state, state)
|
338 |
except Exception as e:
|
339 |
logger.exception("Error in agent_node")
|
340 |
+
new_state = {"messages": [wrap_message(AIMessage(content=f"Error: {e}"))]}
|
341 |
return propagate_state(new_state, state)
|
342 |
|
343 |
def tool_node(state: AgentState) -> Dict[str, Any]:
|
344 |
if state.get("done", False):
|
345 |
return state
|
346 |
+
messages_list = state.get("messages", [])
|
347 |
+
if not messages_list:
|
348 |
+
logger.warning("tool_node invoked with no messages")
|
349 |
+
new_state = {"messages": []}
|
350 |
+
return propagate_state(new_state, state)
|
351 |
+
last = wrap_message(messages_list[-1])
|
352 |
+
# Check for pending tool_calls using dict.get if necessary.
|
353 |
+
tool_calls = last.tool_calls if hasattr(last, "tool_calls") else last.__dict__.get("tool_calls")
|
354 |
+
if not (isinstance(last, AIMessage) and tool_calls):
|
355 |
logger.warning("tool_node invoked without pending tool_calls")
|
356 |
new_state = {"messages": []}
|
357 |
return propagate_state(new_state, state)
|
358 |
+
calls = tool_calls
|
359 |
blocked_ids = set()
|
360 |
for call in calls:
|
361 |
if call["name"] == "prescribe_medication":
|
|
|
409 |
return propagate_state(new_state, state)
|
410 |
triggering = None
|
411 |
for msg in reversed(state.get("messages", [])):
|
412 |
+
wrapped = wrap_message(msg)
|
413 |
+
if isinstance(wrapped, AIMessage) and wrapped.__dict__.get("tool_calls"):
|
414 |
+
triggering = wrapped
|
415 |
break
|
416 |
if not triggering:
|
417 |
new_state = {"messages": [AIMessage(content="Internal Error: reflection context missing.")]}
|
|
|
424 |
)
|
425 |
try:
|
426 |
resp = llm.invoke([SystemMessage(content="Safety reflection"), HumanMessage(content=prompt)])
|
427 |
+
new_state = {"messages": [wrap_message(resp)]}
|
428 |
return propagate_state(new_state, state)
|
429 |
except Exception as e:
|
430 |
logger.exception("Error during reflection")
|
|
|
436 |
state.setdefault("iterations", 0)
|
437 |
state["iterations"] += 1
|
438 |
logger.info(f"Iteration count: {state['iterations']}")
|
|
|
439 |
if state["iterations"] >= 4:
|
440 |
state.setdefault("messages", []).append(AIMessage(content="Final output: consultation complete."))
|
441 |
state["done"] = True
|
|
|
443 |
if not state.get("messages"):
|
444 |
state["done"] = True
|
445 |
return "end_conversation_turn"
|
446 |
+
last = wrap_message(state["messages"][-1])
|
447 |
if not isinstance(last, AIMessage):
|
448 |
state["done"] = True
|
449 |
return "end_conversation_turn"
|
450 |
+
if last.__dict__.get("tool_calls"):
|
451 |
return "continue_tools"
|
452 |
if "consultation complete" in last.content.lower():
|
453 |
state["done"] = True
|
|
|
456 |
return "agent"
|
457 |
|
458 |
def after_tools_router(state: AgentState) -> str:
|
|
|
459 |
if state.get("interaction_warnings"):
|
460 |
return "reflection"
|
461 |
return "end_conversation_turn"
|
|
|
477 |
"reflection": "reflection",
|
478 |
"end_conversation_turn": END
|
479 |
})
|
480 |
+
# Removed edge from reflection to agent to break cycle.
|
481 |
self.graph_app = wf.compile()
|
482 |
logger.info("ClinicalAgent ready")
|
483 |
|
484 |
def invoke_turn(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
485 |
try:
|
|
|
486 |
result = self.graph_app.invoke(state, {"recursion_limit": 100})
|
487 |
result.setdefault("summary", state.get("summary"))
|
488 |
result.setdefault("interaction_warnings", None)
|