mgbam commited on
Commit
7bcacfa
·
verified ·
1 Parent(s): 4b23857

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +60 -125
agent.py CHANGED
@@ -18,7 +18,6 @@ from langgraph.graph import StateGraph, END
18
  from typing import Optional, List, Dict, Any, TypedDict, Annotated
19
 
20
  # --- Environment Variable Loading ---
21
- # Keys are primarily used here, but checked in app.py for UI feedback
22
  UMLS_API_KEY = os.environ.get("UMLS_API_KEY")
23
  GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
24
  TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY")
@@ -29,59 +28,21 @@ AGENT_TEMPERATURE = 0.1
29
  MAX_SEARCH_RESULTS = 3
30
 
31
  class ClinicalPrompts:
32
- # The comprehensive system prompt defining agent behavior
33
  SYSTEM_PROMPT = """
34
- You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation.
35
- Your goal is to support healthcare professionals by analyzing patient data, providing differential diagnoses, suggesting evidence-based management plans, and identifying risks according to current standards of care.
36
-
37
- **Core Directives for this Conversation:**
38
- 1. **Analyze Sequentially:** Process information turn-by-turn. Base your responses on the *entire* conversation history.
39
- 2. **Seek Clarity:** If the provided information is insufficient or ambiguous for a safe assessment, CLEARLY STATE what specific additional information or clarification is needed. Do NOT guess or make unsafe assumptions.
40
- 3. **Structured Assessment (When Ready):** When you have sufficient information and have performed necessary checks (like interactions, guideline searches), provide a comprehensive assessment using the following JSON structure. Output this JSON structure as the primary content of your response when you are providing the full analysis. Do NOT output incomplete JSON. If you need to ask a question or perform a tool call first, do that instead of outputting this structure.
41
- ```json
42
- {
43
- "assessment": "Concise summary of the patient's presentation and key findings based on the conversation.",
44
- "differential_diagnosis": [
45
- {"diagnosis": "Primary Diagnosis", "likelihood": "High/Medium/Low", "rationale": "Supporting evidence from conversation..."},
46
- {"diagnosis": "Alternative Diagnosis 1", "likelihood": "Medium/Low", "rationale": "Supporting/Refuting evidence..."},
47
- {"diagnosis": "Alternative Diagnosis 2", "likelihood": "Low", "rationale": "Why it's less likely but considered..."}
48
- ],
49
- "risk_assessment": {
50
- "identified_red_flags": ["List any triggered red flags based on input and analysis"],
51
- "immediate_concerns": ["Specific urgent issues requiring attention (e.g., sepsis risk, ACS rule-out)"],
52
- "potential_complications": ["Possible future issues based on presentation"]
53
- },
54
- "recommended_plan": {
55
- "investigations": ["List specific lab tests or imaging required. Use 'order_lab_test' tool."],
56
- "therapeutics": ["Suggest specific treatments or prescriptions. Use 'prescribe_medication' tool. MUST check interactions first using 'check_drug_interactions'."],
57
- "consultations": ["Recommend specialist consultations if needed."],
58
- "patient_education": ["Key points for patient communication."]
59
- },
60
- "rationale_summary": "Justification for assessment/plan. **Crucially, if relevant (e.g., ACS, sepsis, common infections), use 'tavily_search_results' to find and cite current clinical practice guidelines (e.g., 'latest ACC/AHA chest pain guidelines 202X', 'Surviving Sepsis Campaign guidelines') supporting your recommendations.** Include summary of guideline findings here.",
61
- "interaction_check_summary": "Summary of findings from 'check_drug_interactions' if performed."
62
- }
63
- ```
64
- 4. **Safety First - Interactions:** BEFORE suggesting a new prescription via `prescribe_medication`, you MUST FIRST use `check_drug_interactions` in a preceding or concurrent tool call. Report the findings from the interaction check. If significant interactions exist, modify the plan or state the contraindication clearly.
65
- 5. **Safety First - Red Flags:** Use the `flag_risk` tool IMMEDIATELY if critical red flags requiring urgent action are identified at any point in the conversation.
66
- 6. **Tool Use:** Employ tools (`order_lab_test`, `prescribe_medication`, `check_drug_interactions`, `flag_risk`, `tavily_search_results`) logically within the conversational flow. Wait for tool results before proceeding if the result is needed for the next step (e.g., wait for interaction check before confirming prescription in the structured JSON).
67
- 7. **Evidence & Guidelines:** Actively use `tavily_search_results` not just for general knowledge, but specifically to query for and incorporate **current clinical practice guidelines** relevant to the patient's presentation (e.g., chest pain, shortness of breath, suspected infection). Summarize findings in the `rationale_summary` when providing the structured output.
68
- 8. **Conciseness & Flow:** Be medically accurate and concise. Use standard terminology. Respond naturally in conversation (asking questions, acknowledging info) until ready for the full structured JSON output.
69
  """
