mgbam commited on
Commit
86911ce
·
verified ·
1 Parent(s): 9cecafe

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +434 -303
agent.py CHANGED
@@ -1,308 +1,439 @@
1
- # app.py
2
- import streamlit as st
3
- import json
4
- import re
5
  import os
 
 
 
6
  import traceback
7
- from dotenv import load_dotenv
8
-
9
- # Import agent logic and message types from agent.py
10
- try:
11
- from agent import ClinicalAgent, AgentState, check_red_flags
12
- from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
13
- except ImportError as e:
14
- st.error(f"Failed to import from agent.py: {e}. Make sure agent.py is in the same directory.")
15
- st.stop()
16
-
17
- # --- Environment Variable Loading & Validation ---
18
- load_dotenv()
19
- # Check keys required by agent.py are present before initializing the agent
20
- UMLS_API_KEY = os.environ.get("UMLS_API_KEY")
21
- GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
22
- TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY")
23
- missing_keys = []
24
- if not UMLS_API_KEY: missing_keys.append("UMLS_API_KEY")
25
- if not GROQ_API_KEY: missing_keys.append("GROQ_API_KEY")
26
- if not TAVILY_API_KEY: missing_keys.append("TAVILY_API_KEY")
27
- if missing_keys:
28
- st.error(f"Missing required API Key(s): {', '.join(missing_keys)}. Please set them in Hugging Face Space Secrets or environment variables.")
29
- st.stop()
30
-
31
- # --- App Configuration ---
32
- class ClinicalAppSettings:
33
- APP_TITLE = "SynapseAI (UMLS/FDA Integrated)"
34
- PAGE_LAYOUT = "wide"
35
- MODEL_NAME_DISPLAY = "Llama3-70b (via Groq)" # Defined in agent.py
36
-
37
- # --- Streamlit UI ---
38
- def main():
39
- st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
40
- st.title(f"🩺 {ClinicalAppSettings.APP_TITLE}")
41
- st.caption(f"Interactive Assistant | LangGraph/Groq/Tavily/UMLS/OpenFDA | Model: {ClinicalAppSettings.MODEL_NAME_DISPLAY}")
42
-
43
- # Initialize session state
44
- if "messages" not in st.session_state:
45
- st.session_state.messages = []
46
- if "patient_data" not in st.session_state:
47
- st.session_state.patient_data = None
48
- if "summary" not in st.session_state:
49
- st.session_state.summary = None
50
-
51
- # Initialize the agent instance only once
52
- if "agent" not in st.session_state:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  try:
54
- st.session_state.agent = ClinicalAgent()
55
- print("ClinicalAgent successfully initialized in Streamlit session state.")
 
 
56
  except Exception as e:
57
- st.error(f"Failed to initialize Clinical Agent: {e}. Check API keys and dependencies.")
58
- print(f"ERROR Initializing ClinicalAgent: {e}")
59
- traceback.print_exc()
60
- st.stop()
61
-
62
- # --- Patient Data Input Sidebar ---
63
- with st.sidebar:
64
- st.header("📄 Patient Intake Form")
65
- # Input fields... (Using shorter versions for brevity, assume full fields are here)
66
- st.subheader("Demographics")
67
- age = st.number_input("Age", 0, 120, 55, key="sb_age")
68
- sex = st.selectbox("Sex", ["Male", "Female", "Other"], key="sb_sex")
69
-
70
- st.subheader("HPI")
71
- chief_complaint = st.text_input("Chief Complaint", "Chest pain", key="sb_cc")
72
- hpi_details = st.text_area("HPI Details", "55 y/o male...", height=100, key="sb_hpi")
73
- symptoms = st.multiselect(
74
- "Symptoms",
75
- ["Nausea", "Diaphoresis", "SOB", "Dizziness", "Severe Headache", "Syncope", "Hemoptysis"],
76
- default=["Nausea", "Diaphoresis"],
77
- key="sb_sym"
78
- )
79
-
80
- st.subheader("History")
81
- pmh = st.text_area("PMH", "HTN, HLD, DM2, History of MI", key="sb_pmh")
82
- psh = st.text_area("PSH", "Appendectomy", key="sb_psh")
83
-
84
- st.subheader("Meds & Allergies")
85
- current_meds_str = st.text_area(
86
- "Current Meds",
87
- "Lisinopril 10mg daily\nMetformin 1000mg BID\nWarfarin 5mg daily",
88
- key="sb_meds"
89
- )
90
- allergies_str = st.text_area("Allergies", "Penicillin (rash), Aspirin", key="sb_allergies")
91
-
92
- st.subheader("Social/Family")
93
- social_history = st.text_area("SH", "Smoker", key="sb_sh")
94
- family_history = st.text_area("FHx", "Father MI", key="sb_fhx")
95
-
96
- st.subheader("Vitals & Exam")
97
- col1, col2 = st.columns(2)
98
- with col1:
99
- temp_c = st.number_input("Temp C", 35.0, 42.0, 36.8, format="%.1f", key="sb_temp")
100
- hr_bpm = st.number_input("HR", 30, 250, 95, key="sb_hr")
101
- rr_rpm = st.number_input("RR", 5, 50, 18, key="sb_rr")
102
- with col2:
103
- bp_mmhg = st.text_input("BP", "155/90", key="sb_bp")
104
- spo2_percent = st.number_input("SpO2", 70, 100, 96, key="sb_spo2")
105
- pain_scale = st.slider("Pain", 0, 10, 8, key="sb_pain")
106
- exam_notes = st.text_area("Exam Notes", "Awake, alert...", height=68, key="sb_exam")
107
-
108
- if st.button("Start/Update Consultation", key="sb_start"):
109
- # Compile data...
110
- current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
111
- current_med_names_only = []
112
- for med in current_meds_list:
113
- match = re.match(r"^\s*([a-zA-Z\-]+)", med)
114
- if match:
115
- current_med_names_only.append(match.group(1).lower())
116
-
117
- allergies_list = []
118
- for a in allergies_str.split(','):
119
- cleaned_allergy = a.strip()
120
- if cleaned_allergy:
121
- match = re.match(r"^\s*([a-zA-Z\-\s/]+)(?:\s*\(.*\))?", cleaned_allergy)
122
- name_part = match.group(1).strip().lower() if match else cleaned_allergy.lower()
123
- allergies_list.append(name_part)
124
-
125
- # Update patient data in session state
126
- st.session_state.patient_data = {
127
- "demographics": {"age": age, "sex": sex},
128
- "hpi": {"chief_complaint": chief_complaint, "details": hpi_details, "symptoms": symptoms},
129
- "pmh": {"conditions": pmh},
130
- "psh": {"procedures": psh},
131
- "medications": {"current": current_meds_list, "names_only": current_med_names_only},
132
- "allergies": allergies_list,
133
- "social_history": {"details": social_history},
134
- "family_history": {"details": family_history},
135
- "vitals": {
136
- "temp_c": temp_c,
137
- "hr_bpm": hr_bpm,
138
- "bp_mmhg": bp_mmhg,
139
- "rr_rpm": rr_rpm,
140
- "spo2_percent": spo2_percent,
141
- "pain_scale": pain_scale
142
- },
143
- "exam_findings": {"notes": exam_notes}
144
  }
