mgbam commited on
Commit
6b2d9f7
Β·
verified Β·
1 Parent(s): 71db5de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -43
app.py CHANGED
@@ -12,7 +12,6 @@ from dotenv import load_dotenv
12
  from langchain_groq import ChatGroq
13
  from langchain_community.tools.tavily_search import TavilySearchResults
14
  from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
15
- # from langchain_core.prompts import ChatPromptTemplate # Not explicitly used
16
  from langchain_core.pydantic_v1 import BaseModel, Field
17
  from langchain_core.tools import tool
18
  from langgraph.prebuilt import ToolExecutor
@@ -38,16 +37,16 @@ class ClinicalPrompts: SYSTEM_PROMPT = """
38
  """
39
 
40
  # --- API Helper Functions (get_rxcui, get_openfda_label, search_text_list) ---
41
- # ... (Keep these functions exactly as they were in the previous 'full code' response) ...
42
  UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"; RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"; OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
43
  @lru_cache(maxsize=256)
44
  def get_rxcui(drug_name: str) -> Optional[str]:
45
  if not drug_name or not isinstance(drug_name, str): return None; drug_name = drug_name.strip();
46
  if not drug_name: return None; print(f"RxNorm Lookup for: '{drug_name}'");
47
- try:
48
  params = {"name": drug_name, "search": 1}; response = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10); response.raise_for_status(); data = response.json();
49
  if data and "idGroup" in data and "rxnormId" in data["idGroup"]: rxcui = data["idGroup"]["rxnormId"][0]; print(f" Found RxCUI: {rxcui} for '{drug_name}'"); return rxcui
50
- else:
51
  params = {"name": drug_name}; response = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10); response.raise_for_status(); data = response.json();
52
  if data and "drugGroup" in data and "conceptGroup" in data["drugGroup"]:
53
  for group in data["drugGroup"]["conceptGroup"]:
@@ -79,30 +78,71 @@ def search_text_list(text_list: Optional[List[str]], search_terms: List[str]) ->
79
  for term in search_terms_lower:
80
  if term in text_item_lower:
81
  start_index = text_item_lower.find(term); snippet_start = max(0, start_index - 50); snippet_end = min(len(text_item), start_index + len(term) + 100); snippet = text_item[snippet_start:snippet_end];
82
- snippet = snippet.replace(term, f"**{term}**", 1); found_snippets.append(f"...{snippet}..."); break
 
 
 
83
  return found_snippets
84
 
85
 
86
- # --- Other Helper Functions (parse_bp, check_red_flags, format_patient_data_for_prompt) ---
87
- # ... (Keep these functions exactly as they were) ...
88
  def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
89
  if not isinstance(bp_string, str): return None; match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string.strip());
90
  if match: return int(match.group(1)), int(match.group(2)); return None
 
 
91
  def check_red_flags(patient_data: dict) -> List[str]:
92
- flags = [];
93
- if not patient_data: return flags; symptoms = patient_data.get("hpi", {}).get("symptoms", []); vitals = patient_data.get("vitals", {}); history = patient_data.get("pmh", {}).get("conditions", ""); symptoms_lower = [str(s).lower() for s in symptoms if isinstance(s, str)];
94
- if "chest pain" in symptoms_lower: flags.append("Red Flag: Chest Pain reported."); if "shortness of breath" in symptoms_lower: flags.append("Red Flag: Shortness of Breath reported."); if "severe headache" in symptoms_lower: flags.append("Red Flag: Severe Headache reported."); if "sudden vision loss" in symptoms_lower: flags.append("Red Flag: Sudden Vision Loss reported."); if "weakness on one side" in symptoms_lower: flags.append("Red Flag: Unilateral Weakness reported (potential stroke)."); if "hemoptysis" in symptoms_lower: flags.append("Red Flag: Hemoptysis (coughing up blood)."); if "syncope" in symptoms_lower: flags.append("Red Flag: Syncope (fainting).");
95
- if vitals: temp = vitals.get("temp_c"); hr = vitals.get("hr_bpm"); rr = vitals.get("rr_rpm"); spo2 = vitals.get("spo2_percent"); bp_str = vitals.get("bp_mmhg");
96
- if temp is not None and temp >= 38.5: flags.append(f"Red Flag: Fever ({temp}Β°C)."); if hr is not None and hr >= 120: flags.append(f"Red Flag: Tachycardia ({hr} bpm)."); if hr is not None and hr <= 50: flags.append(f"Red Flag: Bradycardia ({hr} bpm)."); if rr is not None and rr >= 24: flags.append(f"Red Flag: Tachypnea ({rr} rpm)."); if spo2 is not None and spo2 <= 92: flags.append(f"Red Flag: Hypoxia ({spo2}%).");
97
- if bp_str: bp = parse_bp(bp_str);
98
- if bp:
99
- if bp[0] >= 180 or bp[1] >= 110: flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {bp_str} mmHg).");
100
- if bp[0] <= 90 or bp[1] <= 60: flags.append(f"Red Flag: Hypotension (BP: {bp_str} mmHg).");
101
- if history and isinstance(history, str): history_lower = history.lower();
102
- if "history of mi" in history_lower and "chest pain" in symptoms_lower: flags.append("Red Flag: History of MI with current Chest Pain.");
103
- if "history of dvt/pe" in history_lower and "shortness of breath" in symptoms_lower: flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.");
104
- return list(set(flags))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def format_patient_data_for_prompt(data: dict) -> str:
 