70
 
71
  # --- API Constants & Helper Functions ---
72
- UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
73
- RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
74
- OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
75
-
76
  @lru_cache(maxsize=256)
77
  def get_rxcui(drug_name: str) -> Optional[str]:
78
- """Uses RxNorm API to find the RxCUI for a given drug name."""
79
  if not drug_name or not isinstance(drug_name, str): return None; drug_name = drug_name.strip();
80
  if not drug_name: return None; print(f"RxNorm Lookup for: '{drug_name}'");
81
- try: # Try direct lookup first
82
  params = {"name": drug_name, "search": 1}; response = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10); response.raise_for_status(); data = response.json();
83
  if data and "idGroup" in data and "rxnormId" in data["idGroup"]: rxcui = data["idGroup"]["rxnormId"][0]; print(f" Found RxCUI: {rxcui} for '{drug_name}'"); return rxcui
84
- else: # Fallback to /drugs search
85
  params = {"name": drug_name}; response = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10); response.raise_for_status(); data = response.json();
86
  if data and "drugGroup" in data and "conceptGroup" in data["drugGroup"]:
87
  for group in data["drugGroup"]["conceptGroup"]:
@@ -95,7 +56,7 @@ def get_rxcui(drug_name: str) -> Optional[str]:
95
 
96
  @lru_cache(maxsize=128)
97
  def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[dict]:
98
- """Fetches drug label information from OpenFDA using RxCUI or drug name."""
99
  if not rxcui and not drug_name: return None; print(f"OpenFDA Label Lookup for: RXCUI={rxcui}, Name={drug_name}"); search_terms = []
100
  if rxcui: search_terms.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"')
101
  if drug_name: search_terms.append(f'(openfda.brand_name:"{drug_name.lower()}" OR openfda.generic_name:"{drug_name.lower()}")')
@@ -109,7 +70,7 @@ def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = No
109
  except Exception as e: print(f" Unexpected error in get_openfda_label: {e}"); return None
110
 
111
  def search_text_list(text_list: Optional[List[str]], search_terms: List[str]) -> List[str]:
112
- """ Case-insensitive search for any search_term within a list of text strings. Returns snippets. """
113
  found_snippets = [];
114
  if not text_list or not search_terms: return found_snippets; search_terms_lower = [str(term).lower() for term in search_terms if term];
115
  for text_item in text_list:
@@ -125,22 +86,14 @@ def search_text_list(text_list: Optional[List[str]], search_terms: List[str]) ->
125
 
126
  # --- Clinical Helper Functions ---
127
  def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
128
- """Parses BP string like '120/80' into (systolic, diastolic) integers."""
129
- if not isinstance(bp_string, str): return None
130
- match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string.strip())
131
- if match: return int(match.group(1)), int(match.group(2))
132
- return None
133
 
134
  def check_red_flags(patient_data: dict) -> List[str]:
135
- """Checks patient data against predefined red flags."""
136
- flags = []
137
- if not patient_data: return flags
138
- symptoms = patient_data.get("hpi", {}).get("symptoms", [])
139
- vitals = patient_data.get("vitals", {})
140
- history = patient_data.get("pmh", {}).get("conditions", "")
141
- symptoms_lower = [str(s).lower() for s in symptoms if isinstance(s, str)]
142
-
143
- # Symptom Flags
144
  if "chest pain" in symptoms_lower: flags.append("Red Flag: Chest Pain reported.")
145
  if "shortness of breath" in symptoms_lower: flags.append("Red Flag: Shortness of Breath reported.")
146
  if "severe headache" in symptoms_lower: flags.append("Red Flag: Severe Headache reported.")
