Update agent.py
Browse files
agent.py
CHANGED
@@ -18,7 +18,6 @@ from langgraph.graph import StateGraph, END
|
|
18 |
from typing import Optional, List, Dict, Any, TypedDict, Annotated
|
19 |
|
20 |
# --- Environment Variable Loading ---
|
21 |
-
# Keys are primarily used here, but checked in app.py for UI feedback
|
22 |
UMLS_API_KEY = os.environ.get("UMLS_API_KEY")
|
23 |
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
|
24 |
TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY")
|
@@ -29,59 +28,21 @@ AGENT_TEMPERATURE = 0.1
|
|
29 |
MAX_SEARCH_RESULTS = 3
|
30 |
|
31 |
class ClinicalPrompts:
|
32 |
-
# The comprehensive system prompt defining agent behavior
|
33 |
SYSTEM_PROMPT = """
|
34 |
-
You are SynapseAI, an expert AI clinical assistant
|
35 |
-
Your goal is to support healthcare professionals by analyzing patient data, providing differential diagnoses, suggesting evidence-based management plans, and identifying risks according to current standards of care.
|
36 |
-
|
37 |
-
**Core Directives for this Conversation:**
|
38 |
-
1. **Analyze Sequentially:** Process information turn-by-turn. Base your responses on the *entire* conversation history.
|
39 |
-
2. **Seek Clarity:** If the provided information is insufficient or ambiguous for a safe assessment, CLEARLY STATE what specific additional information or clarification is needed. Do NOT guess or make unsafe assumptions.
|
40 |
-
3. **Structured Assessment (When Ready):** When you have sufficient information and have performed necessary checks (like interactions, guideline searches), provide a comprehensive assessment using the following JSON structure. Output this JSON structure as the primary content of your response when you are providing the full analysis. Do NOT output incomplete JSON. If you need to ask a question or perform a tool call first, do that instead of outputting this structure.
|
41 |
-
```json
|
42 |
-
{
|
43 |
-
"assessment": "Concise summary of the patient's presentation and key findings based on the conversation.",
|
44 |
-
"differential_diagnosis": [
|
45 |
-
{"diagnosis": "Primary Diagnosis", "likelihood": "High/Medium/Low", "rationale": "Supporting evidence from conversation..."},
|
46 |
-
{"diagnosis": "Alternative Diagnosis 1", "likelihood": "Medium/Low", "rationale": "Supporting/Refuting evidence..."},
|
47 |
-
{"diagnosis": "Alternative Diagnosis 2", "likelihood": "Low", "rationale": "Why it's less likely but considered..."}
|
48 |
-
],
|
49 |
-
"risk_assessment": {
|
50 |
-
"identified_red_flags": ["List any triggered red flags based on input and analysis"],
|
51 |
-
"immediate_concerns": ["Specific urgent issues requiring attention (e.g., sepsis risk, ACS rule-out)"],
|
52 |
-
"potential_complications": ["Possible future issues based on presentation"]
|
53 |
-
},
|
54 |
-
"recommended_plan": {
|
55 |
-
"investigations": ["List specific lab tests or imaging required. Use 'order_lab_test' tool."],
|
56 |
-
"therapeutics": ["Suggest specific treatments or prescriptions. Use 'prescribe_medication' tool. MUST check interactions first using 'check_drug_interactions'."],
|
57 |
-
"consultations": ["Recommend specialist consultations if needed."],
|
58 |
-
"patient_education": ["Key points for patient communication."]
|
59 |
-
},
|
60 |
-
"rationale_summary": "Justification for assessment/plan. **Crucially, if relevant (e.g., ACS, sepsis, common infections), use 'tavily_search_results' to find and cite current clinical practice guidelines (e.g., 'latest ACC/AHA chest pain guidelines 202X', 'Surviving Sepsis Campaign guidelines') supporting your recommendations.** Include summary of guideline findings here.",
|
61 |
-
"interaction_check_summary": "Summary of findings from 'check_drug_interactions' if performed."
|
62 |
-
}
|
63 |
-
```
|
64 |
-
4. **Safety First - Interactions:** BEFORE suggesting a new prescription via `prescribe_medication`, you MUST FIRST use `check_drug_interactions` in a preceding or concurrent tool call. Report the findings from the interaction check. If significant interactions exist, modify the plan or state the contraindication clearly.
|
65 |
-
5. **Safety First - Red Flags:** Use the `flag_risk` tool IMMEDIATELY if critical red flags requiring urgent action are identified at any point in the conversation.
|
66 |
-
6. **Tool Use:** Employ tools (`order_lab_test`, `prescribe_medication`, `check_drug_interactions`, `flag_risk`, `tavily_search_results`) logically within the conversational flow. Wait for tool results before proceeding if the result is needed for the next step (e.g., wait for interaction check before confirming prescription in the structured JSON).
|
67 |
-
7. **Evidence & Guidelines:** Actively use `tavily_search_results` not just for general knowledge, but specifically to query for and incorporate **current clinical practice guidelines** relevant to the patient's presentation (e.g., chest pain, shortness of breath, suspected infection). Summarize findings in the `rationale_summary` when providing the structured output.
|
68 |
-
8. **Conciseness & Flow:** Be medically accurate and concise. Use standard terminology. Respond naturally in conversation (asking questions, acknowledging info) until ready for the full structured JSON output.
|
69 |
"""
|
70 |
|
71 |
# --- API Constants & Helper Functions ---
|
72 |
-
UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
|
73 |
-
RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
|
74 |
-
OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
|
75 |
-
|
76 |
@lru_cache(maxsize=256)
|
77 |
def get_rxcui(drug_name: str) -> Optional[str]:
|
78 |
-
|
79 |
if not drug_name or not isinstance(drug_name, str): return None; drug_name = drug_name.strip();
|
80 |
if not drug_name: return None; print(f"RxNorm Lookup for: '{drug_name}'");
|
81 |
-
try:
|
82 |
params = {"name": drug_name, "search": 1}; response = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10); response.raise_for_status(); data = response.json();
|
83 |
if data and "idGroup" in data and "rxnormId" in data["idGroup"]: rxcui = data["idGroup"]["rxnormId"][0]; print(f" Found RxCUI: {rxcui} for '{drug_name}'"); return rxcui
|
84 |
-
else:
|
85 |
params = {"name": drug_name}; response = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10); response.raise_for_status(); data = response.json();
|
86 |
if data and "drugGroup" in data and "conceptGroup" in data["drugGroup"]:
|
87 |
for group in data["drugGroup"]["conceptGroup"]:
|
@@ -95,7 +56,7 @@ def get_rxcui(drug_name: str) -> Optional[str]:
|
|
95 |
|
96 |
@lru_cache(maxsize=128)
|
97 |
def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[dict]:
|
98 |
-
|
99 |
if not rxcui and not drug_name: return None; print(f"OpenFDA Label Lookup for: RXCUI={rxcui}, Name={drug_name}"); search_terms = []
|
100 |
if rxcui: search_terms.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"')
|
101 |
if drug_name: search_terms.append(f'(openfda.brand_name:"{drug_name.lower()}" OR openfda.generic_name:"{drug_name.lower()}")')
|
@@ -109,7 +70,7 @@ def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = No
|
|
109 |
except Exception as e: print(f" Unexpected error in get_openfda_label: {e}"); return None
|
110 |
|
111 |
def search_text_list(text_list: Optional[List[str]], search_terms: List[str]) -> List[str]:
|
112 |
-
|
113 |
found_snippets = [];
|
114 |
if not text_list or not search_terms: return found_snippets; search_terms_lower = [str(term).lower() for term in search_terms if term];
|
115 |
for text_item in text_list:
|
@@ -125,22 +86,14 @@ def search_text_list(text_list: Optional[List[str]], search_terms: List[str]) ->
|
|
125 |
|
126 |
# --- Clinical Helper Functions ---
|
127 |
def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
|
128 |
-
|
129 |
-
if not isinstance(bp_string, str): return None
|
130 |
-
match
|
131 |
-
if match: return int(match.group(1)), int(match.group(2))
|
132 |
-
return None
|
133 |
|
134 |
def check_red_flags(patient_data: dict) -> List[str]:
|
135 |
-
|
136 |
-
flags = []
|
137 |
-
if not patient_data: return flags
|
138 |
-
symptoms = patient_data.get("hpi", {}).get("symptoms", [])
|
139 |
-
vitals = patient_data.get("vitals", {})
|
140 |
-
history = patient_data.get("pmh", {}).get("conditions", "")
|
141 |
-
symptoms_lower = [str(s).lower() for s in symptoms if isinstance(s, str)]
|
142 |
-
|
143 |
-
# Symptom Flags
|
144 |
if "chest pain" in symptoms_lower: flags.append("Red Flag: Chest Pain reported.")
|
145 |
if "shortness of breath" in symptoms_lower: flags.append("Red Flag: Shortness of Breath reported.")
|
146 |
if "severe headache" in symptoms_lower: flags.append("Red Flag: Severe Headache reported.")
|
@@ -148,40 +101,38 @@ def check_red_flags(patient_data: dict) -> List[str]:
|
|
148 |
if "weakness on one side" in symptoms_lower: flags.append("Red Flag: Unilateral Weakness reported (potential stroke).")
|
149 |
if "hemoptysis" in symptoms_lower: flags.append("Red Flag: Hemoptysis (coughing up blood).")
|
150 |
if "syncope" in symptoms_lower: flags.append("Red Flag: Syncope (fainting).")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
-
|
153 |
-
if vitals:
|
154 |
-
temp = vitals.get("temp_c"); hr = vitals.get("hr_bpm"); rr = vitals.get("rr_rpm")
|
155 |
-
spo2 = vitals.get("spo2_percent"); bp_str = vitals.get("bp_mmhg")
|
156 |
-
if temp is not None and temp >= 38.5: flags.append(f"Red Flag: Fever ({temp}°C).")
|
157 |
-
if hr is not None and hr >= 120: flags.append(f"Red Flag: Tachycardia ({hr} bpm).")
|
158 |
-
if hr is not None and hr <= 50: flags.append(f"Red Flag: Bradycardia ({hr} bpm).")
|
159 |
-
if rr is not None and rr >= 24: flags.append(f"Red Flag: Tachypnea ({rr} rpm).")
|
160 |
-
if spo2 is not None and spo2 <= 92: flags.append(f"Red Flag: Hypoxia ({spo2}%).")
|
161 |
-
if bp_str:
|
162 |
-
bp = parse_bp(bp_str)
|
163 |
-
if bp:
|
164 |
-
if bp[0] >= 180 or bp[1] >= 110: flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {bp_str} mmHg).")
|
165 |
-
if bp[0] <= 90 or bp[1] <= 60: flags.append(f"Red Flag: Hypotension (BP: {bp_str} mmHg).")
|
166 |
-
|
167 |
-
# History Flags
|
168 |
-
if history and isinstance(history, str):
|
169 |
-
history_lower = history.lower()
|
170 |
-
if "history of mi" in history_lower and "chest pain" in symptoms_lower: flags.append("Red Flag: History of MI with current Chest Pain.")
|
171 |
-
if "history of dvt/pe" in history_lower and "shortness of breath" in symptoms_lower: flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.")
|
172 |
-
|
173 |
-
return list(set(flags)) # Unique flags
|
174 |
-
|
175 |
def format_patient_data_for_prompt(data: dict) -> str:
|
176 |
"""Formats the patient dictionary into a readable string for the LLM."""
|
177 |
-
if not data: return "No patient data provided."
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
if
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
return prompt_str.strip()
|
186 |
|
187 |
|
@@ -199,6 +150,7 @@ def prescribe_medication(medication_name: str, dosage: str, route: str, frequenc
|
|
199 |
print(f"Executing prescribe_medication: {medication_name} {dosage}..."); return json.dumps({"status": "success", "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", "details": f"Duration: {duration}. Reason: {reason}"})
|
200 |
@tool("check_drug_interactions", args_schema=InteractionCheckInput)
|
201 |
def check_drug_interactions(potential_prescription: str, current_medications: Optional[List[str]] = None, allergies: Optional[List[str]] = None) -> str:
|
|
|
202 |
print(f"\n--- Executing REAL check_drug_interactions ---"); print(f"Checking potential prescription: '{potential_prescription}'"); warnings = []; potential_med_lower = potential_prescription.lower().strip();
|
203 |
current_meds_list = current_medications or []; allergies_list = allergies or []; current_med_names_lower = [];
|
204 |
for med in current_meds_list: match = re.match(r"^\s*([a-zA-Z\-]+)", str(med));
|
@@ -233,28 +185,24 @@ def check_drug_interactions(potential_prescription: str, current_medications: Op
|
|
233 |
return json.dumps({"status": status, "message": message, "warnings": final_warnings})
|
234 |
@tool("flag_risk", args_schema=FlagRiskInput)
|
235 |
def flag_risk(risk_description: str, urgency: str) -> str:
|
236 |
-
print(f"Executing flag_risk: {risk_description}, Urgency: {urgency}");
|
237 |
-
return json.dumps({"status": "flagged", "message": f"Risk '{risk_description}' flagged with {urgency} urgency."})
|
238 |
search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results")
|
239 |
all_tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool]
|
240 |
|
241 |
# --- LangGraph State & Nodes ---
|
242 |
class AgentState(TypedDict): messages: Annotated[list[Any], operator.add]; patient_data: Optional[dict]; summary: Optional[str]; interaction_warnings: Optional[List[str]]
|
243 |
-
|
244 |
-
llm = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME)
|
245 |
-
model_with_tools = llm.bind_tools(all_tools)
|
246 |
-
tool_executor = ToolExecutor(all_tools)
|
247 |
-
|
248 |
def agent_node(state: AgentState):
|
|
|
249 |
print("\n---AGENT NODE---"); current_messages = state['messages'];
|
250 |
if not current_messages or not isinstance(current_messages[0], SystemMessage): print("Prepending System Prompt."); current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages;
|
251 |
print(f"Invoking LLM with {len(current_messages)} messages.");
|
252 |
try: response = model_with_tools.invoke(current_messages); print(f"Agent Raw Response Type: {type(response)}");
|
253 |
if hasattr(response, 'tool_calls') and response.tool_calls: print(f"Agent Response Tool Calls: {response.tool_calls}"); else: print("Agent Response: No tool calls.");
|
254 |
except Exception as e: print(f"ERROR in agent_node: {e}"); traceback.print_exc(); error_message = AIMessage(content=f"Error: {e}"); return {"messages": [error_message]};
|
255 |
-
return {"messages": [response]}
|
256 |
-
|
257 |
def tool_node(state: AgentState):
|
|
|
258 |
print("\n---TOOL NODE---"); tool_messages = []; last_message = state['messages'][-1]; interaction_warnings_found = [];
|
259 |
if not isinstance(last_message, AIMessage) or not getattr(last_message, 'tool_calls', None): print("Warning: Tool node called unexpectedly."); return {"messages": [], "interaction_warnings": None};
|
260 |
tool_calls = last_message.tool_calls; print(f"Tool calls received: {json.dumps(tool_calls, indent=2)}"); prescriptions_requested = {}; interaction_checks_requested = {};
|
@@ -288,6 +236,7 @@ def tool_node(state: AgentState):
|
|
288 |
return {"messages": tool_messages, "interaction_warnings": interaction_warnings_found or None} # Return messages AND warnings
|
289 |
|
290 |
def reflection_node(state: AgentState):
|
|
|
291 |
print("\n---REFLECTION NODE---")
|
292 |
interaction_warnings = state.get("interaction_warnings")
|
293 |
if not interaction_warnings: print("Warning: Reflection node called without warnings."); return {"messages": [], "interaction_warnings": None};
|
@@ -298,23 +247,22 @@ def reflection_node(state: AgentState):
|
|
298 |
if any(tc['id'] in relevant_tool_call_ids for tc in msg.tool_calls): triggering_ai_message = msg; break;
|
299 |
if not triggering_ai_message: print("Error: Could not find triggering AI message for reflection."); return {"messages": [AIMessage(content="Internal Error: Reflection context missing.")], "interaction_warnings": None};
|
300 |
original_plan_proposal_context = triggering_ai_message.content;
|
301 |
-
reflection_prompt_text = f"""You are SynapseAI, performing a critical safety review...
|
302 |
-
Previous Context:\n{original_plan_proposal_context}\n---\nInteraction Warnings:\n```json\n{json.dumps(interaction_warnings, indent=2)}\n```\n**CRITICAL REFLECTION STEP:** Analyze warnings, decide if revision is needed, respond ONLY about therapeutics revision based on these warnings."""
|
303 |
reflection_messages = [SystemMessage(content="Perform focused safety review based on interaction warnings."), HumanMessage(content=reflection_prompt_text)];
|
304 |
print("Invoking LLM for reflection...");
|
305 |
try: reflection_response = llm.invoke(reflection_messages); print(f"Reflection Response: {reflection_response.content}"); final_ai_message = AIMessage(content=reflection_response.content);
|
306 |
except Exception as e: print(f"ERROR during reflection: {e}"); traceback.print_exc(); final_ai_message = AIMessage(content=f"Error during safety reflection: {e}");
|
307 |
return {"messages": [final_ai_message], "interaction_warnings": None} # Return reflection response, clear warnings
|
308 |
|
309 |
-
|
310 |
# --- Graph Routing Logic ---
|
311 |
def should_continue(state: AgentState) -> str:
|
|
|
312 |
print("\n---ROUTING DECISION (Agent Output)---"); last_message = state['messages'][-1] if state['messages'] else None;
|
313 |
if not isinstance(last_message, AIMessage): return "end_conversation_turn";
|
314 |
if "Sorry, an internal error occurred" in last_message.content: return "end_conversation_turn";
|
315 |
if getattr(last_message, 'tool_calls', None): return "continue_tools"; else: return "end_conversation_turn";
|
316 |
-
|
317 |
def after_tools_router(state: AgentState) -> str:
|
|
|
318 |
print("\n---ROUTING DECISION (After Tools)---");
|
319 |
if state.get("interaction_warnings"): print("Routing: Warnings found -> Reflection"); return "reflect_on_warnings";
|
320 |
else: print("Routing: No warnings -> Agent"); return "continue_to_agent";
|
@@ -322,26 +270,13 @@ def after_tools_router(state: AgentState) -> str:
|
|
322 |
# --- ClinicalAgent Class ---
|
323 |
class ClinicalAgent:
|
324 |
def __init__(self):
|
325 |
-
|
326 |
-
workflow.add_node("agent", agent_node)
|
327 |
-
workflow.
|
328 |
-
workflow.add_node("reflection", reflection_node)
|
329 |
-
workflow.set_entry_point("agent")
|
330 |
-
workflow.add_conditional_edges("agent", should_continue, {"continue_tools": "tools", "end_conversation_turn": END})
|
331 |
workflow.add_conditional_edges("tools", after_tools_router, {"reflect_on_warnings": "reflection", "continue_to_agent": "agent"})
|
332 |
-
workflow.add_edge("reflection", "agent")
|
333 |
-
self.graph_app = workflow.compile()
|
334 |
-
print("ClinicalAgent initialized and LangGraph compiled.")
|
335 |
-
|
336 |
def invoke_turn(self, state: Dict) -> Dict:
|
337 |
-
|
338 |
-
print(f"Invoking graph with state keys: {state.keys()}")
|
339 |
-
try:
|
340 |
-
|
341 |
-
final_state.setdefault('summary', state.get('summary')) # Ensure keys exist
|
342 |
-
final_state.setdefault('interaction_warnings', None)
|
343 |
-
return final_state
|
344 |
-
except Exception as e:
|
345 |
-
print(f"CRITICAL ERROR during graph invocation: {type(e).__name__} - {e}"); traceback.print_exc();
|
346 |
-
error_msg = AIMessage(content=f"Sorry, a critical error occurred during processing: {e}");
|
347 |
-
return {"messages": state.get('messages', []) + [error_msg], "patient_data": state.get('patient_data'), "summary": state.get('summary'), "interaction_warnings": None}
|
|
|
18 |
from typing import Optional, List, Dict, Any, TypedDict, Annotated
|
19 |
|
20 |
# --- Environment Variable Loading ---
|
|
|
21 |
UMLS_API_KEY = os.environ.get("UMLS_API_KEY")
|
22 |
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
|
23 |
TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY")
|
|
|
28 |
MAX_SEARCH_RESULTS = 3
|
29 |
|
30 |
class ClinicalPrompts:
|
|
|
31 |
SYSTEM_PROMPT = """
|
32 |
+
You are SynapseAI, an expert AI clinical assistant... [SYSTEM PROMPT OMITTED FOR BREVITY]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
"""
|
34 |
|
35 |
# --- API Constants & Helper Functions ---
|
36 |
+
UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"; RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"; OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
|
|
|
|
|
|
|
37 |
@lru_cache(maxsize=256)
|
38 |
def get_rxcui(drug_name: str) -> Optional[str]:
|
39 |
+
# ... (Keep implementation) ...
|
40 |
if not drug_name or not isinstance(drug_name, str): return None; drug_name = drug_name.strip();
|
41 |
if not drug_name: return None; print(f"RxNorm Lookup for: '{drug_name}'");
|
42 |
+
try:
|
43 |
params = {"name": drug_name, "search": 1}; response = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10); response.raise_for_status(); data = response.json();
|
44 |
if data and "idGroup" in data and "rxnormId" in data["idGroup"]: rxcui = data["idGroup"]["rxnormId"][0]; print(f" Found RxCUI: {rxcui} for '{drug_name}'"); return rxcui
|
45 |
+
else:
|
46 |
params = {"name": drug_name}; response = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10); response.raise_for_status(); data = response.json();
|
47 |
if data and "drugGroup" in data and "conceptGroup" in data["drugGroup"]:
|
48 |
for group in data["drugGroup"]["conceptGroup"]:
|
|
|
56 |
|
57 |
@lru_cache(maxsize=128)
|
58 |
def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[dict]:
|
59 |
+
# ... (Keep implementation) ...
|
60 |
if not rxcui and not drug_name: return None; print(f"OpenFDA Label Lookup for: RXCUI={rxcui}, Name={drug_name}"); search_terms = []
|
61 |
if rxcui: search_terms.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"')
|
62 |
if drug_name: search_terms.append(f'(openfda.brand_name:"{drug_name.lower()}" OR openfda.generic_name:"{drug_name.lower()}")')
|
|
|
70 |
except Exception as e: print(f" Unexpected error in get_openfda_label: {e}"); return None
|
71 |
|
72 |
def search_text_list(text_list: Optional[List[str]], search_terms: List[str]) -> List[str]:
|
73 |
+
# ... (Keep implementation) ...
|
74 |
found_snippets = [];
|
75 |
if not text_list or not search_terms: return found_snippets; search_terms_lower = [str(term).lower() for term in search_terms if term];
|
76 |
for text_item in text_list:
|
|
|
86 |
|
87 |
# --- Clinical Helper Functions ---
|
88 |
def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
|
89 |
+
# ... (Keep implementation) ...
|
90 |
+
if not isinstance(bp_string, str): return None; match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string.strip());
|
91 |
+
if match: return int(match.group(1)), int(match.group(2)); return None
|
|
|
|
|
92 |
|
93 |
def check_red_flags(patient_data: dict) -> List[str]:
|
94 |
+
# ... (Keep implementation with multi-line ifs) ...
|
95 |
+
flags = [];
|
96 |
+
if not patient_data: return flags; symptoms = patient_data.get("hpi", {}).get("symptoms", []); vitals = patient_data.get("vitals", {}); history = patient_data.get("pmh", {}).get("conditions", ""); symptoms_lower = [str(s).lower() for s in symptoms if isinstance(s, str)];
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
if "chest pain" in symptoms_lower: flags.append("Red Flag: Chest Pain reported.")
|
98 |
if "shortness of breath" in symptoms_lower: flags.append("Red Flag: Shortness of Breath reported.")
|
99 |
if "severe headache" in symptoms_lower: flags.append("Red Flag: Severe Headache reported.")
|
|
|
101 |
if "weakness on one side" in symptoms_lower: flags.append("Red Flag: Unilateral Weakness reported (potential stroke).")
|
102 |
if "hemoptysis" in symptoms_lower: flags.append("Red Flag: Hemoptysis (coughing up blood).")
|
103 |
if "syncope" in symptoms_lower: flags.append("Red Flag: Syncope (fainting).")
|
104 |
+
if vitals: temp = vitals.get("temp_c"); hr = vitals.get("hr_bpm"); rr = vitals.get("rr_rpm"); spo2 = vitals.get("spo2_percent"); bp_str = vitals.get("bp_mmhg");
|
105 |
+
if temp is not None and temp >= 38.5: flags.append(f"Red Flag: Fever ({temp}°C)."); if hr is not None and hr >= 120: flags.append(f"Red Flag: Tachycardia ({hr} bpm)."); if hr is not None and hr <= 50: flags.append(f"Red Flag: Bradycardia ({hr} bpm)."); if rr is not None and rr >= 24: flags.append(f"Red Flag: Tachypnea ({rr} rpm)."); if spo2 is not None and spo2 <= 92: flags.append(f"Red Flag: Hypoxia ({spo2}%).");
|
106 |
+
if bp_str: bp = parse_bp(bp_str);
|
107 |
+
if bp:
|
108 |
+
if bp[0] >= 180 or bp[1] >= 110: flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {bp_str} mmHg).");
|
109 |
+
if bp[0] <= 90 or bp[1] <= 60: flags.append(f"Red Flag: Hypotension (BP: {bp_str} mmHg).");
|
110 |
+
if history and isinstance(history, str): history_lower = history.lower();
|
111 |
+
if "history of mi" in history_lower and "chest pain" in symptoms_lower: flags.append("Red Flag: History of MI with current Chest Pain.");
|
112 |
+
if "history of dvt/pe" in history_lower and "shortness of breath" in symptoms_lower: flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.");
|
113 |
+
return list(set(flags))
|
114 |
|
115 |
+
# CORRECTED format_patient_data_for_prompt function indentation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
def format_patient_data_for_prompt(data: dict) -> str:
|
117 |
"""Formats the patient dictionary into a readable string for the LLM."""
|
118 |
+
if not data: return "No patient data provided."
|
119 |
+
prompt_str = ""
|
120 |
+
for key, value in data.items():
|
121 |
+
section_title = key.replace('_', ' ').title()
|
122 |
+
# Check if the value is a dictionary and has content
|
123 |
+
if isinstance(value, dict) and value:
|
124 |
+
has_content = any(sub_value for sub_value in value.values())
|
125 |
+
if has_content:
|
126 |
+
prompt_str += f"**{section_title}:**\n"
|
127 |
+
for sub_key, sub_value in value.items():
|
128 |
+
if sub_value: # Only add if sub-value is truthy
|
129 |
+
prompt_str += f" - {sub_key.replace('_', ' ').title()}: {sub_value}\n"
|
130 |
+
# Check if the value is a non-empty list
|
131 |
+
elif isinstance(value, list) and value: # <-- Correct indentation
|
132 |
+
prompt_str += f"**{section_title}:** {', '.join(map(str, value))}\n"
|
133 |
+
# Check if the value is truthy and not a dictionary (handles strings, numbers, etc.)
|
134 |
+
elif value and not isinstance(value, dict): # <-- Correct indentation
|
135 |
+
prompt_str += f"**{section_title}:** {value}\n"
|
136 |
return prompt_str.strip()
|
137 |
|
138 |
|
|
|
150 |
print(f"Executing prescribe_medication: {medication_name} {dosage}..."); return json.dumps({"status": "success", "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", "details": f"Duration: {duration}. Reason: {reason}"})
|
151 |
@tool("check_drug_interactions", args_schema=InteractionCheckInput)
|
152 |
def check_drug_interactions(potential_prescription: str, current_medications: Optional[List[str]] = None, allergies: Optional[List[str]] = None) -> str:
|
153 |
+
# ... (Keep the FULL implementation of the NEW check_drug_interactions using API helpers) ...
|
154 |
print(f"\n--- Executing REAL check_drug_interactions ---"); print(f"Checking potential prescription: '{potential_prescription}'"); warnings = []; potential_med_lower = potential_prescription.lower().strip();
|
155 |
current_meds_list = current_medications or []; allergies_list = allergies or []; current_med_names_lower = [];
|
156 |
for med in current_meds_list: match = re.match(r"^\s*([a-zA-Z\-]+)", str(med));
|
|
|
185 |
return json.dumps({"status": status, "message": message, "warnings": final_warnings})
|
186 |
@tool("flag_risk", args_schema=FlagRiskInput)
|
187 |
def flag_risk(risk_description: str, urgency: str) -> str:
|
188 |
+
print(f"Executing flag_risk: {risk_description}, Urgency: {urgency}"); return json.dumps({"status": "flagged", "message": f"Risk '{risk_description}' flagged with {urgency} urgency."})
|
|
|
189 |
search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results")
|
190 |
all_tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool]
|
191 |
|
192 |
# --- LangGraph State & Nodes ---
|
193 |
class AgentState(TypedDict): messages: Annotated[list[Any], operator.add]; patient_data: Optional[dict]; summary: Optional[str]; interaction_warnings: Optional[List[str]]
|
194 |
+
llm = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME); model_with_tools = llm.bind_tools(all_tools); tool_executor = ToolExecutor(all_tools)
|
|
|
|
|
|
|
|
|
195 |
def agent_node(state: AgentState):
|
196 |
+
# ... (Keep implementation) ...
|
197 |
print("\n---AGENT NODE---"); current_messages = state['messages'];
|
198 |
if not current_messages or not isinstance(current_messages[0], SystemMessage): print("Prepending System Prompt."); current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages;
|
199 |
print(f"Invoking LLM with {len(current_messages)} messages.");
|
200 |
try: response = model_with_tools.invoke(current_messages); print(f"Agent Raw Response Type: {type(response)}");
|
201 |
if hasattr(response, 'tool_calls') and response.tool_calls: print(f"Agent Response Tool Calls: {response.tool_calls}"); else: print("Agent Response: No tool calls.");
|
202 |
except Exception as e: print(f"ERROR in agent_node: {e}"); traceback.print_exc(); error_message = AIMessage(content=f"Error: {e}"); return {"messages": [error_message]};
|
203 |
+
return {"messages": [response]}
|
|
|
204 |
def tool_node(state: AgentState):
|
205 |
+
# ... (Keep implementation) ...
|
206 |
print("\n---TOOL NODE---"); tool_messages = []; last_message = state['messages'][-1]; interaction_warnings_found = [];
|
207 |
if not isinstance(last_message, AIMessage) or not getattr(last_message, 'tool_calls', None): print("Warning: Tool node called unexpectedly."); return {"messages": [], "interaction_warnings": None};
|
208 |
tool_calls = last_message.tool_calls; print(f"Tool calls received: {json.dumps(tool_calls, indent=2)}"); prescriptions_requested = {}; interaction_checks_requested = {};
|
|
|
236 |
return {"messages": tool_messages, "interaction_warnings": interaction_warnings_found or None} # Return messages AND warnings
|
237 |
|
238 |
def reflection_node(state: AgentState):
|
239 |
+
# ... (Keep implementation) ...
|
240 |
print("\n---REFLECTION NODE---")
|
241 |
interaction_warnings = state.get("interaction_warnings")
|
242 |
if not interaction_warnings: print("Warning: Reflection node called without warnings."); return {"messages": [], "interaction_warnings": None};
|
|
|
247 |
if any(tc['id'] in relevant_tool_call_ids for tc in msg.tool_calls): triggering_ai_message = msg; break;
|
248 |
if not triggering_ai_message: print("Error: Could not find triggering AI message for reflection."); return {"messages": [AIMessage(content="Internal Error: Reflection context missing.")], "interaction_warnings": None};
|
249 |
original_plan_proposal_context = triggering_ai_message.content;
|
250 |
+
reflection_prompt_text = f"""You are SynapseAI, performing a critical safety review... [PROMPT OMITTED FOR BREVITY]""" # Use full prompt
|
|
|
251 |
reflection_messages = [SystemMessage(content="Perform focused safety review based on interaction warnings."), HumanMessage(content=reflection_prompt_text)];
|
252 |
print("Invoking LLM for reflection...");
|
253 |
try: reflection_response = llm.invoke(reflection_messages); print(f"Reflection Response: {reflection_response.content}"); final_ai_message = AIMessage(content=reflection_response.content);
|
254 |
except Exception as e: print(f"ERROR during reflection: {e}"); traceback.print_exc(); final_ai_message = AIMessage(content=f"Error during safety reflection: {e}");
|
255 |
return {"messages": [final_ai_message], "interaction_warnings": None} # Return reflection response, clear warnings
|
256 |
|
|
|
257 |
# --- Graph Routing Logic ---
|
258 |
def should_continue(state: AgentState) -> str:
|
259 |
+
# ... (Keep implementation) ...
|
260 |
print("\n---ROUTING DECISION (Agent Output)---"); last_message = state['messages'][-1] if state['messages'] else None;
|
261 |
if not isinstance(last_message, AIMessage): return "end_conversation_turn";
|
262 |
if "Sorry, an internal error occurred" in last_message.content: return "end_conversation_turn";
|
263 |
if getattr(last_message, 'tool_calls', None): return "continue_tools"; else: return "end_conversation_turn";
|
|
|
264 |
def after_tools_router(state: AgentState) -> str:
|
265 |
+
# ... (Keep implementation) ...
|
266 |
print("\n---ROUTING DECISION (After Tools)---");
|
267 |
if state.get("interaction_warnings"): print("Routing: Warnings found -> Reflection"); return "reflect_on_warnings";
|
268 |
else: print("Routing: No warnings -> Agent"); return "continue_to_agent";
|
|
|
270 |
# --- ClinicalAgent Class ---
|
271 |
class ClinicalAgent:
|
272 |
def __init__(self):
|
273 |
+
# ... (Keep graph compilation) ...
|
274 |
+
workflow = StateGraph(AgentState); workflow.add_node("agent", agent_node); workflow.add_node("tools", tool_node); workflow.add_node("reflection", reflection_node)
|
275 |
+
workflow.set_entry_point("agent"); workflow.add_conditional_edges("agent", should_continue, {"continue_tools": "tools", "end_conversation_turn": END})
|
|
|
|
|
|
|
276 |
workflow.add_conditional_edges("tools", after_tools_router, {"reflect_on_warnings": "reflection", "continue_to_agent": "agent"})
|
277 |
+
workflow.add_edge("reflection", "agent"); self.graph_app = workflow.compile(); print("ClinicalAgent initialized and LangGraph compiled.")
|
|
|
|
|
|
|
278 |
def invoke_turn(self, state: Dict) -> Dict:
|
279 |
+
# ... (Keep implementation) ...
|
280 |
+
print(f"Invoking graph with state keys: {state.keys()}");
|
281 |
+
try: final_state = self.graph_app.invoke(state, {"recursion_limit": 15}); final_state.setdefault('summary', state.get('summary')); final_state.setdefault('interaction_warnings', None); return final_state
|
282 |
+
except Exception as e: print(f"CRITICAL ERROR during graph invocation: {type(e).__name__} - {e}"); traceback.print_exc(); error_msg = AIMessage(content=f"Sorry, error occurred: {e}"); return {"messages": state.get('messages', []) + [error_msg], "patient_data": state.get('patient_data'), "summary": state.get('summary'), "interaction_warnings": None}
|
|
|
|
|
|
|
|
|
|
|
|
|
|