106
  if not data: return "No patient data provided."; prompt_str = "";
107
  for key, value in data.items(): section_title = key.replace('_', ' ').title();
108
  if isinstance(value, dict) and value: has_content = any(sub_value for sub_value in value.values());
@@ -174,7 +214,7 @@ model = ChatGroq(temperature=ClinicalAppSettings.TEMPERATURE, model=ClinicalAppS
174
  model_with_tools = model.bind_tools(tools)
175
 
176
  # --- Graph Nodes (agent_node, tool_node) ---
177
- # ... (Keep agent_node and tool_node functions exactly as they were in the last 'full code' response) ...
178
  def agent_node(state: AgentState):
179
  print("\n---AGENT NODE---"); current_messages = state['messages'];
180
  if not current_messages or not isinstance(current_messages[0], SystemMessage): print("Prepending System Prompt."); current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages;
@@ -233,18 +273,18 @@ def main():
233
  # --- Patient Data Input Sidebar ---
234
  with st.sidebar:
235
  st.header("πŸ“„ Patient Intake Form")
236
- # Input fields... (Using shorter versions for brevity, assume full fields are here)
237
- st.subheader("Demographics"); age = st.number_input("Age", 0, 120, 55); sex = st.selectbox("Sex", ["Male", "Female", "Other"])
238
- st.subheader("HPI"); chief_complaint = st.text_input("Chief Complaint", "Chest pain"); hpi_details = st.text_area("HPI Details", "55 y/o male...", height=100); symptoms = st.multiselect("Symptoms", ["Nausea", "Diaphoresis", "SOB", "Dizziness"], default=["Nausea", "Diaphoresis"])
239
- st.subheader("History"); pmh = st.text_area("PMH", "HTN, HLD, DM2, History of MI"); psh = st.text_area("PSH", "Appendectomy")
240
- st.subheader("Meds & Allergies"); current_meds_str = st.text_area("Current Meds", "Lisinopril 10mg daily\nMetformin 1000mg BID"); allergies_str = st.text_area("Allergies", "Penicillin (rash)")
241
- st.subheader("Social/Family"); social_history = st.text_area("SH", "Smoker"); family_history = st.text_area("FHx", "Father MI")
242
  st.subheader("Vitals & Exam"); col1, col2 = st.columns(2);
243
- with col1: temp_c = st.number_input("Temp C", 35.0, 42.0, 36.8, format="%.1f"); hr_bpm = st.number_input("HR", 30, 250, 95); rr_rpm = st.number_input("RR", 5, 50, 18)
244
- with col2: bp_mmhg = st.text_input("BP", "155/90"); spo2_percent = st.number_input("SpO2", 70, 100, 96); pain_scale = st.slider("Pain", 0, 10, 8)
245
- exam_notes = st.text_area("Exam Notes", "Awake, alert...", height=50)
246
 
247
- if st.button("Start/Update Consultation"):
248
  current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
249
  current_med_names_only = [];
250
  for med in current_meds_list: match = re.match(r"^\s*([a-zA-Z\-]+)", med);
@@ -261,12 +301,12 @@ def main():
261
 
262
  # --- Main Chat Interface Area ---
263
  st.header("πŸ’¬ Clinical Consultation")
264
- # Display loop - SyntaxError Fixed
265
  for msg in st.session_state.messages:
266
  if isinstance(msg, HumanMessage):
267
- with st.chat_message("user"): st.markdown(msg.content) # No key
268
  elif isinstance(msg, AIMessage):
269
- with st.chat_message("assistant"): # No key
270
  ai_content = msg.content; structured_output = None
271
  try: # JSON Parsing logic...
272
  json_match = re.search(r"```json\s*(\{.*?\})\s*```", ai_content, re.DOTALL | re.IGNORECASE)
@@ -292,18 +332,15 @@ def main():
292
  # CORRECTED Tool Call Display Block
293
  if getattr(msg, 'tool_calls', None):
294
  with st.expander("πŸ› οΈ AI requested actions", expanded=False):
295
- if msg.tool_calls: # Check if list is not empty
296
  for tc in msg.tool_calls:
297
  try:
298
- # Properly indented try block content
299
  st.code(f"Action: {tc.get('name', 'Unknown Tool')}\nArgs: {json.dumps(tc.get('args', {}), indent=2)}", language="json")
300
  except Exception as display_e:
301
- # Properly indented except block content
302
- st.error(f"Could not display tool call arguments properly: {display_e}", icon="⚠️")
303
- # Provide a fallback display
304
- st.code(f"Action: {tc.get('name', 'Unknown Tool')}\nRaw Args: {tc.get('args')}") # Show raw args if JSON fails
305
  else:
306
- st.caption("_No actions requested in this turn._")
307
  elif isinstance(msg, ToolMessage):
308
  tool_name_display = getattr(msg, 'name', 'tool_execution')
309
  with st.chat_message(tool_name_display, avatar="πŸ› οΈ"): # No key
@@ -312,9 +349,9 @@ def main():
312
  if status == "success" or status == "clear" or status == "flagged": st.success(f"{message}", icon="βœ…" if status != "flagged" else "🚨")
313
  elif status == "warning": st.warning(f"{message}", icon="⚠️");
314
  if warnings and isinstance(warnings, list): st.caption("Details:"); [st.caption(f"- {warn}") for warn in warnings]
315
- else: st.error(f"{message}", icon="❌")
316
  if details: st.caption(f"Details: {details}")
317
- except json.JSONDecodeError: st.info(f"{msg.content}")
318
  except Exception as e: st.error(f"Error displaying tool message: {e}", icon="❌"); st.caption(f"Raw content: {msg.content}")
319
 
320
  # --- Chat Input Logic ---
 
12
  from langchain_groq import ChatGroq
13
  from langchain_community.tools.tavily_search import TavilySearchResults
14
  from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
 
15
  from langchain_core.pydantic_v1 import BaseModel, Field
16
  from langchain_core.tools import tool
17
  from langgraph.prebuilt import ToolExecutor
 
37
  """
38
 
39
  # --- API Helper Functions (get_rxcui, get_openfda_label, search_text_list) ---
40
+ # ... (Keep these functions exactly as they were) ...
41
  UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"; RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"; OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
42
  @lru_cache(maxsize=256)
43
  def get_rxcui(drug_name: str) -> Optional[str]:
44
  if not drug_name or not isinstance(drug_name, str): return None; drug_name = drug_name.strip();
45
  if not drug_name: return None; print(f"RxNorm Lookup for: '{drug_name}'");
46
+ try: # Try direct lookup first
47
  params = {"name": drug_name, "search": 1}; response = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10); response.raise_for_status(); data = response.json();
48
  if data and "idGroup" in data and "rxnormId" in data["idGroup"]: rxcui = data["idGroup"]["rxnormId"][0]; print(f" Found RxCUI: {rxcui} for '{drug_name}'"); return rxcui
49
+ else: # Fallback to /drugs search
50
  params = {"name": drug_name}; response = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10); response.raise_for_status(); data = response.json();
51
  if data and "drugGroup" in data and "conceptGroup" in data["drugGroup"]:
52
  for group in data["drugGroup"]["conceptGroup"]:
 
78
  for term in search_terms_lower:
79
  if term in text_item_lower:
80
  start_index = text_item_lower.find(term); snippet_start = max(0, start_index - 50); snippet_end = min(len(text_item), start_index + len(term) + 100); snippet = text_item[snippet_start:snippet_end];
81
+ # Highlight first match for clarity
82
+ snippet = re.sub(f"({re.escape(term)})", r"**\1**", snippet, count=1, flags=re.IGNORECASE)
83
+ found_snippets.append(f"...{snippet}...")
84
+ break # Only report first match per text item
85
  return found_snippets
86
 
87
 
88
+ # --- Other Helper Functions ---
 
89
  def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
90
  if not isinstance(bp_string, str): return None; match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string.strip());
91
  if match: return int(match.group(1)), int(match.group(2)); return None
92
+
93
+ # CORRECTED check_red_flags function
94
  def check_red_flags(patient_data: dict) -> List[str]:
95
+ """Checks patient data against predefined red flags."""
96
+ flags = []
97
+ if not patient_data: return flags
98
+ symptoms = patient_data.get("hpi", {}).get("symptoms", [])
99
+ vitals = patient_data.get("vitals", {})
100
+ history = patient_data.get("pmh", {}).get("conditions", "")
101
+ symptoms_lower = [str(s).lower() for s in symptoms if isinstance(s, str)]
102
+
103
+ # Symptom Flags (CORRECTED - Separate lines)
104
+ if "chest pain" in symptoms_lower:
105
+ flags.append("Red Flag: Chest Pain reported.")
106
+ if "shortness of breath" in symptoms_lower:
107
+ flags.append("Red Flag: Shortness of Breath reported.")
108
+ if "severe headache" in symptoms_lower:
109
+ flags.append("Red Flag: Severe Headache reported.")
110
+ if "sudden vision loss" in symptoms_lower:
111
+ flags.append("Red Flag: Sudden Vision Loss reported.")
112
+ if "weakness on one side" in symptoms_lower:
113
+ flags.append("Red Flag: Unilateral Weakness reported (potential stroke).")
114
+ if "hemoptysis" in symptoms_lower:
115
+ flags.append("Red Flag: Hemoptysis (coughing up blood).")
116
+ if "syncope" in symptoms_lower:
117
+ flags.append("Red Flag: Syncope (fainting).")
118
+
119
+ # Vital Sign Flags
120
+ if vitals:
121
+ temp = vitals.get("temp_c"); hr = vitals.get("hr_bpm"); rr = vitals.get("rr_rpm")
122
+ spo2 = vitals.get("spo2_percent"); bp_str = vitals.get("bp_mmhg")
123
+ if temp is not None and temp >= 38.5: flags.append(f"Red Flag: Fever ({temp}Β°C).")
124
+ if hr is not None and hr >= 120: flags.append(f"Red Flag: Tachycardia ({hr} bpm).")
125
+ if hr is not None and hr <= 50: flags.append(f"Red Flag: Bradycardia ({hr} bpm).")
126
+ if rr is not None and rr >= 24: flags.append(f"Red Flag: Tachypnea ({rr} rpm).")
127
+ if spo2 is not None and spo2 <= 92: flags.append(f"Red Flag: Hypoxia ({spo2}%).")
128
+ if bp_str:
129
+ bp = parse_bp(bp_str)
130
+ if bp:
131
+ if bp[0] >= 180 or bp[1] >= 110: flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {bp_str} mmHg).")
132
+ if bp[0] <= 90 or bp[1] <= 60: flags.append(f"Red Flag: Hypotension (BP: {bp_str} mmHg).")
133
+
134
+ # History Flags
135
+ if history and isinstance(history, str):
136
+ history_lower = history.lower()
137
+ if "history of mi" in history_lower and "chest pain" in symptoms_lower:
138
+ flags.append("Red Flag: History of MI with current Chest Pain.")
139
+ if "history of dvt/pe" in history_lower and "shortness of breath" in symptoms_lower:
140
+ flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.")
141
+
142
+ return list(set(flags)) # Unique flags
143
+
144
  def format_patient_data_for_prompt(data: dict) -> str:
145
+ # ... (Keep this function exactly as it was) ...
146
  if not data: return "No patient data provided."; prompt_str = "";
147
  for key, value in data.items(): section_title = key.replace('_', ' ').title();
148
  if isinstance(value, dict) and value: has_content = any(sub_value for sub_value in value.values());
 
214
  model_with_tools = model.bind_tools(tools)
215
 
216
  # --- Graph Nodes (agent_node, tool_node) ---
217
+ # ... (Keep these functions exactly as they were) ...
218
  def agent_node(state: AgentState):
219
  print("\n---AGENT NODE---"); current_messages = state['messages'];
220
  if not current_messages or not isinstance(current_messages[0], SystemMessage): print("Prepending System Prompt."); current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages;
 
273
  # --- Patient Data Input Sidebar ---
274
  with st.sidebar:
275
  st.header("πŸ“„ Patient Intake Form")
276
+ # Input fields...
277
+ st.subheader("Demographics"); age = st.number_input("Age", 0, 120, 55, key="sb_age"); sex = st.selectbox("Sex", ["Male", "Female", "Other"], key="sb_sex")
278
+ st.subheader("HPI"); chief_complaint = st.text_input("Chief Complaint", "Chest pain", key="sb_cc"); hpi_details = st.text_area("HPI Details", "55 y/o male...", height=100, key="sb_hpi"); symptoms = st.multiselect("Symptoms", ["Nausea", "Diaphoresis", "SOB", "Dizziness", "Severe Headache", "Syncope", "Hemoptysis"], default=["Nausea", "Diaphoresis"], key="sb_sym")
279
+ st.subheader("History"); pmh = st.text_area("PMH", "HTN, HLD, DM2, History of MI", key="sb_pmh"); psh = st.text_area("PSH", "Appendectomy", key="sb_psh")
280
+ st.subheader("Meds & Allergies"); current_meds_str = st.text_area("Current Meds", "Lisinopril 10mg daily\nMetformin 1000mg BID\nAtorvastatin 40mg daily", key="sb_meds"); allergies_str = st.text_area("Allergies", "Penicillin (rash), Sulfa", key="sb_allergies")
281
+ st.subheader("Social/Family"); social_history = st.text_area("SH", "Smoker", key="sb_sh"); family_history = st.text_area("FHx", "Father MI", key="sb_fhx")
282
  st.subheader("Vitals & Exam"); col1, col2 = st.columns(2);
283
+ with col1: temp_c = st.number_input("Temp C", 35.0, 42.0, 36.8, format="%.1f", key="sb_temp"); hr_bpm = st.number_input("HR", 30, 250, 95, key="sb_hr"); rr_rpm = st.number_input("RR", 5, 50, 18, key="sb_rr")
284
+ with col2: bp_mmhg = st.text_input("BP", "155/90", key="sb_bp"); spo2_percent = st.number_input("SpO2", 70, 100, 96, key="sb_spo2"); pain_scale = st.slider("Pain", 0, 10, 8, key="sb_pain")
285
+ exam_notes = st.text_area("Exam Notes", "Awake, alert...", height=50, key="sb_exam")
286
 
287
+ if st.button("Start/Update Consultation", key="sb_start"):
288
  current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
289
  current_med_names_only = [];
290
  for med in current_meds_list: match = re.match(r"^\s*([a-zA-Z\-]+)", med);
 
301
 
302
  # --- Main Chat Interface Area ---
303
  st.header("πŸ’¬ Clinical Consultation")
304
+ # Display loop - key= argument REMOVED, Tool Call Display Syntax FIXED
305
  for msg in st.session_state.messages:
306
  if isinstance(msg, HumanMessage):
307
+ with st.chat_message("user"): st.markdown(msg.content)
308
  elif isinstance(msg, AIMessage):
309
+ with st.chat_message("assistant"):
310
  ai_content = msg.content; structured_output = None
311
  try: # JSON Parsing logic...
312
  json_match = re.search(r"```json\s*(\{.*?\})\s*```", ai_content, re.DOTALL | re.IGNORECASE)
 
332
  # CORRECTED Tool Call Display Block
333
  if getattr(msg, 'tool_calls', None):
334
  with st.expander("πŸ› οΈ AI requested actions", expanded=False):
335
+ if msg.tool_calls:
336
  for tc in msg.tool_calls:
337
  try:
 
338
  st.code(f"Action: {tc.get('name', 'Unknown Tool')}\nArgs: {json.dumps(tc.get('args', {}), indent=2)}", language="json")
339
  except Exception as display_e:
340
+ st.error(f"Could not display tool call args: {display_e}", icon="⚠️")
341
+ st.code(f"Action: {tc.get('name', 'Unknown Tool')}\nRaw Args: {tc.get('args')}")
 
 
342
  else:
343
+ st.caption("_No actions requested._")
344
  elif isinstance(msg, ToolMessage):
345
  tool_name_display = getattr(msg, 'name', 'tool_execution')
346
  with st.chat_message(tool_name_display, avatar="πŸ› οΈ"): # No key
 
349
  if status == "success" or status == "clear" or status == "flagged": st.success(f"{message}", icon="βœ…" if status != "flagged" else "🚨")
350
  elif status == "warning": st.warning(f"{message}", icon="⚠️");
351
  if warnings and isinstance(warnings, list): st.caption("Details:"); [st.caption(f"- {warn}") for warn in warnings]
352
+ else: st.error(f"{message}", icon="❌") # Assume error if not success/clear/flagged/warning
353
  if details: st.caption(f"Details: {details}")
354
+ except json.JSONDecodeError: st.info(f"{msg.content}") # Display raw if not JSON
355
  except Exception as e: st.error(f"Error displaying tool message: {e}", icon="❌"); st.caption(f"Raw content: {msg.content}")
356
 
357
  # --- Chat Input Logic ---