@@ -148,40 +101,38 @@ def check_red_flags(patient_data: dict) -> List[str]:
148
  if "weakness on one side" in symptoms_lower: flags.append("Red Flag: Unilateral Weakness reported (potential stroke).")
149
  if "hemoptysis" in symptoms_lower: flags.append("Red Flag: Hemoptysis (coughing up blood).")
150
  if "syncope" in symptoms_lower: flags.append("Red Flag: Syncope (fainting).")
 
 
 
 
 
 
 
 
 
 
151
 
152
- # Vital Sign Flags
153
- if vitals:
154
- temp = vitals.get("temp_c"); hr = vitals.get("hr_bpm"); rr = vitals.get("rr_rpm")
155
- spo2 = vitals.get("spo2_percent"); bp_str = vitals.get("bp_mmhg")
156
- if temp is not None and temp >= 38.5: flags.append(f"Red Flag: Fever ({temp}°C).")
157
- if hr is not None and hr >= 120: flags.append(f"Red Flag: Tachycardia ({hr} bpm).")
158
- if hr is not None and hr <= 50: flags.append(f"Red Flag: Bradycardia ({hr} bpm).")
159
- if rr is not None and rr >= 24: flags.append(f"Red Flag: Tachypnea ({rr} rpm).")
160
- if spo2 is not None and spo2 <= 92: flags.append(f"Red Flag: Hypoxia ({spo2}%).")
161
- if bp_str:
162
- bp = parse_bp(bp_str)
163
- if bp:
164
- if bp[0] >= 180 or bp[1] >= 110: flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {bp_str} mmHg).")
165
- if bp[0] <= 90 or bp[1] <= 60: flags.append(f"Red Flag: Hypotension (BP: {bp_str} mmHg).")
166
-
167
- # History Flags
168
- if history and isinstance(history, str):
169
- history_lower = history.lower()
170
- if "history of mi" in history_lower and "chest pain" in symptoms_lower: flags.append("Red Flag: History of MI with current Chest Pain.")
171
- if "history of dvt/pe" in history_lower and "shortness of breath" in symptoms_lower: flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.")
172
-
173
- return list(set(flags)) # Unique flags
174
-
175
  def format_patient_data_for_prompt(data: dict) -> str:
176
  """Formats the patient dictionary into a readable string for the LLM."""
177
- if not data: return "No patient data provided."; prompt_str = "";
178
- for key, value in data.items(): section_title = key.replace('_', ' ').title();
179
- if isinstance(value, dict) and value: has_content = any(sub_value for sub_value in value.values());
180
- if has_content: prompt_str += f"**{section_title}:**\n";
181
- for sub_key, sub_value in value.items():
182
- if sub_value: prompt_str += f" - {sub_key.replace('_', ' ').title()}: {sub_value}\n"
183
- elif isinstance(value, list) and value: prompt_str += f"**{section_title}:** {', '.join(map(str, value))}\n"
184
- elif value and not isinstance(value, dict): prompt_str += f"**{section_title}:** {value}\n";
 
 
 
 
 
 
 
 
 
 
185
  return prompt_str.strip()
186
 
187
 
@@ -199,6 +150,7 @@ def prescribe_medication(medication_name: str, dosage: str, route: str, frequenc
199
  print(f"Executing prescribe_medication: {medication_name} {dosage}..."); return json.dumps({"status": "success", "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", "details": f"Duration: {duration}. Reason: {reason}"})
200
  @tool("check_drug_interactions", args_schema=InteractionCheckInput)
201
  def check_drug_interactions(potential_prescription: str, current_medications: Optional[List[str]] = None, allergies: Optional[List[str]] = None) -> str:
 
202
  print(f"\n--- Executing REAL check_drug_interactions ---"); print(f"Checking potential prescription: '{potential_prescription}'"); warnings = []; potential_med_lower = potential_prescription.lower().strip();
203
  current_meds_list = current_medications or []; allergies_list = allergies or []; current_med_names_lower = [];
204
  for med in current_meds_list: match = re.match(r"^\s*([a-zA-Z\-]+)", str(med));
@@ -233,28 +185,24 @@ def check_drug_interactions(potential_prescription: str, current_medications: Op
233
  return json.dumps({"status": status, "message": message, "warnings": final_warnings})
234
  @tool("flag_risk", args_schema=FlagRiskInput)
235
  def flag_risk(risk_description: str, urgency: str) -> str:
236
- print(f"Executing flag_risk: {risk_description}, Urgency: {urgency}"); # UI part in app.py
237
- return json.dumps({"status": "flagged", "message": f"Risk '{risk_description}' flagged with {urgency} urgency."})
238
  search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results")
239
  all_tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool]
240
 
241
  # --- LangGraph State & Nodes ---
242
  class AgentState(TypedDict): messages: Annotated[list[Any], operator.add]; patient_data: Optional[dict]; summary: Optional[str]; interaction_warnings: Optional[List[str]]
243
-
244
- llm = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME)
245
- model_with_tools = llm.bind_tools(all_tools)
246
- tool_executor = ToolExecutor(all_tools)
247
-
248
  def agent_node(state: AgentState):
 