145
-
146
- # Call check_red_flags from agent module
147
- red_flags = check_red_flags(st.session_state.patient_data)
148
- st.sidebar.markdown("---")
149
- if red_flags:
150
- st.sidebar.warning("**Initial Red Flags:**")
151
- for flag in red_flags:
152
- st.sidebar.warning(f"- {flag.replace('Red Flag: ', '')}")
153
- else:
154
- st.sidebar.success("No immediate red flags.")
155
-
156
- # Reset conversation and summary on new intake
157
- initial_prompt = "Initiate consultation. Review patient data and begin analysis."
158
- st.session_state.messages = [HumanMessage(content=initial_prompt)]
159
- st.session_state.summary = None # Reset summary
160
- st.success("Patient data loaded/updated.")
161
- # Rerun might be needed if the main area should clear or update based on new data
162
- st.rerun()
163
-
164
- # --- Main Chat Interface Area ---
165
- st.header("💬 Clinical Consultation")
166
- # Display loop
167
- for msg in st.session_state.messages:
168
- if isinstance(msg, HumanMessage):
169
- with st.chat_message("user"):
170
- st.markdown(msg.content)
171
- elif isinstance(msg, AIMessage):
172
- with st.chat_message("assistant"):
173
- ai_content = msg.content
174
- structured_output = None
175
- try:
176
- # JSON Parsing logic...
177
- json_match = re.search(r"```json\s*(\{.*?\})\s*```", ai_content, re.DOTALL | re.IGNORECASE)
178
- if json_match:
179
- json_str = json_match.group(1)
180
- prefix = ai_content[:json_match.start()].strip()
181
- suffix = ai_content[json_match.end():].strip()
182
- if prefix:
183
- st.markdown(prefix)
184
- structured_output = json.loads(json_str)
185
- if suffix:
186
- st.markdown(suffix)
187
- elif ai_content.strip().startswith("{") and ai_content.strip().endswith("}"):
188
- structured_output = json.loads(ai_content)
189
- ai_content = ""
190
- else:
191
- st.markdown(ai_content) # Display non-JSON content
192
- except Exception as e:
193
- st.markdown(ai_content)
194
- print(f"Error parsing/displaying AI JSON: {e}")
195
-
196
- if structured_output and isinstance(structured_output, dict):
197
- # Structured JSON display logic...
198
- st.divider()
199
- st.subheader("📊 AI Analysis & Recommendations")
200
- cols = st.columns(2)
201
- with cols[0]:
202
- st.markdown("**Assessment:**")
203
- st.markdown(f"> {structured_output.get('assessment', 'N/A')}")
204
- st.markdown("**Differential Diagnosis:**")
205
- ddx = structured_output.get('differential_diagnosis', [])
206
- if ddx:
207
- for item in ddx:
208
- likelihood = item.get('likelihood', 'Low')
209
- medal = ('🥇' if likelihood.startswith('H') else '🥈' if likelihood.startswith('M') else '🥉')
210
- expander_title = f"{medal} {item.get('diagnosis', 'Unknown')} ({likelihood})"
211
- with st.expander(expander_title):
212
- st.write(f"**Rationale:** {item.get('rationale', 'N/A')}")
213
- else:
214
- st.info("No DDx provided.")
215
- st.markdown("**Risk Assessment:**")
216
- risk = structured_output.get('risk_assessment', {})
217
- flags = risk.get('identified_red_flags', [])
218
- concerns = risk.get('immediate_concerns', [])
219
- comps = risk.get('potential_complications', [])
220
- if flags:
221
- st.warning(f"**Flags:** {', '.join(flags)}")
222
- if concerns:
223
- st.warning(f"**Concerns:** {', '.join(concerns)}")
224
- if comps:
225
- st.info(f"**Potential Complications:** {', '.join(comps)}")
226
- if not flags and not concerns:
227
- st.success("No major risks highlighted.")
228
- with cols[1]:
229
- st.markdown("**Recommended Plan:**")
230
- plan = structured_output.get('recommended_plan', {})
231
- for section in ["investigations","therapeutics","consultations","patient_education"]:
232
- st.markdown(f"_{section.replace('_',' ').capitalize()}:_")
233
- items = plan.get(section)
234
- if items and isinstance(items, list):
235
- for it in items:
236
- st.markdown(f"- {it}")
237
- elif items:
238
- st.markdown(f"- {items}")
239
- else:
240
- st.markdown("_None_")
241
- st.markdown("")
242
- st.markdown("**Rationale & Guideline Check:**")
243
- st.markdown(f"> {structured_output.get('rationale_summary', 'N/A')}")
244
- interaction_summary = structured_output.get('interaction_check_summary', "")
245
- if interaction_summary:
246
- st.markdown("**Interaction Check Summary:**")
247
- st.markdown(f"> {interaction_summary}")
248
- st.divider()
249
-
250
- # Tool Call Display
251
- if getattr(msg, 'tool_calls', None):
252
- with st.expander("🛠️ AI requested actions", expanded=False):
253
- if msg.tool_calls:
254
- for tc in msg.tool_calls:
255
- try:
256
- st.code(
257
- f"Action: {tc.get('name', 'Unknown Tool')}\nArgs: {json.dumps(tc.get('args', {}), indent=2)}",
258
- language="json"
259
- )
260
- except Exception as display_e:
261
- st.error(f"Could not display tool call args: {display_e}", icon="⚠️")
262
- st.code(f"Action: {tc.get('name', 'Unknown Tool')}\nRaw Args: {tc.get('args')}")
263
- else:
264
- st.caption("_No actions requested._")
265
-
266
- # --- Chat Input Logic ---
267
- if prompt := st.chat_input("Your message or follow-up query..."):
268
- if not st.session_state.patient_data:
269
- st.warning("Please load patient data first.")
270
- st.stop()
271
- if 'agent' not in st.session_state or not st.session_state.agent:
272
- st.error("Agent not initialized. Check logs.")
273
- st.stop()
274
-
275
- # Append user message and display immediately
276
- user_message = HumanMessage(content=prompt)
277
- st.session_state.messages.append(user_message)
278
- with st.chat_message("user"):
279
- st.markdown(prompt)
280
-
281
- # Prepare state for the agent
282
- current_state_dict = {
283
- "messages": st.session_state.messages,
284
- "patient_data": st.session_state.patient_data,
285
- "summary": st.session_state.get("summary"),
286
- "interaction_warnings": None # Start clean
287
- }
288
-
289
- # Invoke the agent's graph for one turn
290
- with st.spinner("SynapseAI is processing..."):
291
- try:
292
- final_state = st.session_state.agent.invoke_turn(current_state_dict)
293
- st.session_state.messages = final_state.get('messages', [])
294
- st.session_state.summary = final_state.get('summary')
295
- except Exception as e:
296
- print(f"CRITICAL ERROR during agent invocation: {type(e).__name__} - {e}")
297
- traceback.print_exc()
298
- st.error(f"An error occurred during processing: {e}", icon="❌")
299
- st.session_state.messages.append(AIMessage(content=f"Error during processing: {e}"))
300
-
301
- st.rerun()
302
-
303
- # Disclaimer
304
- st.markdown("---")
305
- st.warning("**Disclaimer:** SynapseAI is for demonstration...")
306
-
307
- if __name__ == "__main__":
308
- main()
 
