Update agent.py
Browse files
agent.py
CHANGED
@@ -198,8 +198,9 @@ class FlagRiskInput(BaseModel):
|
|
198 |
urgency: str = Field("High")
|
199 |
|
200 |
# ββ Tool Implementations ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
201 |
-
@tool("order_lab_test", args_schema=LabOrderInput)
|
202 |
def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
|
|
|
203 |
logger.info(f"Ordering lab test: {test_name}, reason: {reason}, priority: {priority}")
|
204 |
return json.dumps({
|
205 |
"status": "success",
|
@@ -207,7 +208,7 @@ def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> st
|
|
207 |
"details": f"Reason: {reason}"
|
208 |
})
|
209 |
|
210 |
-
@tool("prescribe_medication", args_schema=PrescriptionInput)
|
211 |
def prescribe_medication(
|
212 |
medication_name: str,
|
213 |
dosage: str,
|
@@ -216,6 +217,7 @@ def prescribe_medication(
|
|
216 |
duration: str,
|
217 |
reason: str
|
218 |
) -> str:
|
|
|
219 |
logger.info(f"Preparing prescription: {medication_name} {dosage}, route: {route}, freq: {frequency}")
|
220 |
return json.dumps({
|
221 |
"status": "success",
|
@@ -223,12 +225,13 @@ def prescribe_medication(
|
|
223 |
"details": f"Duration: {duration}. Reason: {reason}"
|
224 |
})
|
225 |
|
226 |
-
@tool("check_drug_interactions", args_schema=InteractionCheckInput)
|
227 |
def check_drug_interactions(
|
228 |
potential_prescription: str,
|
229 |
current_medications: Optional[List[str]] = None,
|
230 |
allergies: Optional[List[str]] = None
|
231 |
) -> str:
|
|
|
232 |
logger.info(f"Checking interactions for: {potential_prescription}")
|
233 |
warnings: List[str] = []
|
234 |
pm = [m.lower().strip() for m in (current_medications or []) if m]
|
@@ -263,8 +266,9 @@ def check_drug_interactions(
|
|
263 |
)
|
264 |
return json.dumps({"status": status, "message": message, "warnings": warnings})
|
265 |
|
266 |
-
@tool("flag_risk", args_schema=FlagRiskInput)
|
267 |
def flag_risk(risk_description: str, urgency: str = "High") -> str:
|
|
|
268 |
logger.info(f"Flagging risk: {risk_description} (urgency={urgency})")
|
269 |
return json.dumps({
|
270 |
"status": "flagged",
|
@@ -378,12 +382,10 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
|
|
378 |
def should_continue(state: AgentState) -> str:
|
379 |
last = state["messages"][-1] if state["messages"] else None
|
380 |
if not isinstance(last, AIMessage):
|
381 |
-
# Mark conversation as done
|
382 |
state["done"] = True
|
383 |
return "end_conversation_turn"
|
384 |
if getattr(last, "tool_calls", None):
|
385 |
return "continue_tools"
|
386 |
-
# No further tool calls β conversation is finished.
|
387 |
state["done"] = True
|
388 |
return "end_conversation_turn"
|
389 |
|
@@ -408,7 +410,6 @@ class ClinicalAgent:
|
|
408 |
"agent": "agent"
|
409 |
})
|
410 |
wf.add_edge("reflection", "agent")
|
411 |
-
# Set termination condition: stop when state["done"] is True.
|
412 |
wf.set_termination_condition(lambda state: state.get("done", False))
|
413 |
self.graph_app = wf.compile()
|
414 |
logger.info("ClinicalAgent ready")
|
|
|
198 |
urgency: str = Field("High")
|
199 |
|
200 |
# ββ Tool Implementations ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
201 |
+
@tool("order_lab_test", args_schema=LabOrderInput, description="Place an order for a laboratory test.")
|
202 |
def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
|
203 |
+
"""Place an order for a laboratory test."""
|
204 |
logger.info(f"Ordering lab test: {test_name}, reason: {reason}, priority: {priority}")
|
205 |
return json.dumps({
|
206 |
"status": "success",
|
|
|
208 |
"details": f"Reason: {reason}"
|
209 |
})
|
210 |
|
211 |
+
@tool("prescribe_medication", args_schema=PrescriptionInput, description="Prepare a medication prescription.")
|
212 |
def prescribe_medication(
|
213 |
medication_name: str,
|
214 |
dosage: str,
|
|
|
217 |
duration: str,
|
218 |
reason: str
|
219 |
) -> str:
|
220 |
+
"""Prepare a medication prescription."""
|
221 |
logger.info(f"Preparing prescription: {medication_name} {dosage}, route: {route}, freq: {frequency}")
|
222 |
return json.dumps({
|
223 |
"status": "success",
|
|
|
225 |
"details": f"Duration: {duration}. Reason: {reason}"
|
226 |
})
|
227 |
|
228 |
+
@tool("check_drug_interactions", args_schema=InteractionCheckInput, description="Check for drugβdrug interactions and allergy risks.")
|
229 |
def check_drug_interactions(
|
230 |
potential_prescription: str,
|
231 |
current_medications: Optional[List[str]] = None,
|
232 |
allergies: Optional[List[str]] = None
|
233 |
) -> str:
|
234 |
+
"""Check for drugβdrug interactions and allergy risks."""
|
235 |
logger.info(f"Checking interactions for: {potential_prescription}")
|
236 |
warnings: List[str] = []
|
237 |
pm = [m.lower().strip() for m in (current_medications or []) if m]
|
|
|
266 |
)
|
267 |
return json.dumps({"status": status, "message": message, "warnings": warnings})
|
268 |
|
269 |
+
@tool("flag_risk", args_schema=FlagRiskInput, description="Flag a clinical risk with given urgency.")
|
270 |
def flag_risk(risk_description: str, urgency: str = "High") -> str:
|
271 |
+
"""Flag a clinical risk with given urgency."""
|
272 |
logger.info(f"Flagging risk: {risk_description} (urgency={urgency})")
|
273 |
return json.dumps({
|
274 |
"status": "flagged",
|
|
|
382 |
def should_continue(state: AgentState) -> str:
|
383 |
last = state["messages"][-1] if state["messages"] else None
|
384 |
if not isinstance(last, AIMessage):
|
|
|
385 |
state["done"] = True
|
386 |
return "end_conversation_turn"
|
387 |
if getattr(last, "tool_calls", None):
|
388 |
return "continue_tools"
|
|
|
389 |
state["done"] = True
|
390 |
return "end_conversation_turn"
|
391 |
|
|
|
410 |
"agent": "agent"
|
411 |
})
|
412 |
wf.add_edge("reflection", "agent")
|
|
|
413 |
wf.set_termination_condition(lambda state: state.get("done", False))
|
414 |
self.graph_app = wf.compile()
|
415 |
logger.info("ClinicalAgent ready")
|