249
  print("\n---AGENT NODE---"); current_messages = state['messages'];
250
  if not current_messages or not isinstance(current_messages[0], SystemMessage): print("Prepending System Prompt."); current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages;
251
  print(f"Invoking LLM with {len(current_messages)} messages.");
252
  try: response = model_with_tools.invoke(current_messages); print(f"Agent Raw Response Type: {type(response)}");
253
  if hasattr(response, 'tool_calls') and response.tool_calls: print(f"Agent Response Tool Calls: {response.tool_calls}"); else: print("Agent Response: No tool calls.");
254
  except Exception as e: print(f"ERROR in agent_node: {e}"); traceback.print_exc(); error_message = AIMessage(content=f"Error: {e}"); return {"messages": [error_message]};
255
- return {"messages": [response]} # Only return messages
256
-
257
  def tool_node(state: AgentState):
 
258
  print("\n---TOOL NODE---"); tool_messages = []; last_message = state['messages'][-1]; interaction_warnings_found = [];
259
  if not isinstance(last_message, AIMessage) or not getattr(last_message, 'tool_calls', None): print("Warning: Tool node called unexpectedly."); return {"messages": [], "interaction_warnings": None};
260
  tool_calls = last_message.tool_calls; print(f"Tool calls received: {json.dumps(tool_calls, indent=2)}"); prescriptions_requested = {}; interaction_checks_requested = {};
@@ -288,6 +236,7 @@ def tool_node(state: AgentState):
288
  return {"messages": tool_messages, "interaction_warnings": interaction_warnings_found or None} # Return messages AND warnings
289
 
290
  def reflection_node(state: AgentState):
 
291
  print("\n---REFLECTION NODE---")
292
  interaction_warnings = state.get("interaction_warnings")
293
  if not interaction_warnings: print("Warning: Reflection node called without warnings."); return {"messages": [], "interaction_warnings": None};
@@ -298,23 +247,22 @@ def reflection_node(state: AgentState):
298
  if any(tc['id'] in relevant_tool_call_ids for tc in msg.tool_calls): triggering_ai_message = msg; break;
299
  if not triggering_ai_message: print("Error: Could not find triggering AI message for reflection."); return {"messages": [AIMessage(content="Internal Error: Reflection context missing.")], "interaction_warnings": None};
300
  original_plan_proposal_context = triggering_ai_message.content;
301
- reflection_prompt_text = f"""You are SynapseAI, performing a critical safety review...
302
- Previous Context:\n{original_plan_proposal_context}\n---\nInteraction Warnings:\n```json\n{json.dumps(interaction_warnings, indent=2)}\n```\n**CRITICAL REFLECTION STEP:** Analyze warnings, decide if revision is needed, respond ONLY about therapeutics revision based on these warnings."""
303
  reflection_messages = [SystemMessage(content="Perform focused safety review based on interaction warnings."), HumanMessage(content=reflection_prompt_text)];
304
  print("Invoking LLM for reflection...");
305
  try: reflection_response = llm.invoke(reflection_messages); print(f"Reflection Response: {reflection_response.content}"); final_ai_message = AIMessage(content=reflection_response.content);
306
  except Exception as e: print(f"ERROR during reflection: {e}"); traceback.print_exc(); final_ai_message = AIMessage(content=f"Error during safety reflection: {e}");
307
  return {"messages": [final_ai_message], "interaction_warnings": None} # Return reflection response, clear warnings
308
 