1
+ # agent.py
2
+
 
 
3
  import os
4
+ import re
5
+ import json
6
+ import logging
7
  import traceback
8
+ from functools import lru_cache
9
+ from typing import List, Dict, Any, Optional, TypedDict
10
+
11
+ import requests
12
+ from langchain_groq import ChatGroq
13
+ from langchain_community.tools.tavily_search import TavilySearchResults
14
+ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
15
+ from langchain_core.pydantic_v1 import BaseModel, Field
16
+ from langchain_core.tools import tool
17
+ from langgraph.prebuilt import ToolExecutor
18
+ from langgraph.graph import StateGraph, END
19
+
20
+ # ── Logging Configuration ─────────────────────────────────────────────────────
21
+ logger = logging.getLogger(__name__)
22
+ logging.basicConfig(level=logging.INFO)
23
+
24
+ # ── Environment Variables ─────────────────────────────────────────────────────
25
+ UMLS_API_KEY = os.getenv("UMLS_API_KEY")
26
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
27
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
28
+
29
+ if not all([UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY]):
30
+ logger.error("Missing one or more required API keys: UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY")
31
+ raise RuntimeError("Missing required API keys")
32
+
33
+ # ── Agent Configuration ───────────────────────────────────────────────────────
34
+ AGENT_MODEL_NAME = "llama3-70b-8192"
35
+ AGENT_TEMPERATURE = 0.1
36
+ MAX_SEARCH_RESULTS = 3
37
+
38
+ class ClinicalPrompts:
39
+ SYSTEM_PROMPT = """
40
+ You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation...
41
+ [SYSTEM PROMPT CONTENT HERE]
42
+ """
43
+
44
+ # ── Helper Functions ──────────────────────────────────────────────────────────
45
+ UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
46
+ RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
47
+ OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
48
+
49
+ @lru_cache(maxsize=256)
50
+ def get_rxcui(drug_name: str) -> Optional[str]:
51
+ """Lookup RxNorm CUI for a drug name."""
52
+ drug_name = (drug_name or "").strip()
53
+ if not drug_name:
54
+ return None
55
+ logger.info(f"Looking up RxCUI for '{drug_name}'")
56
+ try:
57
+ params = {"name": drug_name, "search": 1}
58
+ r = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10)
59
+ r.raise_for_status()
60
+ data = r.json().get("idGroup", {})
61
+ if ids := data.get("rxnormId"):
62
+ logger.info(f"Found RxCUI {ids[0]} for '{drug_name}'")
63
+ return ids[0]
64
+ # fallback to broader search
65
+ r = requests.get(f"{RXNORM_API_BASE}/drugs.json", params={"name": drug_name}, timeout=10)
66
+ r.raise_for_status()
67
+ for group in r.json().get("drugGroup", {}).get("conceptGroup", []):
68
+ if props := group.get("conceptProperties"):
69
+ logger.info(f"Found RxCUI {props[0]['rxcui']} via /drugs for '{drug_name}'")
70
+ return props[0]["rxcui"]
71
+ except Exception:
72
+ logger.exception(f"Error fetching RxCUI for '{drug_name}'")
73
+ return None
74
+
75
+ @lru_cache(maxsize=128)
76
+ def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[Dict[str, Any]]:
77
+ """Fetch label data from OpenFDA by RxCUI or drug name."""
78
+ if not (rxcui or drug_name):
79
+ return None
80
+ terms = []
81
+ if rxcui:
82
+ terms.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"')
83
+ if drug_name:
84
+ dn = drug_name.lower()
85
+ terms.append(f'(openfda.brand_name:"{dn}" OR openfda.generic_name:"{dn}")')
86
+ query = " OR ".join(terms)
87
+ logger.info(f"Looking up OpenFDA label with query: {query}")
88
+ try:
89
+ r = requests.get(OPENFDA_API_BASE, params={"search": query, "limit": 1}, timeout=15)
90
+ r.raise_for_status()
91
+ results = r.json().get("results", [])
92
+ if results:
93
+ return results[0]
94
+ except Exception:
95
+ logger.exception("Error fetching OpenFDA label")
96
+ return None
97
+
98
+ def search_text_list(texts: List[str], terms: List[str]) -> List[str]:
99
+ """Return snippets from texts containing any of the search terms."""
100
+ snippets = []
101
+ lowers = [t.lower() for t in terms if t]
102
+ for text in texts or []:
103
+ tl = text.lower()
104
+ for term in lowers:
105
+ if term in tl:
106
+ i = tl.find(term)
107
+ start = max(0, i - 50)
108
+ end = min(len(text), i + len(term) + 100)
109
+ snippet = text[start:end]
110
+ snippet = re.sub(f"({re.escape(term)})", r"**\1**", snippet, flags=re.IGNORECASE)
111
+ snippets.append(f"...{snippet}...")
112
+ break
113
+ return snippets
114
+
115
+ def parse_bp(bp: str) -> Optional[tuple[int, int]]:
116
+ """Parse a 'SYS/DIA' blood pressure string into a tuple."""
117
+ if m := re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", (bp or "").strip()):
118
+ return int(m.group(1)), int(m.group(2))
119
+ return None
120
+
121
+ def check_red_flags(patient_data: Dict[str, Any]) -> List[str]:
122
+ """Identify immediate red flags in patient data."""
123
+ flags: List[str] = []
124
+ hpi = patient_data.get("hpi", {})
125
+ vitals = patient_data.get("vitals", {})
126
+ symptoms = [s.lower() for s in hpi.get("symptoms", []) if isinstance(s, str)]
127
+ # Symptom-based
128
+ mapping = {
129
+ "chest pain": "Chest pain reported",
130
+ "shortness of breath": "Shortness of breath reported",
131
+ "severe headache": "Severe headache reported",
132
+ "syncope": "Syncope (fainting) reported",
133
+ "hemoptysis": "Hemoptysis (coughing blood) reported"
134
+ }
135
+ for term, desc in mapping.items():
136
+ if term in symptoms:
137
+ flags.append(f"Red Flag: {desc}.")
138
+ # Vitals-based
139
+ temp = vitals.get("temp_c")
140
+ hr = vitals.get("hr_bpm")
141
+ rr = vitals.get("rr_rpm")
142
+ spo2 = vitals.get("spo2_percent")
143
+ bp = parse_bp(vitals.get("bp_mmhg", ""))
144
+ if temp is not None and temp >= 38.5:
145
+ flags.append(f"Red Flag: Fever ({temp}°C).")
146
+ if hr is not None:
147
+ if hr >= 120:
148
+ flags.append(f"Red Flag: Tachycardia ({hr} bpm).")
149
+ if hr <= 50:
150
+ flags.append(f"Red Flag: Bradycardia ({hr} bpm).")
151
+ if rr is not None and rr >= 24:
152
+ flags.append(f"Red Flag: Tachypnea ({rr} rpm).")
153
+ if spo2 is not None and spo2 <= 92:
154
+ flags.append(f"Red Flag: Hypoxia ({spo2}%).")
155
+ if bp:
156
+ sys, dia = bp
157
+ if sys >= 180 or dia >= 110:
158
+ flags.append(f"Red Flag: Hypertensive urgency/emergency ({sys}/{dia} mmHg).")
159
+ if sys <= 90 or dia <= 60:
160
+ flags.append(f"Red Flag: Hypotension ({sys}/{dia} mmHg).")
161
+ return list(dict.fromkeys(flags)) # preserve order, dedupe
162
+
163
+ def format_patient_data_for_prompt(data: Dict[str, Any]) -> str:
164
+ """Convert patient_data dict into a markdown-like prompt string."""
165
+ if not data:
166
+ return "No patient data provided."
167
+ lines: List[str] = []
168
+ for section, value in data.items():
169
+ title = section.replace("_", " ").title()
170
+ if isinstance(value, dict) and any(value.values()):
171
+ lines.append(f"**{title}:**")
172
+ for k, v in value.items():
173
+ if v:
174
+ lines.append(f"- {k.replace('_',' ').title()}: {v}")
175
+ elif isinstance(value, list) and value:
176
+ lines.append(f"**{title}:** {', '.join(map(str, value))}")
177
+ elif value:
178
+ lines.append(f"**{title}:** {value}")
179
+ return "\n".join(lines)
180
+
181
+ # ── Tool Input Schemas ────────────────────────────────────────────────────────
182
+ class LabOrderInput(BaseModel):
183
+ test_name: str = Field(...)
184
+ reason: str = Field(...)
185
+ priority: str = Field("Routine")
186
+
187
+ class PrescriptionInput(BaseModel):
188
+ medication_name: str = Field(...)
189
+ dosage: str = Field(...)
190
+ route: str = Field(...)
191
+ frequency: str = Field(...)
192
+ duration: str = Field("As directed")
193
+ reason: str = Field(...)
194
+
195
+ class InteractionCheckInput(BaseModel):
196
+ potential_prescription: str = Field(...)
197
+ current_medications: Optional[List[str]] = Field(None)
198
+ allergies: Optional[List[str]] = Field(None)
199
+
200
+ class FlagRiskInput(BaseModel):
201
+ risk_description: str = Field(...)
202
+ urgency: str = Field("High")
203
+
204
+ # ── Tool Implementations ──────────────────────────────────────────────────────
205
+ @tool("order_lab_test", args_schema=LabOrderInput)
206
+ def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
207
+ logger.info(f"Ordering lab test: {test_name}, reason: {reason}, priority: {priority}")
208
+ return json.dumps({
209
+ "status": "success",
210
+ "message": f"Lab Ordered: {test_name} ({priority})",
211
+ "details": f"Reason: {reason}"
212
+ })
213
+
214
+ @tool("prescribe_medication", args_schema=PrescriptionInput)
215
+ def prescribe_medication(
216
+ medication_name: str,
217
+ dosage: str,
218
+ route: str,
219
+ frequency: str,
220
+ duration: str,
221
+ reason: str
222
+ ) -> str:
223
+ logger.info(f"Preparing prescription: {medication_name} {dosage}, route: {route}, freq: {frequency}")
224
+ return json.dumps({
225
+ "status": "success",
226
+ "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}",
227
+ "details": f"Duration: {duration}. Reason: {reason}"
228
+ })
229
+
230
+ @tool("check_drug_interactions", args_schema=InteractionCheckInput)
231
+ def check_drug_interactions(
232
+ potential_prescription: str,
233
+ current_medications: Optional[List[str]] = None,
234
+ allergies: Optional[List[str]] = None
235
+ ) -> str:
236
+ logger.info(f"Checking interactions for: {potential_prescription}")
237
+ warnings: List[str] = []
238
+ pm = [m.lower().strip() for m in (current_medications or []) if m]
239
+ al = [a.lower().strip() for a in (allergies or []) if a]
240
+
241
+ # Allergy checks
242
+ if potential_prescription.lower().strip() in al:
243
+ warnings.append(f"CRITICAL ALLERGY: Patient allergic to '{potential_prescription}'.")
244
+ # RxNorm/OpenFDA lookups
245
+ rxcui = get_rxcui(potential_prescription)
246
+ label = get_openfda_label(rxcui=rxcui, drug_name=potential_prescription)
247
+ if not (rxcui or label):
248
+ warnings.append(f"INFO: Could not identify '{potential_prescription}'. Checks may be incomplete.")
249
+
250
+ # Contraindications & warnings sections
251
+ for section in ("contraindications", "warnings_and_cautions", "warnings"):
252
+ items = label.get(section) if label else None
253
+ if isinstance(items, list):
254
+ snippets = search_text_list(items, al)
255
+ if snippets:
256
+ warnings.append(f"ALLERGY RISK ({section}): {'; '.join(snippets)}")
257
+
258
+ # Drug–drug interactions
259
+ for med in pm:
260
+ mrxcui = get_rxcui(med)
261
+ mlabel = get_openfda_label(rxcui=mrxcui, drug_name=med)
262
+ # check in both labels
263
+ for sec in ("drug_interactions",):
264
+ for src_label, src_name in ((label, potential_prescription), (mlabel, med)):
265
+ items = src_label.get(sec) if src_label else None
266
+ if isinstance(items, list):
267
+ snippets = search_text_list(items, [med if src_name==potential_prescription else potential_prescription])
268
+ if snippets:
269
+ warnings.append(f"Interaction ({src_name} label): {'; '.join(snippets)}")
270
+
271
+ status = "warning" if warnings else "clear"
272
+ message = (
273
+ f"{len(warnings)} issue(s) found for '{potential_prescription}'."
274
+ if warnings else
275
+ f"No major interactions or allergy issues identified for '{potential_prescription}'."
276
+ )
277
+ return json.dumps({"status": status, "message": message, "warnings": warnings})
278
+
279
+ @tool("flag_risk", args_schema=FlagRiskInput)
280
+ def flag_risk(risk_description: str, urgency: str = "High") -> str:
281
+ logger.info(f"Flagging risk: {risk_description} (urgency={urgency})")
282
+ return json.dumps({
283
+ "status": "flagged",
284
+ "message": f"Risk '{risk_description}' flagged with {urgency} urgency."
285
+ })
286
+
287
+ search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results")
288
+ all_tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool]
289
+
290
+ # ── LLM & Tool Executor ──────────────────────────────────────────────────────
291
+ llm = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME)
292
+ model_with_tools = llm.bind_tools(all_tools)
293
+ tool_executor = ToolExecutor(all_tools)
294
+
295
+ # ── State Definition ──────────────────────────────────────────────────────────
296
+ class AgentState(TypedDict):
297
+ messages: List[Any]
298
+ patient_data: Optional[Dict[str, Any]]
299
+ summary: Optional[str]
300
+ interaction_warnings: Optional[List[str]]
301
+
302
+ # ── Graph Nodes ───────────────────────────────────────────────────────────────
303
+ def agent_node(state: AgentState) -> Dict[str, Any]:
304
+ msgs = state["messages"]
305
+ if not msgs or not isinstance(msgs[0], SystemMessage):
306
+ msgs = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + msgs
307
+ logger.info(f"Invoking LLM with {len(msgs)} messages")
308
+ try:
309
+ response = model_with_tools.invoke(msgs)
310
+ return {"messages": [response]}
311
+ except Exception as e:
312
+ logger.exception("Error in agent_node")
313
+ return {"messages": [AIMessage(content=f"Error: {e}")]}
314
+
315
+ def tool_node(state: AgentState) -> Dict[str, Any]:
316
+ last = state["messages"][-1]
317
+ if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None):
318
+ logger.warning("tool_node invoked without pending tool_calls")
319
+ return {"messages": [], "interaction_warnings": None}
320
+
321
+ calls = last.tool_calls
322
+ # Safety: require interaction check before prescribing
323
+ blocked_ids = set()
324
+ for call in calls:
325
+ if call["name"] == "prescribe_medication":
326
+ med = call["args"].get("medication_name", "").lower()
327
+ if not any(c["name"] == "check_drug_interactions" and c["args"].get("potential_prescription","").lower() == med for c in calls):
328
+ logger.warning(f"Blocking prescribe_medication for '{med}' without interaction check")
329
+ blocked_ids.add(call["id"])
330
+
331
+ to_execute = [c for c in calls if c["id"] not in blocked_ids]
332
+ # Augment interaction checks with patient data
333
+ pd = state.get("patient_data", {})
334
+ for call in to_execute:
335
+ if call["name"] == "check_drug_interactions":
336
+ call["args"].setdefault("current_medications", pd.get("medications", {}).get("current", []))
337
+ call["args"].setdefault("allergies", pd.get("allergies", []))
338
+
339
+ messages: List[ToolMessage] = []
340
+ warnings: List[str] = []
341
+ try:
342
+ responses = tool_executor.batch(to_execute, return_exceptions=True)
343
+ for call, resp in zip(to_execute, responses):
344
+ if isinstance(resp, Exception):
345
+ logger.exception(f"Error executing tool {call['name']}")
346
+ content = json.dumps({"status": "error", "message": str(resp)})
347
+ else:
348
+ content = str(resp)
349
+ if call["name"] == "check_drug_interactions":
350
+ data = json.loads(content)
351
+ if data.get("status") == "warning":
352
+ warnings.extend(data.get("warnings", []))
353
+ messages.append(ToolMessage(content=content, tool_call_id=call["id"], name=call["name"]))
354
+ except Exception as e:
355
+ logger.exception("Critical error in tool_node")
356
+ # return an error message for each pending call
357
+ for call in to_execute:
358
+ messages.append(ToolMessage(
359
+ content=json.dumps({"status": "error", "message": str(e)}),
360
+ tool_call_id=call["id"],
361
+ name=call["name"]
362
+ ))
363
+ return {"messages": messages, "interaction_warnings": warnings or None}
364
+
365
+ def reflection_node(state: AgentState) -> Dict[str, Any]:
366
+ warns = state.get("interaction_warnings")
367
+ if not warns:
368
+ logger.warning("reflection_node called without warnings")
369
+ return {"messages": [], "interaction_warnings": None}
370
+
371
+ # Find the AIMessage that triggered the warnings
372
+ triggering: Optional[AIMessage] = None
373
+ for msg in reversed(state["messages"]):
374
+ if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
375
+ triggering = msg
376
+ break
377
+
378
+ if not triggering:
379
+ return {"messages": [AIMessage(content="Internal Error: reflection context missing.")], "interaction_warnings": None}
380
+
381
+ prompt = (
382
+ "You are SynapseAI, performing a focused safety review of the following plan:\n\n"
383
+ f"{triggering.content}\n\n"
384
+ "Highlight any issues based on these warnings:\n" + "\n".join(f"- {w}" for w in warns)
385
+ )
386
+ try:
387
+ resp = llm.invoke([SystemMessage(content="Safety reflection"), HumanMessage(content=prompt)])
388
+ return {"messages": [AIMessage(content=resp.content)], "interaction_warnings": None}
389
+ except Exception as e:
390
+ logger.exception("Error during reflection")
391
+ return {"messages": [AIMessage(content=f"Error during reflection: {e}")], "interaction_warnings": None}
392
+
393
+ # ── Routing Functions ────────────────────────────────────────────────────────
394
+ def should_continue(state: AgentState) -> str:
395
+ last = state["messages"][-1] if state["messages"] else None
396
+ if not isinstance(last, AIMessage):
397
+ return "end_conversation_turn"
398
+ if getattr(last, "tool_calls", None):
399
+ return "continue_tools"
400
+ return "end_conversation_turn"
401
+
402
+ def after_tools_router(state: AgentState) -> str:
403
+ return "reflection" if state.get("interaction_warnings") else "agent"
404
+
405
+ # ── ClinicalAgent ────────────────────────────────────────────────────────────
406
+ class ClinicalAgent:
407
+ def __init__(self):
408
+ logger.info("Building ClinicalAgent workflow")
409
+ wf = StateGraph(AgentState)
410
+ wf.add_node("agent", agent_node)
411
+ wf.add_node("tools", tool_node)
412
+ wf.add_node("reflection", reflection_node)
413
+ wf.set_entry_point("agent")
414
+ wf.add_conditional_edges("agent", should_continue, {
415
+ "continue_tools": "tools",
416
+ "end_conversation_turn": END
417
+ })
418
+ wf.add_conditional_edges("tools", after_tools_router, {
419
+ "reflection": "reflection",
420
+ "agent": "agent"
421
+ })
422
+ wf.add_edge("reflection", "agent")
423
+ self.graph_app = wf.compile()
424
+ logger.info("ClinicalAgent ready")
425
+
426
+ def invoke_turn(self, state: Dict[str, Any]) -> Dict[str, Any]:
427
  try:
428
+ result = self.graph_app.invoke(state, {"recursion_limit": 15})
429
+ result.setdefault("summary", state.get("summary"))
430
+ result.setdefault("interaction_warnings", None)
431
+ return result
432
  except Exception as e:
433
+ logger.exception("Error during graph invocation")
434
+ return {
435
+ "messages": state.get("messages", []) + [AIMessage(content=f"Error: {e}")],
436
+ "patient_data": state.get("patient_data"),
437
+ "summary": state.get("summary"),
438
+ "interaction_warnings": None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  }