Update agent.py
Browse files
agent.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
-
# agent.py
|
2 |
-
import requests
|
3 |
-
import json
|
4 |
-
import re
|
5 |
import os
|
6 |
-
import
|
|
|
|
|
7 |
import traceback
|
|
|
8 |
from functools import lru_cache
|
|
|
9 |
|
10 |
from langchain_groq import ChatGroq
|
11 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
@@ -15,333 +15,490 @@ from langchain_core.tools import tool
|
|
15 |
from langgraph.prebuilt import ToolExecutor
|
16 |
from langgraph.graph import StateGraph, END
|
17 |
|
18 |
-
from typing import Optional, List, Dict, Any, TypedDict, Annotated
|
19 |
|
20 |
-
# ---
|
21 |
-
# Keys are primarily used here, but checked in app.py for UI feedback
|
22 |
UMLS_API_KEY = os.environ.get("UMLS_API_KEY")
|
23 |
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
|
24 |
TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY")
|
25 |
|
26 |
-
# --- Configuration & Constants ---
|
27 |
AGENT_MODEL_NAME = "llama3-70b-8192"
|
28 |
AGENT_TEMPERATURE = 0.1
|
29 |
MAX_SEARCH_RESULTS = 3
|
30 |
|
31 |
-
class ClinicalPrompts:
|
32 |
-
# The comprehensive system prompt defining agent behavior
|
33 |
-
SYSTEM_PROMPT = """
|
34 |
-
You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation.
|
35 |
-
Your goal is to support healthcare professionals by analyzing patient data, providing differential diagnoses, suggesting evidence-based management plans, and identifying risks according to current standards of care.
|
36 |
-
|
37 |
-
**Core Directives for this Conversation:**
|
38 |
-
1. **Analyze Sequentially:** Process information turn-by-turn. Base your responses on the *entire* conversation history.
|
39 |
-
2. **Seek Clarity:** If the provided information is insufficient or ambiguous for a safe assessment, CLEARLY STATE what specific additional information or clarification is needed. Do NOT guess or make unsafe assumptions.
|
40 |
-
3. **Structured Assessment (When Ready):** When you have sufficient information and have performed necessary checks (like interactions, guideline searches), provide a comprehensive assessment using the following JSON structure. Output this JSON structure as the primary content of your response when you are providing the full analysis. Do NOT output incomplete JSON. If you need to ask a question or perform a tool call first, do that instead of outputting this structure.
|
41 |
-
```json
|
42 |
-
{
|
43 |
-
"assessment": "Concise summary of the patient's presentation and key findings based on the conversation.",
|
44 |
-
"differential_diagnosis": [
|
45 |
-
{"diagnosis": "Primary Diagnosis", "likelihood": "High/Medium/Low", "rationale": "Supporting evidence from conversation..."},
|
46 |
-
{"diagnosis": "Alternative Diagnosis 1", "likelihood": "Medium/Low", "rationale": "Supporting/Refuting evidence..."},
|
47 |
-
{"diagnosis": "Alternative Diagnosis 2", "likelihood": "Low", "rationale": "Why it's less likely but considered..."}
|
48 |
-
],
|
49 |
-
"risk_assessment": {
|
50 |
-
"identified_red_flags": ["List any triggered red flags based on input and analysis"],
|
51 |
-
"immediate_concerns": ["Specific urgent issues requiring attention (e.g., sepsis risk, ACS rule-out)"],
|
52 |
-
"potential_complications": ["Possible future issues based on presentation"]
|
53 |
-
},
|
54 |
-
"recommended_plan": {
|
55 |
-
"investigations": ["List specific lab tests or imaging required. Use 'order_lab_test' tool."],
|
56 |
-
"therapeutics": ["Suggest specific treatments or prescriptions. Use 'prescribe_medication' tool. MUST check interactions first using 'check_drug_interactions'."],
|
57 |
-
"consultations": ["Recommend specialist consultations if needed."],
|
58 |
-
"patient_education": ["Key points for patient communication."]
|
59 |
-
},
|
60 |
-
"rationale_summary": "Justification for assessment/plan. **Crucially, if relevant (e.g., ACS, sepsis, common infections), use 'tavily_search_results' to find and cite current clinical practice guidelines (e.g., 'latest ACC/AHA chest pain guidelines 202X', 'Surviving Sepsis Campaign guidelines') supporting your recommendations.** Include summary of guideline findings here.",
|
61 |
-
"interaction_check_summary": "Summary of findings from 'check_drug_interactions' if performed."
|
62 |
-
}
|
63 |
-
```
|
64 |
-
4. **Safety First - Interactions:** BEFORE suggesting a new prescription via `prescribe_medication`, you MUST FIRST use `check_drug_interactions` in a preceding or concurrent tool call. Report the findings from the interaction check. If significant interactions exist, modify the plan or state the contraindication clearly.
|
65 |
-
5. **Safety First - Red Flags:** Use the `flag_risk` tool IMMEDIATELY if critical red flags requiring urgent action are identified at any point in the conversation.
|
66 |
-
6. **Tool Use:** Employ tools (`order_lab_test`, `prescribe_medication`, `check_drug_interactions`, `flag_risk`, `tavily_search_results`) logically within the conversational flow. Wait for tool results before proceeding if the result is needed for the next step (e.g., wait for interaction check before confirming prescription in the structured JSON).
|
67 |
-
7. **Evidence & Guidelines:** Actively use `tavily_search_results` not just for general knowledge, but specifically to query for and incorporate **current clinical practice guidelines** relevant to the patient's presentation (e.g., chest pain, shortness of breath, suspected infection). Summarize findings in the `rationale_summary` when providing the structured output.
|
68 |
-
8. **Conciseness & Flow:** Be medically accurate and concise. Use standard terminology. Respond naturally in conversation (asking questions, acknowledging info) until ready for the full structured JSON output.
|
69 |
-
"""
|
70 |
-
|
71 |
-
# --- API Constants & Helper Functions ---
|
72 |
UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
|
73 |
RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
|
74 |
OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
@lru_cache(maxsize=256)
|
77 |
def get_rxcui(drug_name: str) -> Optional[str]:
|
78 |
-
"""
|
79 |
-
if not drug_name
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
@lru_cache(maxsize=128)
|
97 |
-
def get_openfda_label(
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
try:
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
""" Case-insensitive search for any search_term within a list of text strings. Returns snippets. """
|
113 |
-
found_snippets = [];
|
114 |
-
if not text_list or not search_terms: return found_snippets; search_terms_lower = [str(term).lower() for term in search_terms if term];
|
115 |
-
for text_item in text_list:
|
116 |
-
if not isinstance(text_item, str): continue; text_item_lower = text_item.lower();
|
117 |
-
for term in search_terms_lower:
|
118 |
-
if term in text_item_lower:
|
119 |
-
start_index = text_item_lower.find(term); snippet_start = max(0, start_index - 50); snippet_end = min(len(text_item), start_index + len(term) + 100); snippet = text_item[snippet_start:snippet_end];
|
120 |
-
snippet = re.sub(f"({re.escape(term)})", r"**\1**", snippet, count=1, flags=re.IGNORECASE) # Highlight match
|
121 |
-
found_snippets.append(f"...{snippet}...")
|
122 |
-
break # Only report first match per text item
|
123 |
-
return found_snippets
|
124 |
-
|
125 |
-
|
126 |
-
# --- Clinical Helper Functions ---
|
127 |
-
def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
|
128 |
-
"""Parses BP string like '120/80' into (systolic, diastolic) integers."""
|
129 |
-
if not isinstance(bp_string, str): return None
|
130 |
-
match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string.strip())
|
131 |
-
if match: return int(match.group(1)), int(match.group(2))
|
132 |
return None
|
133 |
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
flags = []
|
137 |
-
if not patient_data:
|
138 |
-
|
|
|
|
|
139 |
vitals = patient_data.get("vitals", {})
|
140 |
-
history = patient_data.get("pmh", {}).get("conditions", "")
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
@tool("order_lab_test", args_schema=LabOrderInput)
|
195 |
def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
@tool("prescribe_medication", args_schema=PrescriptionInput)
|
198 |
-
def prescribe_medication(
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
@tool("check_drug_interactions", args_schema=InteractionCheckInput)
|
201 |
-
def check_drug_interactions(
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
if
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
@tool("flag_risk", args_schema=FlagRiskInput)
|
235 |
def flag_risk(risk_description: str, urgency: str) -> str:
|
236 |
-
|
237 |
-
|
|
|
|
|
|
|
|
|
|
|
238 |
search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results")
|
239 |
all_tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool]
|
240 |
|
241 |
-
# --- LangGraph State & Nodes ---
|
242 |
-
class AgentState(TypedDict): messages: Annotated[list[Any], operator.add]; patient_data: Optional[dict]; summary: Optional[str]; interaction_warnings: Optional[List[str]]
|
243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
llm = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME)
|
245 |
model_with_tools = llm.bind_tools(all_tools)
|
246 |
tool_executor = ToolExecutor(all_tools)
|
247 |
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
for msg in reversed(state['messages']):
|
296 |
-
if isinstance(msg, ToolMessage) and msg.name ==
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
def should_continue(state: AgentState) -> str:
|
312 |
-
|
313 |
-
if not isinstance(
|
314 |
-
|
315 |
-
if getattr(
|
|
|
|
|
|
|
316 |
|
317 |
def after_tools_router(state: AgentState) -> str:
|
318 |
-
|
319 |
-
|
320 |
-
|
|
|
321 |
|
322 |
-
# --- ClinicalAgent Class ---
|
323 |
class ClinicalAgent:
|
324 |
def __init__(self):
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
|
|
|
|
|
|
|
|
|
|
335 |
|
336 |
def invoke_turn(self, state: Dict) -> Dict:
|
337 |
-
"""Invokes the LangGraph app for one turn."""
|
338 |
-
print(f"Invoking graph with state keys: {state.keys()}")
|
339 |
try:
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
return
|
344 |
except Exception as e:
|
345 |
-
|
346 |
-
|
347 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
import requests
|
5 |
import traceback
|
6 |
+
import operator
|
7 |
from functools import lru_cache
|
8 |
+
from typing import Any, Dict, List, Optional, TypedDict, Annotated
|
9 |
|
10 |
from langchain_groq import ChatGroq
|
11 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
|
|
15 |
from langgraph.prebuilt import ToolExecutor
|
16 |
from langgraph.graph import StateGraph, END
|
17 |
|
|
|
18 |
|
19 |
+
# --- Configuration & Constants ---
|
|
|
20 |
UMLS_API_KEY = os.environ.get("UMLS_API_KEY")
|
21 |
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
|
22 |
TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY")
|
23 |
|
|
|
24 |
AGENT_MODEL_NAME = "llama3-70b-8192"
|
25 |
AGENT_TEMPERATURE = 0.1
|
26 |
MAX_SEARCH_RESULTS = 3
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
|
29 |
RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
|
30 |
OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
|
31 |
|
32 |
+
|
33 |
+
class ClinicalPrompts:
|
34 |
+
SYSTEM_PROMPT = (
|
35 |
+
"""
|
36 |
+
You are SynapseAI, an expert AI clinical assistant in an interactive consultation.
|
37 |
+
Analyze patient data, provide differential diagnoses, suggest management plans,
|
38 |
+
and identify risks according to current standards of care.
|
39 |
+
|
40 |
+
1. Process information sequentially; use full conversation history.
|
41 |
+
2. Ask for clarification if data is insufficient; do not guess.
|
42 |
+
3. When ready, output a complete JSON assessment as specified.
|
43 |
+
4. Before prescribing, run drug-interaction checks and report results.
|
44 |
+
5. Flag urgent red flags immediately.
|
45 |
+
6. Use tools logically; await results when needed.
|
46 |
+
7. Query clinical guidelines via tavily_search_results and cite them.
|
47 |
+
8. Be concise, accurate, and use standard terminology.
|
48 |
+
"""
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
# --- Helper Functions ---
|
53 |
@lru_cache(maxsize=256)
|
54 |
def get_rxcui(drug_name: str) -> Optional[str]:
|
55 |
+
"""Return RxNorm CUI for a given drug name."""
|
56 |
+
if not drug_name:
|
57 |
+
return None
|
58 |
+
name = drug_name.strip()
|
59 |
+
if not name:
|
60 |
+
return None
|
61 |
+
|
62 |
+
try:
|
63 |
+
# Primary lookup
|
64 |
+
params = {"name": name, "search": 1}
|
65 |
+
resp = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10)
|
66 |
+
resp.raise_for_status()
|
67 |
+
data = resp.json()
|
68 |
+
ids = data.get("idGroup", {}).get("rxnormId", [])
|
69 |
+
if ids:
|
70 |
+
return ids[0]
|
71 |
+
|
72 |
+
# Fallback lookup
|
73 |
+
params = {"name": name}
|
74 |
+
resp = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10)
|
75 |
+
resp.raise_for_status()
|
76 |
+
data = resp.json()
|
77 |
+
groups = data.get("drugGroup", {}).get("conceptGroup", [])
|
78 |
+
for grp in groups:
|
79 |
+
if grp.get("tty") in ["SBD", "SCD", "GPCK", "BPCK", "IN", "MIN", "PIN"]:
|
80 |
+
props = grp.get("conceptProperties", [])
|
81 |
+
if props:
|
82 |
+
return props[0].get("rxcui")
|
83 |
+
except Exception:
|
84 |
+
traceback.print_exc()
|
85 |
+
return None
|
86 |
+
|
87 |
|
88 |
@lru_cache(maxsize=128)
|
89 |
+
def get_openfda_label(
|
90 |
+
rxcui: Optional[str] = None,
|
91 |
+
drug_name: Optional[str] = None
|
92 |
+
) -> Optional[dict]:
|
93 |
+
"""Fetch OpenFDA drug label by RxCUI or name."""
|
94 |
+
if not (rxcui or drug_name):
|
95 |
+
return None
|
96 |
+
|
97 |
+
terms = []
|
98 |
+
if rxcui:
|
99 |
+
terms.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"')
|
100 |
+
if drug_name:
|
101 |
+
name = drug_name.lower()
|
102 |
+
terms.append(f'(openfda.brand_name:"{name}" OR openfda.generic_name:"{name}")')
|
103 |
+
|
104 |
+
query = " OR ".join(terms)
|
105 |
+
params = {"search": query, "limit": 1}
|
106 |
+
|
107 |
try:
|
108 |
+
resp = requests.get(OPENFDA_API_BASE, params=params, timeout=15)
|
109 |
+
resp.raise_for_status()
|
110 |
+
data = resp.json()
|
111 |
+
results = data.get("results", [])
|
112 |
+
if results:
|
113 |
+
return results[0]
|
114 |
+
except Exception:
|
115 |
+
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
return None
|
117 |
|
118 |
+
|
119 |
+
def search_text_list(texts: List[str], terms: List[str]) -> List[str]:
|
120 |
+
"""Return snippets where any term appears in texts."""
|
121 |
+
snippets = []
|
122 |
+
lowers = [t.lower() for t in terms if t]
|
123 |
+
for txt in texts or []:
|
124 |
+
if not isinstance(txt, str):
|
125 |
+
continue
|
126 |
+
low_txt = txt.lower()
|
127 |
+
for term in lowers:
|
128 |
+
idx = low_txt.find(term)
|
129 |
+
if idx >= 0:
|
130 |
+
start = max(0, idx - 50)
|
131 |
+
end = min(len(txt), idx + len(term) + 100)
|
132 |
+
snippet = txt[start:end]
|
133 |
+
snippet = re.sub(
|
134 |
+
f"({re.escape(term)})",
|
135 |
+
r"**\1**",
|
136 |
+
snippet,
|
137 |
+
count=1,
|
138 |
+
flags=re.IGNORECASE,
|
139 |
+
)
|
140 |
+
snippets.append(f"...{snippet}...")
|
141 |
+
break
|
142 |
+
return snippets
|
143 |
+
|
144 |
+
|
145 |
+
def parse_bp(bp_str: str) -> Optional[tuple[int, int]]:
|
146 |
+
"""Parse blood pressure string 'systolic/diastolic'."""
|
147 |
+
match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_str or "")
|
148 |
+
if match:
|
149 |
+
return int(match.group(1)), int(match.group(2))
|
150 |
+
return None
|
151 |
+
|
152 |
+
|
153 |
+
def check_red_flags(patient_data: Dict) -> List[str]:
|
154 |
+
"""Identify critical red flags from patient data."""
|
155 |
flags = []
|
156 |
+
if not patient_data:
|
157 |
+
return flags
|
158 |
+
|
159 |
+
symptoms = [s.lower() for s in patient_data.get("hpi", {}).get("symptoms", [])]
|
160 |
vitals = patient_data.get("vitals", {})
|
161 |
+
history = patient_data.get("pmh", {}).get("conditions", "").lower()
|
162 |
+
|
163 |
+
# Symptom-based flags
|
164 |
+
mapping = {
|
165 |
+
"chest pain": "Chest Pain reported.",
|
166 |
+
"shortness of breath": "Shortness of Breath reported.",
|
167 |
+
"severe headache": "Severe Headache reported.",
|
168 |
+
"sudden vision loss": "Sudden Vision Loss reported.",
|
169 |
+
"weakness on one side": "Unilateral Weakness reported (potential stroke).",
|
170 |
+
"hemoptysis": "Hemoptysis (coughing up blood).",
|
171 |
+
"syncope": "Syncope (fainting).",
|
172 |
+
}
|
173 |
+
for term, desc in mapping.items():
|
174 |
+
if term in symptoms:
|
175 |
+
flags.append(f"Red Flag: {desc}")
|
176 |
+
|
177 |
+
# Vital sign flags
|
178 |
+
temp = vitals.get("temp_c")
|
179 |
+
hr = vitals.get("hr_bpm")
|
180 |
+
rr = vitals.get("rr_rpm")
|
181 |
+
spo2 = vitals.get("spo2_percent")
|
182 |
+
bp = parse_bp(vitals.get("bp_mmhg", ""))
|
183 |
+
|
184 |
+
if temp and temp >= 38.5:
|
185 |
+
flags.append(f"Red Flag: Fever ({temp}°C).")
|
186 |
+
if hr:
|
187 |
+
if hr >= 120:
|
188 |
+
flags.append(f"Red Flag: Tachycardia ({hr} bpm).")
|
189 |
+
if hr <= 50:
|
190 |
+
flags.append(f"Red Flag: Bradycardia ({hr} bpm).")
|
191 |
+
if rr and rr >= 24:
|
192 |
+
flags.append(f"Red Flag: Tachypnea ({rr} rpm).")
|
193 |
+
if spo2 and spo2 <= 92:
|
194 |
+
flags.append(f"Red Flag: Hypoxia ({spo2}%).")
|
195 |
+
if bp:
|
196 |
+
sys, dia = bp
|
197 |
+
if sys >= 180 or dia >= 110:
|
198 |
+
flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {sys}/{dia} mmHg).")
|
199 |
+
if sys <= 90 or dia <= 60:
|
200 |
+
flags.append(f"Red Flag: Hypotension (BP: {sys}/{dia} mmHg).")
|
201 |
+
|
202 |
+
# History-based flags
|
203 |
+
if "history of mi" in history and "chest pain" in symptoms:
|
204 |
+
flags.append("Red Flag: History of MI with current Chest Pain.")
|
205 |
+
if "history of dvt/pe" in history and "shortness of breath" in symptoms:
|
206 |
+
flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.")
|
207 |
+
|
208 |
+
return list(set(flags))
|
209 |
+
|
210 |
+
|
211 |
+
def format_patient_data_for_prompt(data: Dict) -> str:
|
212 |
+
"""Convert patient data dict into a human-readable prompt section."""
|
213 |
+
if not data:
|
214 |
+
return "No patient data provided."
|
215 |
+
|
216 |
+
sections = []
|
217 |
+
for key, val in data.items():
|
218 |
+
title = key.replace("_", " ").title()
|
219 |
+
if isinstance(val, dict) and any(val.values()):
|
220 |
+
lines = [f"**{title}:**"]
|
221 |
+
for subk, subv in val.items():
|
222 |
+
if subv:
|
223 |
+
lines.append(f"- {subk.replace('_', ' ').title()}: {subv}")
|
224 |
+
sections.append("\n".join(lines))
|
225 |
+
elif isinstance(val, list) and val:
|
226 |
+
sections.append(f"**{title}:** {', '.join(map(str, val))}")
|
227 |
+
elif val:
|
228 |
+
sections.append(f"**{title}:** {val}")
|
229 |
+
|
230 |
+
return "\n\n".join(sections)
|
231 |
+
|
232 |
+
|
233 |
+
# --- Tool Schemas & Definitions ---
|
234 |
+
class LabOrderInput(BaseModel):
|
235 |
+
test_name: str = Field(...)
|
236 |
+
reason: str = Field(...)
|
237 |
+
priority: str = Field("Routine")
|
238 |
+
|
239 |
+
|
240 |
+
class PrescriptionInput(BaseModel):
|
241 |
+
medication_name: str = Field(...)
|
242 |
+
dosage: str = Field(...)
|
243 |
+
route: str = Field(...)
|
244 |
+
frequency: str = Field(...)
|
245 |
+
duration: str = Field("As directed")
|
246 |
+
reason: str = Field(...)
|
247 |
+
|
248 |
+
|
249 |
+
class InteractionCheckInput(BaseModel):
|
250 |
+
potential_prescription: str = Field(...)
|
251 |
+
current_medications: Optional[List[str]] = Field(None)
|
252 |
+
allergies: Optional[List[str]] = Field(None)
|
253 |
+
|
254 |
+
|
255 |
+
class FlagRiskInput(BaseModel):
|
256 |
+
risk_description: str = Field(...)
|
257 |
+
urgency: str = Field("High")
|
258 |
+
|
259 |
|
260 |
@tool("order_lab_test", args_schema=LabOrderInput)
|
261 |
def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
|
262 |
+
result = {
|
263 |
+
"status": "success",
|
264 |
+
"message": f"Lab Ordered: {test_name} ({priority})",
|
265 |
+
"details": f"Reason: {reason}"
|
266 |
+
}
|
267 |
+
return json.dumps(result)
|
268 |
+
|
269 |
+
|
270 |
@tool("prescribe_medication", args_schema=PrescriptionInput)
|
271 |
+
def prescribe_medication(
|
272 |
+
medication_name: str,
|
273 |
+
dosage: str,
|
274 |
+
route: str,
|
275 |
+
frequency: str,
|
276 |
+
duration: str,
|
277 |
+
reason: str
|
278 |
+
) -> str:
|
279 |
+
result = {
|
280 |
+
"status": "success",
|
281 |
+
"message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}",
|
282 |
+
"details": f"Duration: {duration}. Reason: {reason}"
|
283 |
+
}
|
284 |
+
return json.dumps(result)
|
285 |
+
|
286 |
+
|
287 |
@tool("check_drug_interactions", args_schema=InteractionCheckInput)
|
288 |
+
def check_drug_interactions(
|
289 |
+
potential_prescription: str,
|
290 |
+
current_medications: Optional[List[str]] = None,
|
291 |
+
allergies: Optional[List[str]] = None
|
292 |
+
) -> str:
|
293 |
+
warnings: List[str] = []
|
294 |
+
presc_lower = potential_prescription.lower().strip()
|
295 |
+
current = [m.lower().strip() for m in (current_medications or [])]
|
296 |
+
allergy_list = [a.lower().strip() for a in (allergies or [])]
|
297 |
+
|
298 |
+
# Normalize and lookup
|
299 |
+
rxcui = get_rxcui(potential_prescription)
|
300 |
+
label = get_openfda_label(rxcui=rxcui, drug_name=potential_prescription)
|
301 |
+
if not rxcui and not label:
|
302 |
+
warnings.append(f"INFO: Could not identify '{potential_prescription}'.")
|
303 |
+
|
304 |
+
# Allergy checks
|
305 |
+
for alg in allergy_list:
|
306 |
+
if alg == presc_lower:
|
307 |
+
warnings.append(f"CRITICAL ALLERGY: Patient allergic to '{alg}'.")
|
308 |
+
# Additional cross-allergy logic...
|
309 |
+
|
310 |
+
# Drug-drug interactions
|
311 |
+
if rxcui or label:
|
312 |
+
for med in current:
|
313 |
+
if med and med != presc_lower:
|
314 |
+
# interaction search on label sections
|
315 |
+
interactions = []
|
316 |
+
if label and label.get("drug_interactions"):
|
317 |
+
interactions = search_text_list(label["drug_interactions"], [med])
|
318 |
+
if interactions:
|
319 |
+
warnings.append(
|
320 |
+
f"Potential Interaction: '{potential_prescription}' & '{med}'. Snippets: {'; '.join(interactions)}"
|
321 |
+
)
|
322 |
+
else:
|
323 |
+
warnings.append(f"INFO: Skipped interaction check for '{potential_prescription}'.")
|
324 |
+
|
325 |
+
status = "warning" if warnings else "clear"
|
326 |
+
message = (
|
327 |
+
f"Interaction/Allergy check for '{potential_prescription}': {len(warnings)} issue(s)."
|
328 |
+
if warnings else
|
329 |
+
f"No major issues for '{potential_prescription}'."
|
330 |
+
)
|
331 |
+
return json.dumps({"status": status, "message": message, "warnings": warnings})
|
332 |
+
|
333 |
+
|
334 |
@tool("flag_risk", args_schema=FlagRiskInput)
|
335 |
def flag_risk(risk_description: str, urgency: str) -> str:
|
336 |
+
return json.dumps({
|
337 |
+
"status": "flagged",
|
338 |
+
"message": f"Risk '{risk_description}' flagged with {urgency} urgency."
|
339 |
+
})
|
340 |
+
|
341 |
+
|
342 |
+
# Initialize search tool and tool list
|
343 |
search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results")
|
344 |
all_tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool]
|
345 |
|
|
|
|
|
346 |
|
347 |
+
# --- LangGraph Setup ---
|
348 |
+
class AgentState(TypedDict):
|
349 |
+
messages: Annotated[List[Any], operator.add]
|
350 |
+
patient_data: Optional[Dict]
|
351 |
+
summary: Optional[str]
|
352 |
+
interaction_warnings: Optional[List[str]]
|
353 |
+
|
354 |
+
# LLM and executor
|
355 |
llm = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME)
|
356 |
model_with_tools = llm.bind_tools(all_tools)
|
357 |
tool_executor = ToolExecutor(all_tools)
|
358 |
|
359 |
+
|
360 |
+
def agent_node(state: AgentState) -> Dict:
|
361 |
+
"""Invoke the LLM agent node."""
|
362 |
+
msgs = state['messages'][:]
|
363 |
+
if not msgs or not isinstance(msgs[0], SystemMessage):
|
364 |
+
msgs.insert(0, SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT))
|
365 |
+
|
366 |
+
try:
|
367 |
+
response = model_with_tools.invoke(msgs)
|
368 |
+
return {"messages": [response]}
|
369 |
+
except Exception as e:
|
370 |
+
traceback.print_exc()
|
371 |
+
err = AIMessage(content=f"Error: {e}")
|
372 |
+
return {"messages": [err]}
|
373 |
+
|
374 |
+
|
375 |
+
def tool_node(state: AgentState) -> Dict:
|
376 |
+
"""Execute any pending tool calls from the last AI message."""
|
377 |
+
last = state['messages'][-1]
|
378 |
+
if not isinstance(last, AIMessage) or not getattr(last, 'tool_calls', None):
|
379 |
+
return {"messages": [], "interaction_warnings": None}
|
380 |
+
|
381 |
+
calls = last.tool_calls
|
382 |
+
# Enforce safety: require interaction check before prescribing
|
383 |
+
blocked_ids = set()
|
384 |
+
for call in calls:
|
385 |
+
if call['name'] == 'prescribe_medication':
|
386 |
+
# block if no interaction check for this med
|
387 |
+
med = call['args'].get('medication_name', '').lower()
|
388 |
+
if not any(
|
389 |
+
c['name'] == 'check_drug_interactions' and
|
390 |
+
c['args'].get('potential_prescription', '').lower() == med
|
391 |
+
for c in calls
|
392 |
+
):
|
393 |
+
blocked_ids.add(call['id'])
|
394 |
+
|
395 |
+
valid_calls = [c for c in calls if c['id'] not in blocked_ids]
|
396 |
+
|
397 |
+
# Augment interaction checks with patient data
|
398 |
+
for c in valid_calls:
|
399 |
+
if c['name'] == 'check_drug_interactions':
|
400 |
+
c['args']['current_medications'] = state.get('patient_data', {}).get('medications', {}).get('current', [])
|
401 |
+
c['args']['allergies'] = state.get('patient_data', {}).get('allergies', [])
|
402 |
+
|
403 |
+
results = []
|
404 |
+
warnings: List[str] = []
|
405 |
+
try:
|
406 |
+
responses = tool_executor.batch(valid_calls, return_exceptions=True)
|
407 |
+
for call, resp in zip(valid_calls, responses):
|
408 |
+
if isinstance(resp, Exception):
|
409 |
+
traceback.print_exc()
|
410 |
+
content = json.dumps({"status": "error", "message": str(resp)})
|
411 |
+
else:
|
412 |
+
content = str(resp)
|
413 |
+
if call['name'] == 'check_drug_interactions':
|
414 |
+
data = json.loads(content)
|
415 |
+
if data.get('warnings'):
|
416 |
+
warnings.extend(data['warnings'])
|
417 |
+
results.append(ToolMessage(content=content, tool_call_id=call['id'], name=call['name']))
|
418 |
+
except Exception as e:
|
419 |
+
traceback.print_exc()
|
420 |
+
content = json.dumps({"status": "error", "message": str(e)})
|
421 |
+
for c in valid_calls:
|
422 |
+
results.append(ToolMessage(content=content, tool_call_id=c['id'], name=c['name']))
|
423 |
+
|
424 |
+
return {"messages": results, "interaction_warnings": warnings or None}
|
425 |
+
|
426 |
+
|
427 |
+
def reflection_node(state: AgentState) -> Dict:
|
428 |
+
"""Review interaction warnings and adjust plan if needed."""
|
429 |
+
warnings = state.get('interaction_warnings')
|
430 |
+
if not warnings:
|
431 |
+
return {"messages": [], "interaction_warnings": None}
|
432 |
+
|
433 |
+
# Find the AI message that triggered the warnings
|
434 |
+
trigger_id = None
|
435 |
for msg in reversed(state['messages']):
|
436 |
+
if isinstance(msg, ToolMessage) and msg.name == 'check_drug_interactions':
|
437 |
+
trigger_id = msg.tool_call_id
|
438 |
+
break
|
439 |
+
|
440 |
+
prompt = (
|
441 |
+
f"Interaction warnings:\n{json.dumps(warnings, indent=2)}\n"
|
442 |
+
"Provide a revised therapeutics plan addressing these issues."
|
443 |
+
)
|
444 |
+
msgs = [
|
445 |
+
SystemMessage(content="Safety reflection on drug interactions."),
|
446 |
+
HumanMessage(content=prompt)
|
447 |
+
]
|
448 |
+
|
449 |
+
try:
|
450 |
+
resp = llm.invoke(msgs)
|
451 |
+
return {"messages": [AIMessage(content=resp.content)], "interaction_warnings": None}
|
452 |
+
except Exception as e:
|
453 |
+
traceback.print_exc()
|
454 |
+
return {"messages": [AIMessage(content=f"Reflection error: {e}")], "interaction_warnings": None}
|
455 |
+
|
456 |
+
|
457 |
def should_continue(state: AgentState) -> str:
|
458 |
+
last = state['messages'][-1] if state['messages'] else None
|
459 |
+
if not isinstance(last, AIMessage):
|
460 |
+
return 'end_conversation_turn'
|
461 |
+
if getattr(last, 'tool_calls', None):
|
462 |
+
return 'continue_tools'
|
463 |
+
return 'end_conversation_turn'
|
464 |
+
|
465 |
|
466 |
def after_tools_router(state: AgentState) -> str:
|
467 |
+
if state.get('interaction_warnings'):
|
468 |
+
return 'reflect_on_warnings'
|
469 |
+
return 'continue_to_agent'
|
470 |
+
|
471 |
|
|
|
472 |
class ClinicalAgent:
|
473 |
def __init__(self):
|
474 |
+
graph = StateGraph(AgentState)
|
475 |
+
graph.add_node('agent', agent_node)
|
476 |
+
graph.add_node('tools', tool_node)
|
477 |
+
graph.add_node('reflection', reflection_node)
|
478 |
+
graph.set_entry_point('agent')
|
479 |
+
graph.add_conditional_edges(
|
480 |
+
'agent', should_continue,
|
481 |
+
{'continue_tools': 'tools', 'end_conversation_turn': END}
|
482 |
+
)
|
483 |
+
graph.add_conditional_edges(
|
484 |
+
'tools', after_tools_router,
|
485 |
+
{'reflect_on_warnings': 'reflection', 'continue_to_agent': 'agent'}
|
486 |
+
)
|
487 |
+
graph.add_edge('reflection', 'agent')
|
488 |
+
self.app = graph.compile()
|
489 |
|
490 |
def invoke_turn(self, state: Dict) -> Dict:
|
|
|
|
|
491 |
try:
|
492 |
+
result = self.app.invoke(state, {'recursion_limit': 15})
|
493 |
+
result.setdefault('summary', state.get('summary'))
|
494 |
+
result.setdefault('interaction_warnings', None)
|
495 |
+
return result
|
496 |
except Exception as e:
|
497 |
+
traceback.print_exc()
|
498 |
+
err = AIMessage(content=f"Critical error: {e}")
|
499 |
+
return {
|
500 |
+
'messages': state.get('messages', []) + [err],
|
501 |
+
'patient_data': state.get('patient_data'),
|
502 |
+
'summary': state.get('summary'),
|
503 |
+
'interaction_warnings': None
|
504 |
+
}
|