309
-
310
  # --- Graph Routing Logic ---
311
  def should_continue(state: AgentState) -> str:
 
312
  print("\n---ROUTING DECISION (Agent Output)---"); last_message = state['messages'][-1] if state['messages'] else None;
313
  if not isinstance(last_message, AIMessage): return "end_conversation_turn";
314
  if "Sorry, an internal error occurred" in last_message.content: return "end_conversation_turn";
315
  if getattr(last_message, 'tool_calls', None): return "continue_tools"; else: return "end_conversation_turn";
316
-
317
  def after_tools_router(state: AgentState) -> str:
 
318
  print("\n---ROUTING DECISION (After Tools)---");
319
  if state.get("interaction_warnings"): print("Routing: Warnings found -> Reflection"); return "reflect_on_warnings";
320
  else: print("Routing: No warnings -> Agent"); return "continue_to_agent";
@@ -322,26 +270,13 @@ def after_tools_router(state: AgentState) -> str:
322
  # --- ClinicalAgent Class ---
323
  class ClinicalAgent:
324
  def __init__(self):
325
- workflow = StateGraph(AgentState)
326
- workflow.add_node("agent", agent_node)
327
- workflow.add_node("tools", tool_node)
328
- workflow.add_node("reflection", reflection_node)
329
- workflow.set_entry_point("agent")
330
- workflow.add_conditional_edges("agent", should_continue, {"continue_tools": "tools", "end_conversation_turn": END})
331
  workflow.add_conditional_edges("tools", after_tools_router, {"reflect_on_warnings": "reflection", "continue_to_agent": "agent"})
332
- workflow.add_edge("reflection", "agent")
333
- self.graph_app = workflow.compile()
334
- print("ClinicalAgent initialized and LangGraph compiled.")
335
-
336
  def invoke_turn(self, state: Dict) -> Dict:
337
- """Invokes the LangGraph app for one turn."""
338
- print(f"Invoking graph with state keys: {state.keys()}")
339
- try:
340
- final_state = self.graph_app.invoke(state, {"recursion_limit": 15})
341
- final_state.setdefault('summary', state.get('summary')) # Ensure keys exist
342
- final_state.setdefault('interaction_warnings', None)
343
- return final_state
344
- except Exception as e:
345
- print(f"CRITICAL ERROR during graph invocation: {type(e).__name__} - {e}"); traceback.print_exc();
346
- error_msg = AIMessage(content=f"Sorry, a critical error occurred during processing: {e}");
347
- return {"messages": state.get('messages', []) + [error_msg], "patient_data": state.get('patient_data'), "summary": state.get('summary'), "interaction_warnings": None}
 
18
  from typing import Optional, List, Dict, Any, TypedDict, Annotated
19
 
20
  # --- Environment Variable Loading ---
 
21
  UMLS_API_KEY = os.environ.get("UMLS_API_KEY")
22
  GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
23
  TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY")
 
28
  MAX_SEARCH_RESULTS = 3
29
 
30
  class ClinicalPrompts:
 
31
  SYSTEM_PROMPT = """
32
+ You are SynapseAI, an expert AI clinical assistant... [SYSTEM PROMPT OMITTED FOR BREVITY]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  """
34
 
35
  # --- API Constants & Helper Functions ---
36
+ UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"; RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"; OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
 
 
 
37
  @lru_cache(maxsize=256)
38
  def get_rxcui(drug_name: str) -> Optional[str]:
39
+ # ... (Keep implementation) ...
40
  if not drug_name or not isinstance(drug_name, str): return None; drug_name = drug_name.strip();
41
  if not drug_name: return None; print(f"RxNorm Lookup for: '{drug_name}'");
42
+ try:
43
  params = {"name": drug_name, "search": 1}; response = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10); response.raise_for_status(); data = response.json();
44
  if data and "idGroup" in data and "rxnormId" in data["idGroup"]: rxcui = data["idGroup"]["rxnormId"][0]; print(f" Found RxCUI: {rxcui} for '{drug_name}'"); return rxcui
45
+ else:
46
  params = {"name": drug_name}; response = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10); response.raise_for_status(); data = response.json();
47
  if data and "drugGroup" in data and "conceptGroup" in data["drugGroup"]:
48
  for group in data["drugGroup"]["conceptGroup"]:
 
