Update app.py
Browse files
app.py
CHANGED
@@ -12,7 +12,6 @@ from dotenv import load_dotenv
|
|
12 |
from langchain_groq import ChatGroq
|
13 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
14 |
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
|
15 |
-
# from langchain_core.prompts import ChatPromptTemplate # Not explicitly used
|
16 |
from langchain_core.pydantic_v1 import BaseModel, Field
|
17 |
from langchain_core.tools import tool
|
18 |
from langgraph.prebuilt import ToolExecutor
|
@@ -38,16 +37,16 @@ class ClinicalPrompts: SYSTEM_PROMPT = """
|
|
38 |
"""
|
39 |
|
40 |
# --- API Helper Functions (get_rxcui, get_openfda_label, search_text_list) ---
|
41 |
-
# ... (Keep these functions exactly as they were
|
42 |
UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"; RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"; OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
|
43 |
@lru_cache(maxsize=256)
|
44 |
def get_rxcui(drug_name: str) -> Optional[str]:
|
45 |
if not drug_name or not isinstance(drug_name, str): return None; drug_name = drug_name.strip();
|
46 |
if not drug_name: return None; print(f"RxNorm Lookup for: '{drug_name}'");
|
47 |
-
try:
|
48 |
params = {"name": drug_name, "search": 1}; response = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10); response.raise_for_status(); data = response.json();
|
49 |
if data and "idGroup" in data and "rxnormId" in data["idGroup"]: rxcui = data["idGroup"]["rxnormId"][0]; print(f" Found RxCUI: {rxcui} for '{drug_name}'"); return rxcui
|
50 |
-
else:
|
51 |
params = {"name": drug_name}; response = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10); response.raise_for_status(); data = response.json();
|
52 |
if data and "drugGroup" in data and "conceptGroup" in data["drugGroup"]:
|
53 |
for group in data["drugGroup"]["conceptGroup"]:
|
@@ -79,30 +78,71 @@ def search_text_list(text_list: Optional[List[str]], search_terms: List[str]) ->
|
|
79 |
for term in search_terms_lower:
|
80 |
if term in text_item_lower:
|
81 |
start_index = text_item_lower.find(term); snippet_start = max(0, start_index - 50); snippet_end = min(len(text_item), start_index + len(term) + 100); snippet = text_item[snippet_start:snippet_end];
|
82 |
-
|
|
|
|
|
|
|
83 |
return found_snippets
|
84 |
|
85 |
|
86 |
-
# --- Other Helper Functions
|
87 |
-
# ... (Keep these functions exactly as they were) ...
|
88 |
def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
|
89 |
if not isinstance(bp_string, str): return None; match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string.strip());
|
90 |
if match: return int(match.group(1)), int(match.group(2)); return None
|
|
|
|
|
91 |
def check_red_flags(patient_data: dict) -> List[str]:
|
92 |
-
|
93 |
-
|
94 |
-
if
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
if
|
99 |
-
|
100 |
-
|
101 |
-
if
|
102 |
-
|
103 |
-
if "
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
def format_patient_data_for_prompt(data: dict) -> str:
|
|
|
106 |
if not data: return "No patient data provided."; prompt_str = "";
|
107 |
for key, value in data.items(): section_title = key.replace('_', ' ').title();
|
108 |
if isinstance(value, dict) and value: has_content = any(sub_value for sub_value in value.values());
|
@@ -174,7 +214,7 @@ model = ChatGroq(temperature=ClinicalAppSettings.TEMPERATURE, model=ClinicalAppS
|
|
174 |
model_with_tools = model.bind_tools(tools)
|
175 |
|
176 |
# --- Graph Nodes (agent_node, tool_node) ---
|
177 |
-
# ... (Keep
|
178 |
def agent_node(state: AgentState):
|
179 |
print("\n---AGENT NODE---"); current_messages = state['messages'];
|
180 |
if not current_messages or not isinstance(current_messages[0], SystemMessage): print("Prepending System Prompt."); current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages;
|
@@ -233,18 +273,18 @@ def main():
|
|
233 |
# --- Patient Data Input Sidebar ---
|
234 |
with st.sidebar:
|
235 |
st.header("π Patient Intake Form")
|
236 |
-
# Input fields...
|
237 |
-
st.subheader("Demographics"); age = st.number_input("Age", 0, 120, 55); sex = st.selectbox("Sex", ["Male", "Female", "Other"])
|
238 |
-
st.subheader("HPI"); chief_complaint = st.text_input("Chief Complaint", "Chest pain"); hpi_details = st.text_area("HPI Details", "55 y/o male...", height=100); symptoms = st.multiselect("Symptoms", ["Nausea", "Diaphoresis", "SOB", "Dizziness"], default=["Nausea", "Diaphoresis"])
|
239 |
-
st.subheader("History"); pmh = st.text_area("PMH", "HTN, HLD, DM2, History of MI"); psh = st.text_area("PSH", "Appendectomy")
|
240 |
-
st.subheader("Meds & Allergies"); current_meds_str = st.text_area("Current Meds", "Lisinopril 10mg daily\nMetformin 1000mg BID"); allergies_str = st.text_area("Allergies", "Penicillin (rash)")
|
241 |
-
st.subheader("Social/Family"); social_history = st.text_area("SH", "Smoker"); family_history = st.text_area("FHx", "Father MI")
|
242 |
st.subheader("Vitals & Exam"); col1, col2 = st.columns(2);
|
243 |
-
with col1: temp_c = st.number_input("Temp C", 35.0, 42.0, 36.8, format="%.1f"); hr_bpm = st.number_input("HR", 30, 250, 95); rr_rpm = st.number_input("RR", 5, 50, 18)
|
244 |
-
with col2: bp_mmhg = st.text_input("BP", "155/90"); spo2_percent = st.number_input("SpO2", 70, 100, 96); pain_scale = st.slider("Pain", 0, 10, 8)
|
245 |
-
exam_notes = st.text_area("Exam Notes", "Awake, alert...", height=50)
|
246 |
|
247 |
-
if st.button("Start/Update Consultation"):
|
248 |
current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
|
249 |
current_med_names_only = [];
|
250 |
for med in current_meds_list: match = re.match(r"^\s*([a-zA-Z\-]+)", med);
|
@@ -261,12 +301,12 @@ def main():
|
|
261 |
|
262 |
# --- Main Chat Interface Area ---
|
263 |
st.header("π¬ Clinical Consultation")
|
264 |
-
# Display loop -
|
265 |
for msg in st.session_state.messages:
|
266 |
if isinstance(msg, HumanMessage):
|
267 |
-
with st.chat_message("user"): st.markdown(msg.content)
|
268 |
elif isinstance(msg, AIMessage):
|
269 |
-
with st.chat_message("assistant"):
|
270 |
ai_content = msg.content; structured_output = None
|
271 |
try: # JSON Parsing logic...
|
272 |
json_match = re.search(r"```json\s*(\{.*?\})\s*```", ai_content, re.DOTALL | re.IGNORECASE)
|
@@ -292,18 +332,15 @@ def main():
|
|
292 |
# CORRECTED Tool Call Display Block
|
293 |
if getattr(msg, 'tool_calls', None):
|
294 |
with st.expander("π οΈ AI requested actions", expanded=False):
|
295 |
-
if msg.tool_calls:
|
296 |
for tc in msg.tool_calls:
|
297 |
try:
|
298 |
-
# Properly indented try block content
|
299 |
st.code(f"Action: {tc.get('name', 'Unknown Tool')}\nArgs: {json.dumps(tc.get('args', {}), indent=2)}", language="json")
|
300 |
except Exception as display_e:
|
301 |
-
|
302 |
-
st.
|
303 |
-
# Provide a fallback display
|
304 |
-
st.code(f"Action: {tc.get('name', 'Unknown Tool')}\nRaw Args: {tc.get('args')}") # Show raw args if JSON fails
|
305 |
else:
|
306 |
-
st.caption("_No actions requested
|
307 |
elif isinstance(msg, ToolMessage):
|
308 |
tool_name_display = getattr(msg, 'name', 'tool_execution')
|
309 |
with st.chat_message(tool_name_display, avatar="π οΈ"): # No key
|
@@ -312,9 +349,9 @@ def main():
|
|
312 |
if status == "success" or status == "clear" or status == "flagged": st.success(f"{message}", icon="β
" if status != "flagged" else "π¨")
|
313 |
elif status == "warning": st.warning(f"{message}", icon="β οΈ");
|
314 |
if warnings and isinstance(warnings, list): st.caption("Details:"); [st.caption(f"- {warn}") for warn in warnings]
|
315 |
-
else: st.error(f"{message}", icon="β")
|
316 |
if details: st.caption(f"Details: {details}")
|
317 |
-
except json.JSONDecodeError: st.info(f"{msg.content}")
|
318 |
except Exception as e: st.error(f"Error displaying tool message: {e}", icon="β"); st.caption(f"Raw content: {msg.content}")
|
319 |
|
320 |
# --- Chat Input Logic ---
|
|
|
12 |
from langchain_groq import ChatGroq
|
13 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
14 |
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
|
|
|
15 |
from langchain_core.pydantic_v1 import BaseModel, Field
|
16 |
from langchain_core.tools import tool
|
17 |
from langgraph.prebuilt import ToolExecutor
|
|
|
37 |
"""
|
38 |
|
39 |
# --- API Helper Functions (get_rxcui, get_openfda_label, search_text_list) ---
|
40 |
+
# ... (Keep these functions exactly as they were) ...
|
41 |
UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"; RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"; OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
|
42 |
@lru_cache(maxsize=256)
|
43 |
def get_rxcui(drug_name: str) -> Optional[str]:
|
44 |
if not drug_name or not isinstance(drug_name, str): return None; drug_name = drug_name.strip();
|
45 |
if not drug_name: return None; print(f"RxNorm Lookup for: '{drug_name}'");
|
46 |
+
try: # Try direct lookup first
|
47 |
params = {"name": drug_name, "search": 1}; response = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10); response.raise_for_status(); data = response.json();
|
48 |
if data and "idGroup" in data and "rxnormId" in data["idGroup"]: rxcui = data["idGroup"]["rxnormId"][0]; print(f" Found RxCUI: {rxcui} for '{drug_name}'"); return rxcui
|
49 |
+
else: # Fallback to /drugs search
|
50 |
params = {"name": drug_name}; response = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10); response.raise_for_status(); data = response.json();
|
51 |
if data and "drugGroup" in data and "conceptGroup" in data["drugGroup"]:
|
52 |
for group in data["drugGroup"]["conceptGroup"]:
|
|
|
78 |
for term in search_terms_lower:
|
79 |
if term in text_item_lower:
|
80 |
start_index = text_item_lower.find(term); snippet_start = max(0, start_index - 50); snippet_end = min(len(text_item), start_index + len(term) + 100); snippet = text_item[snippet_start:snippet_end];
|
81 |
+
# Highlight first match for clarity
|
82 |
+
snippet = re.sub(f"({re.escape(term)})", r"**\1**", snippet, count=1, flags=re.IGNORECASE)
|
83 |
+
found_snippets.append(f"...{snippet}...")
|
84 |
+
break # Only report first match per text item
|
85 |
return found_snippets
|
86 |
|
87 |
|
88 |
+
# --- Other Helper Functions ---
|
|
|
89 |
def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
|
90 |
if not isinstance(bp_string, str): return None; match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string.strip());
|
91 |
if match: return int(match.group(1)), int(match.group(2)); return None
|
92 |
+
|
93 |
+
# CORRECTED check_red_flags function
|
94 |
def check_red_flags(patient_data: dict) -> List[str]:
|
95 |
+
"""Checks patient data against predefined red flags."""
|
96 |
+
flags = []
|
97 |
+
if not patient_data: return flags
|
98 |
+
symptoms = patient_data.get("hpi", {}).get("symptoms", [])
|
99 |
+
vitals = patient_data.get("vitals", {})
|
100 |
+
history = patient_data.get("pmh", {}).get("conditions", "")
|
101 |
+
symptoms_lower = [str(s).lower() for s in symptoms if isinstance(s, str)]
|
102 |
+
|
103 |
+
# Symptom Flags (CORRECTED - Separate lines)
|
104 |
+
if "chest pain" in symptoms_lower:
|
105 |
+
flags.append("Red Flag: Chest Pain reported.")
|
106 |
+
if "shortness of breath" in symptoms_lower:
|
107 |
+
flags.append("Red Flag: Shortness of Breath reported.")
|
108 |
+
if "severe headache" in symptoms_lower:
|
109 |
+
flags.append("Red Flag: Severe Headache reported.")
|
110 |
+
if "sudden vision loss" in symptoms_lower:
|
111 |
+
flags.append("Red Flag: Sudden Vision Loss reported.")
|
112 |
+
if "weakness on one side" in symptoms_lower:
|
113 |
+
flags.append("Red Flag: Unilateral Weakness reported (potential stroke).")
|
114 |
+
if "hemoptysis" in symptoms_lower:
|
115 |
+
flags.append("Red Flag: Hemoptysis (coughing up blood).")
|
116 |
+
if "syncope" in symptoms_lower:
|
117 |
+
flags.append("Red Flag: Syncope (fainting).")
|
118 |
+
|
119 |
+
# Vital Sign Flags
|
120 |
+
if vitals:
|
121 |
+
temp = vitals.get("temp_c"); hr = vitals.get("hr_bpm"); rr = vitals.get("rr_rpm")
|
122 |
+
spo2 = vitals.get("spo2_percent"); bp_str = vitals.get("bp_mmhg")
|
123 |
+
if temp is not None and temp >= 38.5: flags.append(f"Red Flag: Fever ({temp}Β°C).")
|
124 |
+
if hr is not None and hr >= 120: flags.append(f"Red Flag: Tachycardia ({hr} bpm).")
|
125 |
+
if hr is not None and hr <= 50: flags.append(f"Red Flag: Bradycardia ({hr} bpm).")
|
126 |
+
if rr is not None and rr >= 24: flags.append(f"Red Flag: Tachypnea ({rr} rpm).")
|
127 |
+
if spo2 is not None and spo2 <= 92: flags.append(f"Red Flag: Hypoxia ({spo2}%).")
|
128 |
+
if bp_str:
|
129 |
+
bp = parse_bp(bp_str)
|
130 |
+
if bp:
|
131 |
+
if bp[0] >= 180 or bp[1] >= 110: flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {bp_str} mmHg).")
|
132 |
+
if bp[0] <= 90 or bp[1] <= 60: flags.append(f"Red Flag: Hypotension (BP: {bp_str} mmHg).")
|
133 |
+
|
134 |
+
# History Flags
|
135 |
+
if history and isinstance(history, str):
|
136 |
+
history_lower = history.lower()
|
137 |
+
if "history of mi" in history_lower and "chest pain" in symptoms_lower:
|
138 |
+
flags.append("Red Flag: History of MI with current Chest Pain.")
|
139 |
+
if "history of dvt/pe" in history_lower and "shortness of breath" in symptoms_lower:
|
140 |
+
flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.")
|
141 |
+
|
142 |
+
return list(set(flags)) # Unique flags
|
143 |
+
|
144 |
def format_patient_data_for_prompt(data: dict) -> str:
|
145 |
+
# ... (Keep this function exactly as it was) ...
|
146 |
if not data: return "No patient data provided."; prompt_str = "";
|
147 |
for key, value in data.items(): section_title = key.replace('_', ' ').title();
|
148 |
if isinstance(value, dict) and value: has_content = any(sub_value for sub_value in value.values());
|
|
|
214 |
model_with_tools = model.bind_tools(tools)
|
215 |
|
216 |
# --- Graph Nodes (agent_node, tool_node) ---
|
217 |
+
# ... (Keep these functions exactly as they were) ...
|
218 |
def agent_node(state: AgentState):
|
219 |
print("\n---AGENT NODE---"); current_messages = state['messages'];
|
220 |
if not current_messages or not isinstance(current_messages[0], SystemMessage): print("Prepending System Prompt."); current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages;
|
|
|
273 |
# --- Patient Data Input Sidebar ---
|
274 |
with st.sidebar:
|
275 |
st.header("π Patient Intake Form")
|
276 |
+
# Input fields...
|
277 |
+
st.subheader("Demographics"); age = st.number_input("Age", 0, 120, 55, key="sb_age"); sex = st.selectbox("Sex", ["Male", "Female", "Other"], key="sb_sex")
|
278 |
+
st.subheader("HPI"); chief_complaint = st.text_input("Chief Complaint", "Chest pain", key="sb_cc"); hpi_details = st.text_area("HPI Details", "55 y/o male...", height=100, key="sb_hpi"); symptoms = st.multiselect("Symptoms", ["Nausea", "Diaphoresis", "SOB", "Dizziness", "Severe Headache", "Syncope", "Hemoptysis"], default=["Nausea", "Diaphoresis"], key="sb_sym")
|
279 |
+
st.subheader("History"); pmh = st.text_area("PMH", "HTN, HLD, DM2, History of MI", key="sb_pmh"); psh = st.text_area("PSH", "Appendectomy", key="sb_psh")
|
280 |
+
st.subheader("Meds & Allergies"); current_meds_str = st.text_area("Current Meds", "Lisinopril 10mg daily\nMetformin 1000mg BID\nAtorvastatin 40mg daily", key="sb_meds"); allergies_str = st.text_area("Allergies", "Penicillin (rash), Sulfa", key="sb_allergies")
|
281 |
+
st.subheader("Social/Family"); social_history = st.text_area("SH", "Smoker", key="sb_sh"); family_history = st.text_area("FHx", "Father MI", key="sb_fhx")
|
282 |
st.subheader("Vitals & Exam"); col1, col2 = st.columns(2);
|
283 |
+
with col1: temp_c = st.number_input("Temp C", 35.0, 42.0, 36.8, format="%.1f", key="sb_temp"); hr_bpm = st.number_input("HR", 30, 250, 95, key="sb_hr"); rr_rpm = st.number_input("RR", 5, 50, 18, key="sb_rr")
|
284 |
+
with col2: bp_mmhg = st.text_input("BP", "155/90", key="sb_bp"); spo2_percent = st.number_input("SpO2", 70, 100, 96, key="sb_spo2"); pain_scale = st.slider("Pain", 0, 10, 8, key="sb_pain")
|
285 |
+
exam_notes = st.text_area("Exam Notes", "Awake, alert...", height=50, key="sb_exam")
|
286 |
|
287 |
+
if st.button("Start/Update Consultation", key="sb_start"):
|
288 |
current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
|
289 |
current_med_names_only = [];
|
290 |
for med in current_meds_list: match = re.match(r"^\s*([a-zA-Z\-]+)", med);
|
|
|
301 |
|
302 |
# --- Main Chat Interface Area ---
|
303 |
st.header("π¬ Clinical Consultation")
|
304 |
+
# Display loop - key= argument REMOVED, Tool Call Display Syntax FIXED
|
305 |
for msg in st.session_state.messages:
|
306 |
if isinstance(msg, HumanMessage):
|
307 |
+
with st.chat_message("user"): st.markdown(msg.content)
|
308 |
elif isinstance(msg, AIMessage):
|
309 |
+
with st.chat_message("assistant"):
|
310 |
ai_content = msg.content; structured_output = None
|
311 |
try: # JSON Parsing logic...
|
312 |
json_match = re.search(r"```json\s*(\{.*?\})\s*```", ai_content, re.DOTALL | re.IGNORECASE)
|
|
|
332 |
# CORRECTED Tool Call Display Block
|
333 |
if getattr(msg, 'tool_calls', None):
|
334 |
with st.expander("π οΈ AI requested actions", expanded=False):
|
335 |
+
if msg.tool_calls:
|
336 |
for tc in msg.tool_calls:
|
337 |
try:
|
|
|
338 |
st.code(f"Action: {tc.get('name', 'Unknown Tool')}\nArgs: {json.dumps(tc.get('args', {}), indent=2)}", language="json")
|
339 |
except Exception as display_e:
|
340 |
+
st.error(f"Could not display tool call args: {display_e}", icon="β οΈ")
|
341 |
+
st.code(f"Action: {tc.get('name', 'Unknown Tool')}\nRaw Args: {tc.get('args')}")
|
|
|
|
|
342 |
else:
|
343 |
+
st.caption("_No actions requested._")
|
344 |
elif isinstance(msg, ToolMessage):
|
345 |
tool_name_display = getattr(msg, 'name', 'tool_execution')
|
346 |
with st.chat_message(tool_name_display, avatar="π οΈ"): # No key
|
|
|
349 |
if status == "success" or status == "clear" or status == "flagged": st.success(f"{message}", icon="β
" if status != "flagged" else "π¨")
|
350 |
elif status == "warning": st.warning(f"{message}", icon="β οΈ");
|
351 |
if warnings and isinstance(warnings, list): st.caption("Details:"); [st.caption(f"- {warn}") for warn in warnings]
|
352 |
+
else: st.error(f"{message}", icon="β") # Assume error if not success/clear/flagged/warning
|
353 |
if details: st.caption(f"Details: {details}")
|
354 |
+
except json.JSONDecodeError: st.info(f"{msg.content}") # Display raw if not JSON
|
355 |
except Exception as e: st.error(f"Error displaying tool message: {e}", icon="β"); st.caption(f"Raw content: {msg.content}")
|
356 |
|
357 |
# --- Chat Input Logic ---
|