mgbam commited on
Commit
30c063c
Β·
verified Β·
1 Parent(s): b5a480c

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +8 -7
agent.py CHANGED
@@ -198,8 +198,9 @@ class FlagRiskInput(BaseModel):
198
  urgency: str = Field("High")
199
 
200
  # ── Tool Implementations ──────────────────────────────────────────────────────
201
- @tool("order_lab_test", args_schema=LabOrderInput)
202
  def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
 
203
  logger.info(f"Ordering lab test: {test_name}, reason: {reason}, priority: {priority}")
204
  return json.dumps({
205
  "status": "success",
@@ -207,7 +208,7 @@ def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> st
207
  "details": f"Reason: {reason}"
208
  })
209
 
210
- @tool("prescribe_medication", args_schema=PrescriptionInput)
211
  def prescribe_medication(
212
  medication_name: str,
213
  dosage: str,
@@ -216,6 +217,7 @@ def prescribe_medication(
216
  duration: str,
217
  reason: str
218
  ) -> str:
 
219
  logger.info(f"Preparing prescription: {medication_name} {dosage}, route: {route}, freq: {frequency}")
220
  return json.dumps({
221
  "status": "success",
@@ -223,12 +225,13 @@ def prescribe_medication(
223
  "details": f"Duration: {duration}. Reason: {reason}"
224
  })
225
 
226
- @tool("check_drug_interactions", args_schema=InteractionCheckInput)
227
  def check_drug_interactions(
228
  potential_prescription: str,
229
  current_medications: Optional[List[str]] = None,
230
  allergies: Optional[List[str]] = None
231
  ) -> str:
 
232
  logger.info(f"Checking interactions for: {potential_prescription}")
233
  warnings: List[str] = []
234
  pm = [m.lower().strip() for m in (current_medications or []) if m]
@@ -263,8 +266,9 @@ def check_drug_interactions(
263
  )
264
  return json.dumps({"status": status, "message": message, "warnings": warnings})
265
 
266
- @tool("flag_risk", args_schema=FlagRiskInput)
267
  def flag_risk(risk_description: str, urgency: str = "High") -> str:
 
268
  logger.info(f"Flagging risk: {risk_description} (urgency={urgency})")
269
  return json.dumps({
270
  "status": "flagged",
@@ -378,12 +382,10 @@ def reflection_node(state: AgentState) -> Dict[str, Any]:
378
  def should_continue(state: AgentState) -> str:
379
  last = state["messages"][-1] if state["messages"] else None
380
  if not isinstance(last, AIMessage):
381
- # Mark conversation as done
382
  state["done"] = True
383
  return "end_conversation_turn"
384
  if getattr(last, "tool_calls", None):
385
  return "continue_tools"
386
- # No further tool calls – conversation is finished.
387
  state["done"] = True
388
  return "end_conversation_turn"
389
 
@@ -408,7 +410,6 @@ class ClinicalAgent:
408
  "agent": "agent"
409
  })
410
  wf.add_edge("reflection", "agent")
411
- # Set termination condition: stop when state["done"] is True.
412
  wf.set_termination_condition(lambda state: state.get("done", False))
413
  self.graph_app = wf.compile()
414
  logger.info("ClinicalAgent ready")
 
198
  urgency: str = Field("High")
199
 
200
  # ── Tool Implementations ──────────────────────────────────────────────────────
201
+ @tool("order_lab_test", args_schema=LabOrderInput, description="Place an order for a laboratory test.")
202
  def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
203
+ """Place an order for a laboratory test."""
204
  logger.info(f"Ordering lab test: {test_name}, reason: {reason}, priority: {priority}")
205
  return json.dumps({
206
  "status": "success",
 
208
  "details": f"Reason: {reason}"
209
  })
210
 
211
+ @tool("prescribe_medication", args_schema=PrescriptionInput, description="Prepare a medication prescription.")
212
  def prescribe_medication(
213
  medication_name: str,
214
  dosage: str,
 
217
  duration: str,
218
  reason: str
219
  ) -> str:
220
+ """Prepare a medication prescription."""
221
  logger.info(f"Preparing prescription: {medication_name} {dosage}, route: {route}, freq: {frequency}")
222
  return json.dumps({
223
  "status": "success",
 
225
  "details": f"Duration: {duration}. Reason: {reason}"
226
  })
227
 
228
+ @tool("check_drug_interactions", args_schema=InteractionCheckInput, description="Check for drug–drug interactions and allergy risks.")
229
  def check_drug_interactions(
230
  potential_prescription: str,
231
  current_medications: Optional[List[str]] = None,
232
  allergies: Optional[List[str]] = None
233
  ) -> str:
234
+ """Check for drug–drug interactions and allergy risks."""
235
  logger.info(f"Checking interactions for: {potential_prescription}")
236
  warnings: List[str] = []
237
  pm = [m.lower().strip() for m in (current_medications or []) if m]
 
266
  )
267
  return json.dumps({"status": status, "message": message, "warnings": warnings})
268
 
269
+ @tool("flag_risk", args_schema=FlagRiskInput, description="Flag a clinical risk with given urgency.")
270
  def flag_risk(risk_description: str, urgency: str = "High") -> str:
271
+ """Flag a clinical risk with given urgency."""
272
  logger.info(f"Flagging risk: {risk_description} (urgency={urgency})")
273
  return json.dumps({
274
  "status": "flagged",
 
382
  def should_continue(state: AgentState) -> str:
383
  last = state["messages"][-1] if state["messages"] else None
384
  if not isinstance(last, AIMessage):
 
385
  state["done"] = True
386
  return "end_conversation_turn"
387
  if getattr(last, "tool_calls", None):
388
  return "continue_tools"
 
389
  state["done"] = True
390
  return "end_conversation_turn"
391
 
 
410
  "agent": "agent"
411
  })
412
  wf.add_edge("reflection", "agent")
 
413
  wf.set_termination_condition(lambda state: state.get("done", False))
414
  self.graph_app = wf.compile()
415
  logger.info("ClinicalAgent ready")