56
 
57
  @lru_cache(maxsize=128)
58
  def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[dict]:
59
+ # ... (Keep implementation) ...
60
  if not rxcui and not drug_name: return None; print(f"OpenFDA Label Lookup for: RXCUI={rxcui}, Name={drug_name}"); search_terms = []
61
  if rxcui: search_terms.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"')
62
  if drug_name: search_terms.append(f'(openfda.brand_name:"{drug_name.lower()}" OR openfda.generic_name:"{drug_name.lower()}")')
 
70
  except Exception as e: print(f" Unexpected error in get_openfda_label: {e}"); return None
71
 
72
  def search_text_list(text_list: Optional[List[str]], search_terms: List[str]) -> List[str]:
73
+ # ... (Keep implementation) ...
74
  found_snippets = [];
75
  if not text_list or not search_terms: return found_snippets; search_terms_lower = [str(term).lower() for term in search_terms if term];
76
  for text_item in text_list:
 
86
 
87
  # --- Clinical Helper Functions ---
88
  def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
89
+ # ... (Keep implementation) ...
90
+ if not isinstance(bp_string, str): return None; match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string.strip());
91
+ if match: return int(match.group(1)), int(match.group(2)); return None
 
 
92
 
93
  def check_red_flags(patient_data: dict) -> List[str]:
94
+ # ... (Keep implementation with multi-line ifs) ...
95
+ flags = [];
96
+ if not patient_data: return flags; symptoms = patient_data.get("hpi", {}).get("symptoms", []); vitals = patient_data.get("vitals", {}); history = patient_data.get("pmh", {}).get("conditions", ""); symptoms_lower = [str(s).lower() for s in symptoms if isinstance(s, str)];
 
 
 
 
 
 
97
  if "chest pain" in symptoms_lower: flags.append("Red Flag: Chest Pain reported.")
98
  if "shortness of breath" in symptoms_lower: flags.append("Red Flag: Shortness of Breath reported.")
99
  if "severe headache" in symptoms_lower: flags.append("Red Flag: Severe Headache reported.")
 
101
  if "weakness on one side" in symptoms_lower: flags.append("Red Flag: Unilateral Weakness reported (potential stroke).")
102
  if "hemoptysis" in symptoms_lower: flags.append("Red Flag: Hemoptysis (coughing up blood).")
103
  if "syncope" in symptoms_lower: flags.append("Red Flag: Syncope (fainting).")
104
+ if vitals: temp = vitals.get("temp_c"); hr = vitals.get("hr_bpm"); rr = vitals.get("rr_rpm"); spo2 = vitals.get("spo2_percent"); bp_str = vitals.get("bp_mmhg");
105
+ if temp is not None and temp >= 38.5: flags.append(f"Red Flag: Fever ({temp}°C)."); if hr is not None and hr >= 120: flags.append(f"Red Flag: Tachycardia ({hr} bpm)."); if hr is not None and hr <= 50: flags.append(f"Red Flag: Bradycardia ({hr} bpm)."); if rr is not None and rr >= 24: flags.append(f"Red Flag: Tachypnea ({rr} rpm)."); if spo2 is not None and spo2 <= 92: flags.append(f"Red Flag: Hypoxia ({spo2}%).");
106
+ if bp_str: bp = parse_bp(bp_str);
107
+ if bp:
108
+ if bp[0] >= 180 or bp[1] >= 110: flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {bp_str} mmHg).");
109
+ if bp[0] <= 90 or bp[1] <= 60: flags.append(f"Red Flag: Hypotension (BP: {bp_str} mmHg).");
110
+ if history and isinstance(history, str): history_lower = history.lower();
111
+ if "history of mi" in history_lower and "chest pain" in symptoms_lower: flags.append("Red Flag: History of MI with current Chest Pain.");
112
+ if "history of dvt/pe" in history_lower and "shortness of breath" in symptoms_lower: flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.");
113
+ return list(set(flags))
114
 
115
+ # CORRECTED format_patient_data_for_prompt function indentation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def format_patient_data_for_prompt(data: dict) -> str:
117
  """Formats the patient dictionary into a readable string for the LLM."""
118
+ if not data: return "No patient data provided."
119
+ prompt_str = ""
120
+ for key, value in data.items():
121
+ section_title = key.replace('_', ' ').title()
122
+ # Check if the value is a dictionary and has content
123
+ if isinstance(value, dict) and value:
124
+ has_content = any(sub_value for sub_value in value.values())
125
+ if has_content:
126
+ prompt_str += f"**{section_title}:**\n"
127
+ for sub_key, sub_value in value.items():
128
+ if sub_value: # Only add if sub-value is truthy
129
+ prompt_str += f" - {sub_key.replace('_', ' ').title()}: {sub_value}\n"
130
+ # Check if the value is a non-empty list
131
+ elif isinstance(value, list) and value: # <-- Correct indentation
132
+ prompt_str += f"**{section_title}:** {', '.join(map(str, value))}\n"
133
+ # Check if the value is truthy and not a dictionary (handles strings, numbers, etc.)
134
+ elif value and not isinstance(value, dict): # <-- Correct indentation
135
+ prompt_str += f"**{section_title}:** {value}\n"
136
  return prompt_str.strip()
137
 
138
 
 
150
  print(f"Executing prescribe_medication: {medication_name} {dosage}..."); return json.dumps({"status": "success", "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", "details": f"Duration: {duration}. Reason: {reason}"})
151
  @tool("check_drug_interactions", args_schema=InteractionCheckInput)
152
  def check_drug_interactions(potential_prescription: str, current_medications: Optional[List[str]] = None, allergies: Optional[List[str]] = None) -> str:
153
+ # ... (Keep the FULL implementation of the NEW check_drug_interactions using API helpers) ...
154
  print(f"\n--- Executing REAL check_drug_interactions ---"); print(f"Checking potential prescription: '{potential_prescription}'"); warnings = []; potential_med_lower = potential_prescription.lower().strip();
155
  current_meds_list = current_medications or []; allergies_list = allergies or []; current_med_names_lower = [];
156
  for med in current_meds_list: match = re.match(r"^\s*([a-zA-Z\-]+)", str(med));
 
185
  return json.dumps({"status": status, "message": message, "warnings": final_warnings})
186
  @tool("flag_risk", args_schema=FlagRiskInput)
187
  def flag_risk(risk_description: str, urgency: str) -> str:
188
+ print(f"Executing flag_risk: {risk_description}, Urgency: {urgency}"); return json.dumps({"status": "flagged", "message": f"Risk '{risk_description}' flagged with {urgency} urgency."})
 
189
  search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results")
190
  all_tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool]
191
 
192
  # --- LangGraph State & Nodes ---
193
  class AgentState(TypedDict): messages: Annotated[list[Any], operator.add]; patient_data: Optional[dict]; summary: Optional[str]; interaction_warnings: Optional[List[str]]
194
+ llm = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME); model_with_tools = llm.bind_tools(all_tools); tool_executor = ToolExecutor(all_tools)
 
 
 
 
195
  def agent_node(state: AgentState):
196
+ # ... (Keep implementation) ...
197
  print("\n---AGENT NODE---"); current_messages = state['messages'];
198
  if not current_messages or not isinstance(current_messages[0], SystemMessage): print("Prepending System Prompt."); current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages;
199
  print(f"Invoking LLM with {len(current_messages)} messages.");
200
  try: response = model_with_tools.invoke(current_messages); print(f"Agent Raw Response Type: {type(response)}");
201
  if hasattr(response, 'tool_calls') and response.tool_calls: print(f"Agent Response Tool Calls: {response.tool_calls}"); else: print("Agent Response: No tool calls.");
202
  except Exception as e: print(f"ERROR in agent_node: {e}"); traceback.print_exc(); error_message = AIMessage(content=f"Error: {e}"); return {"messages": [error_message]};
203
+ return {"messages": [response]}
 
204
  def tool_node(state: AgentState):
205
+ # ... (Keep implementation) ...
206
  print("\n---TOOL NODE---"); tool_messages = []; last_message = state['messages'][-1]; interaction_warnings_found = [];
207
  if not isinstance(last_message, AIMessage) or not getattr(last_message, 'tool_calls', None): print("Warning: Tool node called unexpectedly."); return {"messages": [], "interaction_warnings": None};
208
  tool_calls = last_message.tool_calls; print(f"Tool calls received: {json.dumps(tool_calls, indent=2)}"); prescriptions_requested = {}; interaction_checks_requested = {};
 
236
  return {"messages": tool_messages, "interaction_warnings": interaction_warnings_found or None} # Return messages AND warnings
237
 
238
  def reflection_node(state: AgentState):
239
+ # ... (Keep implementation) ...
240
  print("\n---REFLECTION NODE---")
241
  interaction_warnings = state.get("interaction_warnings")
242
  if not interaction_warnings: print("Warning: Reflection node called without warnings."); return {"messages": [], "interaction_warnings": None};
 
247
  if any(tc['id'] in relevant_tool_call_ids for tc in msg.tool_calls): triggering_ai_message = msg; break;
248
  if not triggering_ai_message: print("Error: Could not find triggering AI message for reflection."); return {"messages": [AIMessage(content="Internal Error: Reflection context missing.")], "interaction_warnings": None};
249
  original_plan_proposal_context = triggering_ai_message.content;
250
+ reflection_prompt_text = f"""You are SynapseAI, performing a critical safety review... [PROMPT OMITTED FOR BREVITY]""" # Use full prompt
 
251
  reflection_messages = [SystemMessage(content="Perform focused safety review based on interaction warnings."), HumanMessage(content=reflection_prompt_text)];
252
  print("Invoking LLM for reflection...");
253
  try: reflection_response = llm.invoke(reflection_messages); print(f"Reflection Response: {reflection_response.content}"); final_ai_message = AIMessage(content=reflection_response.content);
254
  except Exception as e: print(f"ERROR during reflection: {e}"); traceback.print_exc(); final_ai_message = AIMessage(content=f"Error during safety reflection: {e}");
255
  return {"messages": [final_ai_message], "interaction_warnings": None} # Return reflection response, clear warnings
256
 
 
257
  # --- Graph Routing Logic ---
258
  def should_continue(state: AgentState) -> str:
259
+ # ... (Keep implementation) ...
260
  print("\n---ROUTING DECISION (Agent Output)---"); last_message = state['messages'][-1] if state['messages'] else None;
261
  if not isinstance(last_message, AIMessage): return "end_conversation_turn";
262
  if "Sorry, an internal error occurred" in last_message.content: return "end_conversation_turn";
263
  if getattr(last_message, 'tool_calls', None): return "continue_tools"; else: return "end_conversation_turn";
 
264
  def after_tools_router(state: AgentState) -> str:
265
+ # ... (Keep implementation) ...
266
  print("\n---ROUTING DECISION (After Tools)---");
267
  if state.get("interaction_warnings"): print("Routing: Warnings found -> Reflection"); return "reflect_on_warnings";
268
  else: print("Routing: No warnings -> Agent"); return "continue_to_agent";
 
270
  # --- ClinicalAgent Class ---
271
  class ClinicalAgent:
272
  def __init__(self):
273
+ # ... (Keep graph compilation) ...
274
+ workflow = StateGraph(AgentState); workflow.add_node("agent", agent_node); workflow.add_node("tools", tool_node); workflow.add_node("reflection", reflection_node)
275
+ workflow.set_entry_point("agent"); workflow.add_conditional_edges("agent", should_continue, {"continue_tools": "tools", "end_conversation_turn": END})
 
 
 
276
  workflow.add_conditional_edges("tools", after_tools_router, {"reflect_on_warnings": "reflection", "continue_to_agent": "agent"})
277
+ workflow.add_edge("reflection", "agent"); self.graph_app = workflow.compile(); print("ClinicalAgent initialized and LangGraph compiled.")
 
 
 
278
  def invoke_turn(self, state: Dict) -> Dict:
279
+ # ... (Keep implementation) ...
280
+ print(f"Invoking graph with state keys: {state.keys()}");
281
+ try: final_state = self.graph_app.invoke(state, {"recursion_limit": 15}); final_state.setdefault('summary', state.get('summary')); final_state.setdefault('interaction_warnings', None); return final_state
282
+ except Exception as e: print(f"CRITICAL ERROR during graph invocation: {type(e).__name__} - {e}"); traceback.print_exc(); error_msg = AIMessage(content=f"Sorry, error occurred: {e}"); return {"messages": state.get('messages', []) + [error_msg], "patient_data": state.get('patient_data'), "summary": state.get('summary'), "interaction_warnings": None}