Update app.py
Browse files
app.py
CHANGED
@@ -11,7 +11,7 @@ from dotenv import load_dotenv
|
|
11 |
from langchain_groq import ChatGroq
|
12 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
13 |
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
|
14 |
-
from langchain_core.prompts import ChatPromptTemplate
|
15 |
from langchain_core.pydantic_v1 import BaseModel, Field
|
16 |
from langchain_core.tools import tool
|
17 |
from langgraph.prebuilt import ToolExecutor
|
@@ -33,9 +33,12 @@ if not GROQ_API_KEY: missing_keys.append("GROQ_API_KEY")
|
|
33 |
if not TAVILY_API_KEY: missing_keys.append("TAVILY_API_KEY")
|
34 |
|
35 |
if missing_keys:
|
|
|
36 |
st.error(f"Missing required API Key(s): {', '.join(missing_keys)}. Please set them in Hugging Face Space Secrets or your environment variables.")
|
|
|
37 |
st.stop()
|
38 |
|
|
|
39 |
# --- Configuration & Constants ---
|
40 |
class ClinicalAppSettings:
|
41 |
APP_TITLE = "SynapseAI: Interactive Clinical Decision Support (UMLS/FDA Integrated)"
|
@@ -103,99 +106,57 @@ def get_rxcui(drug_name: str) -> Optional[str]:
|
|
103 |
response = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10)
|
104 |
response.raise_for_status()
|
105 |
data = response.json()
|
106 |
-
# Extract RxCUI - prioritize exact matches or common types
|
107 |
if data and "idGroup" in data and "rxnormId" in data["idGroup"]:
|
108 |
-
# Select the first one, assuming it's the most relevant by default.
|
109 |
-
# More sophisticated logic could check TTYs (Term Types) if needed.
|
110 |
rxcui = data["idGroup"]["rxnormId"][0]
|
111 |
print(f" Found RxCUI: {rxcui} for '{drug_name}'")
|
112 |
return rxcui
|
113 |
-
else:
|
114 |
-
|
115 |
-
|
116 |
-
response = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10)
|
117 |
-
response.raise_for_status()
|
118 |
-
data = response.json()
|
119 |
if data and "drugGroup" in data and "conceptGroup" in data["drugGroup"]:
|
120 |
for group in data["drugGroup"]["conceptGroup"]:
|
121 |
-
# Prioritize Semantic Types like Brand/Clinical Drug/Ingredient
|
122 |
if group.get("tty") in ["SBD", "SCD", "GPCK", "BPCK", "IN", "MIN", "PIN"]:
|
123 |
if "conceptProperties" in group and group["conceptProperties"]:
|
124 |
rxcui = group["conceptProperties"][0].get("rxcui")
|
125 |
-
if rxcui:
|
126 |
-
print(f" Found RxCUI (via /drugs): {rxcui} for '{drug_name}'")
|
127 |
-
return rxcui
|
128 |
print(f" RxCUI not found for '{drug_name}'.")
|
129 |
return None
|
130 |
-
except requests.exceptions.RequestException as e:
|
131 |
-
|
132 |
-
|
133 |
-
except json.JSONDecodeError as e:
|
134 |
-
print(f" Error decoding RxNorm JSON response for '{drug_name}': {e}")
|
135 |
-
return None
|
136 |
-
except Exception as e: # Catch any other unexpected error
|
137 |
-
print(f" Unexpected error in get_rxcui for '{drug_name}': {e}")
|
138 |
-
return None
|
139 |
|
140 |
@lru_cache(maxsize=128) # Cache OpenFDA lookups
|
141 |
def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[dict]:
|
142 |
"""Fetches drug label information from OpenFDA using RxCUI or drug name."""
|
143 |
if not rxcui and not drug_name: return None
|
144 |
print(f"OpenFDA Label Lookup for: RXCUI={rxcui}, Name={drug_name}")
|
145 |
-
|
146 |
search_terms = []
|
147 |
-
|
148 |
-
if
|
149 |
-
|
150 |
-
search_terms.append(f'openfda.rxcui:"{rxcui}"')
|
151 |
-
# Add name search as fallback or supplement
|
152 |
-
if drug_name:
|
153 |
-
search_terms.append(f'(openfda.brand_name:"{drug_name.lower()}" OR openfda.generic_name:"{drug_name.lower()}")')
|
154 |
-
|
155 |
-
search_query = " OR ".join(search_terms)
|
156 |
-
params = {"search": search_query, "limit": 1} # Get only the most relevant label
|
157 |
-
|
158 |
try:
|
159 |
response = requests.get(OPENFDA_API_BASE, params=params, timeout=15)
|
160 |
-
response.raise_for_status()
|
161 |
-
data
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
return None
|
167 |
-
except requests.exceptions.RequestException as e:
|
168 |
-
print(f" Error fetching OpenFDA label: {e}")
|
169 |
-
return None
|
170 |
-
except json.JSONDecodeError as e:
|
171 |
-
print(f" Error decoding OpenFDA JSON response: {e}")
|
172 |
-
return None
|
173 |
-
except Exception as e:
|
174 |
-
print(f" Unexpected error in get_openfda_label: {e}")
|
175 |
-
return None
|
176 |
|
177 |
def search_text_list(text_list: Optional[List[str]], search_terms: List[str]) -> List[str]:
|
178 |
""" Case-insensitive search for any search_term within a list of text strings. Returns snippets. """
|
179 |
found_snippets = []
|
180 |
if not text_list or not search_terms: return found_snippets
|
181 |
-
# Ensure search terms are lowercased strings
|
182 |
search_terms_lower = [str(term).lower() for term in search_terms if term]
|
183 |
-
|
184 |
for text_item in text_list:
|
185 |
-
if not isinstance(text_item, str): continue
|
186 |
text_item_lower = text_item.lower()
|
187 |
for term in search_terms_lower:
|
188 |
if term in text_item_lower:
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
snippet_end = min(len(text_item), start_index + len(term) + 100)
|
194 |
-
snippet = text_item[snippet_start:snippet_end]
|
195 |
-
# Add indication of where the match is
|
196 |
-
snippet = snippet.replace(term, f"**{term}**", 1) # Highlight first match
|
197 |
-
found_snippets.append(f"...{snippet}...")
|
198 |
-
break # Move to the next text item once a match is found
|
199 |
return found_snippets
|
200 |
|
201 |
# --- Other Helper Functions ---
|
@@ -208,7 +169,6 @@ def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
|
|
208 |
|
209 |
def check_red_flags(patient_data: dict) -> List[str]:
|
210 |
"""Checks patient data against predefined red flags."""
|
211 |
-
# (Keep the implementation from the previous full code listing)
|
212 |
flags = []
|
213 |
if not patient_data: return flags
|
214 |
symptoms = patient_data.get("hpi", {}).get("symptoms", [])
|
@@ -243,12 +203,10 @@ def check_red_flags(patient_data: dict) -> List[str]:
|
|
243 |
if "history of mi" in history_lower and "chest pain" in symptoms_lower: flags.append("Red Flag: History of MI with current Chest Pain.")
|
244 |
if "history of dvt/pe" in history_lower and "shortness of breath" in symptoms_lower: flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.")
|
245 |
|
246 |
-
return list(set(flags))
|
247 |
-
|
248 |
|
249 |
def format_patient_data_for_prompt(data: dict) -> str:
|
250 |
"""Formats the patient dictionary into a readable string for the LLM."""
|
251 |
-
# (Keep the implementation from the previous full code listing)
|
252 |
if not data: return "No patient data provided."
|
253 |
prompt_str = ""
|
254 |
for key, value in data.items():
|
@@ -268,7 +226,7 @@ def format_patient_data_for_prompt(data: dict) -> str:
|
|
268 |
|
269 |
# --- Tool Definitions ---
|
270 |
|
271 |
-
# Pydantic models
|
272 |
class LabOrderInput(BaseModel):
|
273 |
test_name: str = Field(..., description="Specific name of the lab test or panel (e.g., 'CBC', 'BMP', 'Troponin I', 'Urinalysis', 'D-dimer').")
|
274 |
reason: str = Field(..., description="Clinical justification for ordering the test (e.g., 'Rule out infection', 'Assess renal function', 'Evaluate for ACS', 'Assess for PE').")
|
@@ -282,8 +240,6 @@ class PrescriptionInput(BaseModel):
|
|
282 |
duration: str = Field("As directed", description="Duration of treatment (e.g., '7 days', '1 month', 'Ongoing', 'Until follow-up').")
|
283 |
reason: str = Field(..., description="Clinical indication for the prescription.")
|
284 |
|
285 |
-
# Updated InteractionCheckInput - Note: current_medications/allergies are Optional here
|
286 |
-
# because they are populated by the tool_node from state *before* execution.
|
287 |
class InteractionCheckInput(BaseModel):
|
288 |
potential_prescription: str = Field(..., description="The name of the NEW medication being considered for prescribing.")
|
289 |
current_medications: Optional[List[str]] = Field(None, description="List of patient's current medication names (populated from state).")
|
@@ -304,10 +260,8 @@ def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> st
|
|
304 |
def prescribe_medication(medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str) -> str:
|
305 |
"""Prescribes a medication with detailed instructions and clinical indication. IMPORTANT: Requires prior interaction check."""
|
306 |
print(f"Executing prescribe_medication: {medication_name} {dosage}...")
|
307 |
-
# Safety check happens in tool_node *before* this is called.
|
308 |
return json.dumps({"status": "success", "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", "details": f"Duration: {duration}. Reason: {reason}"})
|
309 |
|
310 |
-
# --- NEW Interaction Check Tool using UMLS/RxNorm & OpenFDA ---
|
311 |
@tool("check_drug_interactions", args_schema=InteractionCheckInput)
|
312 |
def check_drug_interactions(potential_prescription: str, current_medications: Optional[List[str]] = None, allergies: Optional[List[str]] = None) -> str:
|
313 |
"""
|
@@ -318,103 +272,57 @@ def check_drug_interactions(potential_prescription: str, current_medications: Op
|
|
318 |
print(f"Checking potential prescription: '{potential_prescription}'")
|
319 |
warnings = []
|
320 |
potential_med_lower = potential_prescription.lower().strip()
|
321 |
-
|
322 |
-
# Use provided lists or default to empty
|
323 |
-
current_meds_list = current_medications or []
|
324 |
-
allergies_list = allergies or []
|
325 |
-
# Clean and lowercase current med names (basic extraction: first word)
|
326 |
current_med_names_lower = []
|
327 |
for med in current_meds_list:
|
328 |
-
match = re.match(r"^\s*([a-zA-Z\-]+)", str(med))
|
329 |
if match: current_med_names_lower.append(match.group(1).lower())
|
330 |
-
# Clean and lowercase allergies
|
331 |
allergies_lower = [str(a).lower().strip() for a in allergies_list if a]
|
|
|
332 |
|
333 |
-
print(f"
|
334 |
-
print(f" Against Allergies: {allergies_lower}")
|
335 |
-
|
336 |
-
# --- Step 1: Normalize potential prescription ---
|
337 |
-
print(f" Step 1: Normalizing '{potential_prescription}'...")
|
338 |
-
potential_rxcui = get_rxcui(potential_prescription)
|
339 |
potential_label = get_openfda_label(rxcui=potential_rxcui, drug_name=potential_prescription)
|
340 |
-
if not potential_rxcui and not potential_label:
|
341 |
-
print(f" Warning: Could not find RxCUI or OpenFDA label for '{potential_prescription}'. Interaction check will be limited.")
|
342 |
-
warnings.append(f"INFO: Could not reliably identify '{potential_prescription}' in standard terminologies/databases. Checks may be incomplete.")
|
343 |
|
344 |
-
|
345 |
-
print(" Step 2: Performing Allergy Check...")
|
346 |
-
# Direct name match against patient's allergy list
|
347 |
for allergy in allergies_lower:
|
348 |
-
if allergy == potential_med_lower:
|
349 |
-
|
350 |
-
|
351 |
-
elif allergy in ["
|
352 |
-
warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to Penicillin. High risk with '{potential_prescription}'.")
|
353 |
-
elif allergy == "sulfa" and potential_med_lower in ["sulfamethoxazole", "bactrim", "sulfasalazine"]:
|
354 |
-
warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to Sulfa. High risk with '{potential_prescription}'.")
|
355 |
-
elif allergy in ["nsaids", "aspirin"] and potential_med_lower in ["ibuprofen", "naproxen", "ketorolac", "diclofenac"]:
|
356 |
-
warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to NSAIDs/Aspirin. Risk with '{potential_prescription}'.")
|
357 |
-
|
358 |
-
# Check OpenFDA Label for Contraindications/Warnings related to ALLERGIES
|
359 |
if potential_label:
|
360 |
-
contraindications = potential_label.get("contraindications")
|
361 |
-
warnings_section = potential_label.get("warnings_and_cautions") or potential_label.get("warnings")
|
362 |
-
|
363 |
if contraindications:
|
364 |
allergy_mentions_ci = search_text_list(contraindications, allergies_lower)
|
365 |
-
if allergy_mentions_ci:
|
366 |
-
warnings.append(f"ALLERGY RISK (Contraindication Found): Label for '{potential_prescription}' mentions contraindication potentially related to patient allergies: {'; '.join(allergy_mentions_ci)}")
|
367 |
-
|
368 |
if warnings_section:
|
369 |
allergy_mentions_warn = search_text_list(warnings_section, allergies_lower)
|
370 |
-
if allergy_mentions_warn:
|
371 |
-
warnings.append(f"ALLERGY RISK (Warning Found): Label for '{potential_prescription}' mentions warnings potentially related to patient allergies: {'; '.join(allergy_mentions_warn)}")
|
372 |
|
373 |
-
# --- Step 3: Drug-Drug Interaction Check ---
|
374 |
print(" Step 3: Performing Drug-Drug Interaction Check...")
|
375 |
-
if potential_rxcui or potential_label:
|
376 |
for current_med_name in current_med_names_lower:
|
377 |
-
if not current_med_name or current_med_name == potential_med_lower: continue
|
378 |
-
|
379 |
print(f" Checking interaction between '{potential_prescription}' and '{current_med_name}'...")
|
380 |
-
current_rxcui = get_rxcui(current_med_name)
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
if current_rxcui: search_terms_for_current.append(current_rxcui) # Add RxCUI if found
|
386 |
-
|
387 |
-
search_terms_for_potential = [potential_med_lower]
|
388 |
-
if potential_rxcui: search_terms_for_potential.append(potential_rxcui) # Add RxCUI if found
|
389 |
-
|
390 |
interaction_found_flag = False
|
391 |
-
# Check Potential Drug's Label ('drug_interactions' section) for mentions of Current Drug
|
392 |
if potential_label and potential_label.get("drug_interactions"):
|
393 |
interaction_mentions = search_text_list(potential_label.get("drug_interactions"), search_terms_for_current)
|
394 |
-
if interaction_mentions:
|
395 |
-
|
396 |
-
interaction_found_flag = True
|
397 |
-
|
398 |
-
# Check Current Drug's Label ('drug_interactions' section) for mentions of Potential Drug
|
399 |
-
if current_label and current_label.get("drug_interactions") and not interaction_found_flag: # Avoid duplicate warnings if already found
|
400 |
interaction_mentions = search_text_list(current_label.get("drug_interactions"), search_terms_for_potential)
|
401 |
-
if interaction_mentions:
|
402 |
-
|
403 |
-
|
404 |
-
else: # Case where potential drug wasn't identified
|
405 |
-
warnings.append(f"INFO: Drug-drug interaction check skipped for '{potential_prescription}' as it could not be identified via RxNorm/OpenFDA.")
|
406 |
-
|
407 |
-
|
408 |
-
# --- Step 4: Format Output ---
|
409 |
-
final_warnings = list(set(warnings)) # Remove duplicates
|
410 |
-
status = "warning" if any("CRITICAL" in w or "Interaction" in w or "RISK" in w for w in final_warnings) else "clear"
|
411 |
-
if not final_warnings: status = "clear" # Ensure clear if no warnings remain
|
412 |
|
|
|
|
|
413 |
message = f"Interaction/Allergy check for '{potential_prescription}': {len(final_warnings)} potential issue(s) identified using RxNorm/OpenFDA." if final_warnings else f"No major interactions or allergy issues identified for '{potential_prescription}' based on RxNorm/OpenFDA lookup."
|
414 |
print(f"--- Interaction Check Complete for '{potential_prescription}' ---")
|
415 |
-
|
416 |
return json.dumps({"status": status, "message": message, "warnings": final_warnings})
|
417 |
-
# --- End of NEW Interaction Check Tool ---
|
418 |
|
419 |
@tool("flag_risk", args_schema=FlagRiskInput)
|
420 |
def flag_risk(risk_description: str, urgency: str) -> str:
|
@@ -424,45 +332,22 @@ def flag_risk(risk_description: str, urgency: str) -> str:
|
|
424 |
return json.dumps({"status": "flagged", "message": f"Risk '{risk_description}' flagged with {urgency} urgency."})
|
425 |
|
426 |
# Initialize Search Tool
|
427 |
-
search_tool = TavilySearchResults(
|
428 |
-
max_results=ClinicalAppSettings.MAX_SEARCH_RESULTS,
|
429 |
-
name="tavily_search_results"
|
430 |
-
)
|
431 |
|
432 |
# --- LangGraph Setup ---
|
433 |
-
|
434 |
-
# Define the state structure
|
435 |
class AgentState(TypedDict):
|
436 |
-
messages: Annotated[list[Any], operator.add]
|
437 |
-
|
438 |
-
|
439 |
-
# Define Tools and Tool Executor
|
440 |
-
tools = [
|
441 |
-
order_lab_test,
|
442 |
-
prescribe_medication,
|
443 |
-
check_drug_interactions, # Using the new implementation
|
444 |
-
flag_risk,
|
445 |
-
search_tool
|
446 |
-
]
|
447 |
tool_executor = ToolExecutor(tools)
|
448 |
-
|
449 |
-
# Define the Agent Model
|
450 |
-
model = ChatGroq(
|
451 |
-
temperature=ClinicalAppSettings.TEMPERATURE,
|
452 |
-
model=ClinicalAppSettings.MODEL_NAME,
|
453 |
-
)
|
454 |
model_with_tools = model.bind_tools(tools)
|
455 |
|
456 |
-
# --- Graph Nodes
|
457 |
-
|
458 |
-
# 1. Agent Node: Calls the LLM (No change needed from previous version)
|
459 |
def agent_node(state: AgentState):
|
460 |
-
"""Invokes the LLM to decide the next action or response."""
|
461 |
print("\n---AGENT NODE---")
|
462 |
current_messages = state['messages']
|
463 |
if not current_messages or not isinstance(current_messages[0], SystemMessage):
|
464 |
-
print("Prepending System Prompt.")
|
465 |
-
current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages
|
466 |
print(f"Invoking LLM with {len(current_messages)} messages.")
|
467 |
try:
|
468 |
response = model_with_tools.invoke(current_messages)
|
@@ -470,64 +355,36 @@ def agent_node(state: AgentState):
|
|
470 |
if hasattr(response, 'tool_calls') and response.tool_calls: print(f"Agent Response Tool Calls: {response.tool_calls}")
|
471 |
else: print("Agent Response: No tool calls.")
|
472 |
except Exception as e:
|
473 |
-
print(f"ERROR in agent_node during LLM invocation: {type(e).__name__} - {e}")
|
474 |
-
traceback.print_exc()
|
475 |
error_message = AIMessage(content=f"Sorry, an internal error occurred while processing the request: {type(e).__name__}")
|
476 |
return {"messages": [error_message]}
|
477 |
return {"messages": [response]}
|
478 |
|
479 |
-
# 2. Tool Node: Executes tools (Mostly the same, ensures context injection)
|
480 |
def tool_node(state: AgentState):
|
481 |
-
"""Executes tools called by the LLM and returns results."""
|
482 |
print("\n---TOOL NODE---")
|
483 |
-
tool_messages = []
|
484 |
-
last_message = state['messages'][-1]
|
485 |
-
|
486 |
if not isinstance(last_message, AIMessage) or not getattr(last_message, 'tool_calls', None):
|
487 |
-
print("Warning: Tool node called unexpectedly without tool calls.")
|
488 |
-
|
489 |
-
|
490 |
-
tool_calls = last_message.tool_calls
|
491 |
-
print(f"Tool calls received: {json.dumps(tool_calls, indent=2)}")
|
492 |
-
|
493 |
-
# Safety Check Logic (No change needed from previous version)
|
494 |
-
prescriptions_requested = {}
|
495 |
-
interaction_checks_requested = {}
|
496 |
for call in tool_calls:
|
497 |
tool_name = call.get('name'); tool_args = call.get('args', {})
|
498 |
-
if tool_name == 'prescribe_medication':
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
if potential_med: interaction_checks_requested[potential_med] = call
|
504 |
-
|
505 |
-
valid_tool_calls_for_execution = []
|
506 |
-
blocked_ids = set()
|
507 |
for med_name, prescribe_call in prescriptions_requested.items():
|
508 |
if med_name not in interaction_checks_requested:
|
509 |
st.error(f"**Safety Violation:** AI attempted to prescribe '{med_name}' without requesting `check_drug_interactions` in the *same turn*. Prescription blocked.")
|
510 |
error_msg = ToolMessage(content=json.dumps({"status": "error", "message": f"Interaction check for '{med_name}' must be requested *before or alongside* the prescription call."}), tool_call_id=prescribe_call['id'], name=prescribe_call['name'])
|
511 |
-
tool_messages.append(error_msg)
|
512 |
-
blocked_ids.add(prescribe_call['id'])
|
513 |
-
|
514 |
valid_tool_calls_for_execution = [call for call in tool_calls if call['id'] not in blocked_ids]
|
515 |
-
|
516 |
-
# Augment interaction checks with patient data (Crucial part - no change needed here)
|
517 |
-
patient_data = state.get("patient_data", {})
|
518 |
-
patient_meds_full = patient_data.get("medications", {}).get("current", []) # Pass full med list if needed by tool
|
519 |
-
patient_allergies = patient_data.get("allergies", [])
|
520 |
-
|
521 |
for call in valid_tool_calls_for_execution:
|
522 |
if call['name'] == 'check_drug_interactions':
|
523 |
if 'args' not in call: call['args'] = {}
|
524 |
-
|
525 |
-
# The tool function expects 'current_medications' (list of names) and 'allergies'
|
526 |
-
call['args']['current_medications'] = patient_meds_full # Pass the full strings
|
527 |
-
call['args']['allergies'] = patient_allergies
|
528 |
-
print(f"Augmented interaction check args for call ID {call['id']}") # Removed args content for brevity
|
529 |
-
|
530 |
-
# Execute valid tool calls (No change needed from previous version)
|
531 |
if valid_tool_calls_for_execution:
|
532 |
print(f"Attempting to execute {len(valid_tool_calls_for_execution)} tools: {[c['name'] for c in valid_tool_calls_for_execution]}")
|
533 |
try:
|
@@ -535,151 +392,79 @@ def tool_node(state: AgentState):
|
|
535 |
for call, resp in zip(valid_tool_calls_for_execution, responses):
|
536 |
tool_call_id = call['id']; tool_name = call['name']
|
537 |
if isinstance(resp, Exception):
|
538 |
-
error_type = type(resp).__name__; error_str = str(resp)
|
539 |
-
|
540 |
-
traceback.print_exc()
|
541 |
-
st.error(f"Error executing action '{tool_name}': {error_type}")
|
542 |
-
error_content = json.dumps({"status": "error", "message": f"Failed to execute '{tool_name}': {error_type} - {error_str}"})
|
543 |
tool_messages.append(ToolMessage(content=error_content, tool_call_id=tool_call_id, name=tool_name))
|
544 |
-
if isinstance(resp, AttributeError) and "'dict' object has no attribute 'tool'" in error_str:
|
545 |
-
print("\n *** DETECTED SPECIFIC ATTRIBUTE ERROR ('dict' object has no attribute 'tool') *** \n")
|
546 |
else:
|
547 |
-
print(f"Tool '{tool_name}' (ID: {tool_call_id}) executed successfully.")
|
548 |
-
content_str = str(resp)
|
549 |
-
tool_messages.append(ToolMessage(content=content_str, tool_call_id=tool_call_id, name=tool_name))
|
550 |
except Exception as e:
|
551 |
-
print(f"CRITICAL UNEXPECTED ERROR within tool_node logic: {type(e).__name__} - {e}")
|
552 |
-
|
553 |
-
error_content = json.dumps({"status": "error", "message": f"Internal error processing tools: {e}"})
|
554 |
-
processed_ids = {msg.tool_call_id for msg in tool_messages}
|
555 |
for call in valid_tool_calls_for_execution:
|
556 |
if call['id'] not in processed_ids: tool_messages.append(ToolMessage(content=error_content, tool_call_id=call['id'], name=call['name']))
|
|
|
557 |
|
558 |
-
|
559 |
-
return {"messages": tool_messages}
|
560 |
-
|
561 |
-
|
562 |
-
# --- Graph Edges (Routing Logic) --- (No change needed)
|
563 |
def should_continue(state: AgentState) -> str:
|
564 |
-
""
|
565 |
-
print("\n---ROUTING DECISION---")
|
566 |
-
last_message = state['messages'][-1] if state['messages'] else None
|
567 |
if not isinstance(last_message, AIMessage): return "end_conversation_turn"
|
568 |
if "Sorry, an internal error occurred" in last_message.content: return "end_conversation_turn"
|
569 |
if getattr(last_message, 'tool_calls', None): return "continue_tools"
|
570 |
else: return "end_conversation_turn"
|
571 |
|
572 |
-
# --- Graph Definition & Compilation ---
|
573 |
-
workflow = StateGraph(AgentState)
|
574 |
-
workflow.
|
575 |
-
workflow.
|
576 |
-
workflow.set_entry_point("agent")
|
577 |
-
workflow.add_conditional_edges("agent", should_continue, {"continue_tools": "tools", "end_conversation_turn": END})
|
578 |
-
workflow.add_edge("tools", "agent")
|
579 |
-
app = workflow.compile()
|
580 |
-
print("LangGraph compiled successfully.")
|
581 |
|
582 |
# --- Streamlit UI ---
|
583 |
def main():
|
584 |
st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
|
585 |
st.title(f"π©Ί {ClinicalAppSettings.APP_TITLE}")
|
586 |
st.caption(f"Interactive Assistant | LangGraph/Groq/Tavily/UMLS/OpenFDA | Model: {ClinicalAppSettings.MODEL_NAME}")
|
587 |
-
|
588 |
-
# Initialize session state (No change needed)
|
589 |
if "messages" not in st.session_state: st.session_state.messages = []
|
590 |
if "patient_data" not in st.session_state: st.session_state.patient_data = None
|
591 |
if "graph_app" not in st.session_state: st.session_state.graph_app = app
|
592 |
|
593 |
-
# --- Patient Data Input Sidebar ---
|
594 |
with st.sidebar:
|
595 |
st.header("π Patient Intake Form")
|
596 |
-
# Demographics, HPI, History, Social/Family, Vitals/Exam
|
597 |
-
|
598 |
-
st.subheader("
|
599 |
-
|
600 |
-
|
601 |
-
st.subheader("
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
st.
|
606 |
-
|
607 |
-
|
608 |
-
st.subheader("Medications & Allergies")
|
609 |
-
current_meds_str = st.text_area("Current Medications (name, dose, freq)", "Lisinopril 10mg daily\nMetformin 1000mg BID\nAtorvastatin 40mg daily\nAspirin 81mg daily", key="meds_input")
|
610 |
-
allergies_str = st.text_area("Allergies (comma separated, specify reaction if known)", "Penicillin (rash), Sulfa (hives)", key="allergy_input")
|
611 |
-
st.subheader("Social/Family History")
|
612 |
-
social_history = st.text_area("Social History (SH)", "Smoker (1 ppd x 30 years), occasional alcohol.", key="sh_input")
|
613 |
-
family_history = st.text_area("Family History (FHx)", "Father had MI at age 60. Mother has HTN.", key="fhx_input")
|
614 |
-
st.subheader("Vitals & Exam Findings")
|
615 |
-
col1, col2 = st.columns(2)
|
616 |
-
with col1:
|
617 |
-
temp_c = st.number_input("Temp (Β°C)", 35.0, 42.0, 36.8, format="%.1f", key="temp_input")
|
618 |
-
hr_bpm = st.number_input("HR (bpm)", 30, 250, 95, key="hr_input")
|
619 |
-
rr_rpm = st.number_input("RR (rpm)", 5, 50, 18, key="rr_input")
|
620 |
-
with col2:
|
621 |
-
bp_mmhg = st.text_input("BP (SYS/DIA)", "155/90", key="bp_input")
|
622 |
-
spo2_percent = st.number_input("SpO2 (%)", 70, 100, 96, key="spo2_input")
|
623 |
-
pain_scale = st.slider("Pain (0-10)", 0, 10, 8, key="pain_input")
|
624 |
-
exam_notes = st.text_area("Brief Physical Exam Notes", "Awake, alert, oriented x3. Mild distress. Lungs clear bilaterally...", key="exam_input", height=100)
|
625 |
-
|
626 |
-
|
627 |
-
# Compile Patient Data Dictionary (Refined Extraction for Tool Use)
|
628 |
-
if st.button("Start/Update Consultation", key="start_button"):
|
629 |
-
# Store full medication strings for display/context
|
630 |
current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
match = re.match(r"^\s*([a-zA-Z\-]+)", med)
|
635 |
-
if match: current_med_names_only.append(match.group(1).lower())
|
636 |
-
|
637 |
-
# Extract allergy names (simplified, before parenthesis)
|
638 |
allergies_list = []
|
639 |
for a in allergies_str.split(','):
|
640 |
-
cleaned_allergy = a.strip()
|
641 |
-
if cleaned_allergy:
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
# Store both full list and names_only list
|
651 |
-
"medications": {"current": current_meds_list, "names_only": current_med_names_only},
|
652 |
-
"allergies": allergies_list, # Store cleaned list
|
653 |
-
"social_history": {"details": social_history}, "family_history": {"details": family_history},
|
654 |
-
"vitals": { "temp_c": temp_c, "hr_bpm": hr_bpm, "bp_mmhg": bp_mmhg, "rr_rpm": rr_rpm, "spo2_percent": spo2_percent, "pain_scale": pain_scale},
|
655 |
-
"exam_findings": {"notes": exam_notes}
|
656 |
-
}
|
657 |
-
|
658 |
-
# Initial Red Flag Check
|
659 |
-
red_flags = check_red_flags(st.session_state.patient_data)
|
660 |
-
st.sidebar.markdown("---")
|
661 |
-
if red_flags:
|
662 |
-
st.sidebar.warning("**Initial Red Flags Detected:**")
|
663 |
-
for flag in red_flags: st.sidebar.warning(f"- {flag.replace('Red Flag: ','')}")
|
664 |
-
else: st.sidebar.success("No immediate red flags detected.")
|
665 |
-
|
666 |
-
# Prepare initial message & reset history
|
667 |
-
initial_prompt = "Initiate consultation for the patient described in the intake form. Review the data and begin analysis."
|
668 |
-
st.session_state.messages = [HumanMessage(content=initial_prompt)]
|
669 |
-
st.success("Patient data loaded/updated. Ready for analysis.")
|
670 |
-
|
671 |
-
|
672 |
-
# --- Main Chat Interface Area --- (No change needed in display logic)
|
673 |
st.header("π¬ Clinical Consultation")
|
674 |
-
|
675 |
-
|
676 |
-
# (Copy the message display loop from the previous full code version)
|
677 |
-
for msg_index, msg in enumerate(st.session_state.messages):
|
678 |
-
unique_key = f"msg_{msg_index}"
|
679 |
if isinstance(msg, HumanMessage):
|
680 |
-
with st.chat_message("user"
|
681 |
elif isinstance(msg, AIMessage):
|
682 |
-
with st.chat_message("assistant"
|
683 |
ai_content = msg.content; structured_output = None
|
684 |
try:
|
685 |
json_match = re.search(r"```json\s*(\{.*?\})\s*```", ai_content, re.DOTALL | re.IGNORECASE)
|
@@ -692,90 +477,58 @@ def main():
|
|
692 |
structured_output = json.loads(ai_content); ai_content = ""
|
693 |
else: st.markdown(ai_content)
|
694 |
except Exception as e: st.markdown(ai_content); print(f"Error parsing/displaying AI JSON: {e}")
|
695 |
-
|
696 |
if structured_output and isinstance(structured_output, dict):
|
697 |
-
|
698 |
-
st.divider(); st.subheader("π AI Analysis & Recommendations")
|
699 |
cols = st.columns(2)
|
700 |
-
with cols[0]:
|
701 |
-
st.markdown(
|
702 |
-
st.markdown(
|
703 |
-
ddx
|
704 |
-
if ddx:
|
705 |
-
for item in ddx:
|
706 |
-
likelihood = item.get('likelihood', '?').capitalize(); icon = "π₯" if likelihood=="High" else ("π₯" if likelihood=="Medium" else "π₯")
|
707 |
-
with st.expander(f"{icon} {item.get('diagnosis', 'Unknown')} ({likelihood})"): st.write(f"**Rationale:** {item.get('rationale', 'N/A')}")
|
708 |
else: st.info("No DDx provided.")
|
709 |
-
st.markdown(
|
710 |
-
flags = risk.get('identified_red_flags', []); concerns = risk.get("immediate_concerns", []); comps = risk.get("potential_complications", [])
|
711 |
if flags: st.warning(f"**Flags:** {', '.join(flags)}")
|
712 |
if concerns: st.warning(f"**Concerns:** {', '.join(concerns)}")
|
713 |
if comps: st.info(f"**Potential Complications:** {', '.join(comps)}")
|
714 |
if not flags and not concerns: st.success("No major risks highlighted.")
|
715 |
-
with cols[1]:
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
if items: [st.markdown(f"- {item}") for item in items] if isinstance(items, list) else st.markdown(f"- {items}")
|
720 |
-
else: st.markdown("_None suggested._")
|
721 |
-
st.markdown("") # Space
|
722 |
-
st.markdown(f"**Rationale & Guideline Check:**"); st.markdown(f"> {structured_output.get('rationale_summary', 'N/A')}")
|
723 |
interaction_summary = structured_output.get("interaction_check_summary", "")
|
724 |
-
if interaction_summary: st.markdown(
|
725 |
st.divider()
|
726 |
-
|
727 |
if getattr(msg, 'tool_calls', None):
|
728 |
-
with st.expander("π οΈ AI requested actions", expanded=False):
|
729 |
-
|
730 |
-
|
731 |
-
except Exception as display_e: st.error(f"Could not display tool call: {display_e}"); st.code(str(tc))
|
732 |
-
|
733 |
elif isinstance(msg, ToolMessage):
|
734 |
tool_name_display = getattr(msg, 'name', 'tool_execution')
|
735 |
-
with st.chat_message(tool_name_display, avatar="π οΈ"
|
736 |
-
#
|
737 |
-
|
738 |
-
tool_data = json.loads(msg.content); status = tool_data.get("status", "info"); message = tool_data.get("message", msg.content)
|
739 |
-
details = tool_data.get("details"); warnings = tool_data.get("warnings")
|
740 |
if status == "success" or status == "clear" or status == "flagged": st.success(f"{message}", icon="β
" if status != "flagged" else "π¨")
|
741 |
-
elif status == "warning":
|
742 |
-
|
743 |
-
if warnings and isinstance(warnings, list):
|
744 |
-
st.caption("Details:")
|
745 |
-
for warn in warnings: st.caption(f"- {warn}") # Display warnings from the tool output JSON
|
746 |
else: st.error(f"{message}", icon="β")
|
747 |
if details: st.caption(f"Details: {details}")
|
748 |
-
except json.JSONDecodeError: st.info(f"{msg.content}")
|
749 |
except Exception as e: st.error(f"Error displaying tool message: {e}", icon="β"); st.caption(f"Raw content: {msg.content}")
|
750 |
|
751 |
-
|
752 |
-
# --- Chat Input Logic --- (No change needed)
|
753 |
if prompt := st.chat_input("Your message or follow-up query..."):
|
754 |
-
if not st.session_state.patient_data:
|
755 |
-
|
756 |
-
|
757 |
-
user_message = HumanMessage(content=prompt)
|
758 |
-
st.session_state.messages.append(user_message)
|
759 |
-
with st.chat_message("user"): st.markdown(prompt)
|
760 |
-
|
761 |
current_state = AgentState(messages=st.session_state.messages, patient_data=st.session_state.patient_data)
|
762 |
-
|
763 |
with st.spinner("SynapseAI is thinking..."):
|
764 |
try:
|
765 |
final_state = st.session_state.graph_app.invoke(current_state, {"recursion_limit": 15})
|
766 |
-
st.session_state.messages = final_state['messages']
|
767 |
-
except Exception as e:
|
768 |
-
print(f"CRITICAL ERROR during graph invocation: {type(e).__name__} - {e}"); traceback.print_exc()
|
769 |
-
st.error(f"An error occurred during the conversation turn: {e}", icon="β")
|
770 |
-
# Optionally add error to history for user visibility
|
771 |
-
# error_ai_msg = AIMessage(content=f"Sorry, a critical error occurred: {type(e).__name__}. Please check logs or try again.")
|
772 |
-
# st.session_state.messages.append(error_ai_msg)
|
773 |
-
|
774 |
st.rerun() # Refresh display
|
775 |
|
776 |
-
# Disclaimer
|
777 |
-
st.markdown("---")
|
778 |
-
st.warning("""**Disclaimer:** SynapseAI is an AI assistant... (Verify all outputs)""")
|
779 |
|
780 |
if __name__ == "__main__":
|
781 |
main()
|
|
|
11 |
from langchain_groq import ChatGroq
|
12 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
13 |
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
|
14 |
+
# from langchain_core.prompts import ChatPromptTemplate # Not explicitly used in this version
|
15 |
from langchain_core.pydantic_v1 import BaseModel, Field
|
16 |
from langchain_core.tools import tool
|
17 |
from langgraph.prebuilt import ToolExecutor
|
|
|
33 |
if not TAVILY_API_KEY: missing_keys.append("TAVILY_API_KEY")
|
34 |
|
35 |
if missing_keys:
|
36 |
+
# Use st.error which stops execution in recent Streamlit versions
|
37 |
st.error(f"Missing required API Key(s): {', '.join(missing_keys)}. Please set them in Hugging Face Space Secrets or your environment variables.")
|
38 |
+
# Ensure execution stops if st.error doesn't automatically do it in the environment
|
39 |
st.stop()
|
40 |
|
41 |
+
|
42 |
# --- Configuration & Constants ---
|
43 |
class ClinicalAppSettings:
|
44 |
APP_TITLE = "SynapseAI: Interactive Clinical Decision Support (UMLS/FDA Integrated)"
|
|
|
106 |
response = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10)
|
107 |
response.raise_for_status()
|
108 |
data = response.json()
|
|
|
109 |
if data and "idGroup" in data and "rxnormId" in data["idGroup"]:
|
|
|
|
|
110 |
rxcui = data["idGroup"]["rxnormId"][0]
|
111 |
print(f" Found RxCUI: {rxcui} for '{drug_name}'")
|
112 |
return rxcui
|
113 |
+
else: # Fallback search
|
114 |
+
params = {"name": drug_name}; response = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10)
|
115 |
+
response.raise_for_status(); data = response.json()
|
|
|
|
|
|
|
116 |
if data and "drugGroup" in data and "conceptGroup" in data["drugGroup"]:
|
117 |
for group in data["drugGroup"]["conceptGroup"]:
|
|
|
118 |
if group.get("tty") in ["SBD", "SCD", "GPCK", "BPCK", "IN", "MIN", "PIN"]:
|
119 |
if "conceptProperties" in group and group["conceptProperties"]:
|
120 |
rxcui = group["conceptProperties"][0].get("rxcui")
|
121 |
+
if rxcui: print(f" Found RxCUI (via /drugs): {rxcui} for '{drug_name}'"); return rxcui
|
|
|
|
|
122 |
print(f" RxCUI not found for '{drug_name}'.")
|
123 |
return None
|
124 |
+
except requests.exceptions.RequestException as e: print(f" Error fetching RxCUI for '{drug_name}': {e}"); return None
|
125 |
+
except json.JSONDecodeError as e: print(f" Error decoding RxNorm JSON response for '{drug_name}': {e}"); return None
|
126 |
+
except Exception as e: print(f" Unexpected error in get_rxcui for '{drug_name}': {e}"); return None
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
@lru_cache(maxsize=128) # Cache OpenFDA lookups
|
129 |
def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[dict]:
|
130 |
"""Fetches drug label information from OpenFDA using RxCUI or drug name."""
|
131 |
if not rxcui and not drug_name: return None
|
132 |
print(f"OpenFDA Label Lookup for: RXCUI={rxcui}, Name={drug_name}")
|
|
|
133 |
search_terms = []
|
134 |
+
if rxcui: search_terms.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"')
|
135 |
+
if drug_name: search_terms.append(f'(openfda.brand_name:"{drug_name.lower()}" OR openfda.generic_name:"{drug_name.lower()}")')
|
136 |
+
search_query = " OR ".join(search_terms); params = {"search": search_query, "limit": 1}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
try:
|
138 |
response = requests.get(OPENFDA_API_BASE, params=params, timeout=15)
|
139 |
+
response.raise_for_status(); data = response.json()
|
140 |
+
if data and "results" in data and data["results"]: print(f" Found OpenFDA label for query: {search_query}"); return data["results"][0]
|
141 |
+
print(f" No OpenFDA label found for query: {search_query}"); return None
|
142 |
+
except requests.exceptions.RequestException as e: print(f" Error fetching OpenFDA label: {e}"); return None
|
143 |
+
except json.JSONDecodeError as e: print(f" Error decoding OpenFDA JSON response: {e}"); return None
|
144 |
+
except Exception as e: print(f" Unexpected error in get_openfda_label: {e}"); return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
def search_text_list(text_list: Optional[List[str]], search_terms: List[str]) -> List[str]:
|
147 |
""" Case-insensitive search for any search_term within a list of text strings. Returns snippets. """
|
148 |
found_snippets = []
|
149 |
if not text_list or not search_terms: return found_snippets
|
|
|
150 |
search_terms_lower = [str(term).lower() for term in search_terms if term]
|
|
|
151 |
for text_item in text_list:
|
152 |
+
if not isinstance(text_item, str): continue
|
153 |
text_item_lower = text_item.lower()
|
154 |
for term in search_terms_lower:
|
155 |
if term in text_item_lower:
|
156 |
+
start_index = text_item_lower.find(term); snippet_start = max(0, start_index - 50)
|
157 |
+
snippet_end = min(len(text_item), start_index + len(term) + 100); snippet = text_item[snippet_start:snippet_end]
|
158 |
+
snippet = snippet.replace(term, f"**{term}**", 1); found_snippets.append(f"...{snippet}...")
|
159 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
return found_snippets
|
161 |
|
162 |
# --- Other Helper Functions ---
|
|
|
169 |
|
170 |
def check_red_flags(patient_data: dict) -> List[str]:
|
171 |
"""Checks patient data against predefined red flags."""
|
|
|
172 |
flags = []
|
173 |
if not patient_data: return flags
|
174 |
symptoms = patient_data.get("hpi", {}).get("symptoms", [])
|
|
|
203 |
if "history of mi" in history_lower and "chest pain" in symptoms_lower: flags.append("Red Flag: History of MI with current Chest Pain.")
|
204 |
if "history of dvt/pe" in history_lower and "shortness of breath" in symptoms_lower: flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.")
|
205 |
|
206 |
+
return list(set(flags))
|
|
|
207 |
|
208 |
def format_patient_data_for_prompt(data: dict) -> str:
|
209 |
"""Formats the patient dictionary into a readable string for the LLM."""
|
|
|
210 |
if not data: return "No patient data provided."
|
211 |
prompt_str = ""
|
212 |
for key, value in data.items():
|
|
|
226 |
|
227 |
# --- Tool Definitions ---
|
228 |
|
229 |
+
# Pydantic models
|
230 |
class LabOrderInput(BaseModel):
|
231 |
test_name: str = Field(..., description="Specific name of the lab test or panel (e.g., 'CBC', 'BMP', 'Troponin I', 'Urinalysis', 'D-dimer').")
|
232 |
reason: str = Field(..., description="Clinical justification for ordering the test (e.g., 'Rule out infection', 'Assess renal function', 'Evaluate for ACS', 'Assess for PE').")
|
|
|
240 |
duration: str = Field("As directed", description="Duration of treatment (e.g., '7 days', '1 month', 'Ongoing', 'Until follow-up').")
|
241 |
reason: str = Field(..., description="Clinical indication for the prescription.")
|
242 |
|
|
|
|
|
243 |
class InteractionCheckInput(BaseModel):
|
244 |
potential_prescription: str = Field(..., description="The name of the NEW medication being considered for prescribing.")
|
245 |
current_medications: Optional[List[str]] = Field(None, description="List of patient's current medication names (populated from state).")
|
|
|
260 |
def prescribe_medication(medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str) -> str:
|
261 |
"""Prescribes a medication with detailed instructions and clinical indication. IMPORTANT: Requires prior interaction check."""
|
262 |
print(f"Executing prescribe_medication: {medication_name} {dosage}...")
|
|
|
263 |
return json.dumps({"status": "success", "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", "details": f"Duration: {duration}. Reason: {reason}"})
|
264 |
|
|
|
265 |
@tool("check_drug_interactions", args_schema=InteractionCheckInput)
|
266 |
def check_drug_interactions(potential_prescription: str, current_medications: Optional[List[str]] = None, allergies: Optional[List[str]] = None) -> str:
|
267 |
"""
|
|
|
272 |
print(f"Checking potential prescription: '{potential_prescription}'")
|
273 |
warnings = []
|
274 |
potential_med_lower = potential_prescription.lower().strip()
|
275 |
+
current_meds_list = current_medications or []; allergies_list = allergies or []
|
|
|
|
|
|
|
|
|
276 |
current_med_names_lower = []
|
277 |
for med in current_meds_list:
|
278 |
+
match = re.match(r"^\s*([a-zA-Z\-]+)", str(med));
|
279 |
if match: current_med_names_lower.append(match.group(1).lower())
|
|
|
280 |
allergies_lower = [str(a).lower().strip() for a in allergies_list if a]
|
281 |
+
print(f" Against Current Meds (names): {current_med_names_lower}"); print(f" Against Allergies: {allergies_lower}")
|
282 |
|
283 |
+
print(f" Step 1: Normalizing '{potential_prescription}'..."); potential_rxcui = get_rxcui(potential_prescription)
|
|
|
|
|
|
|
|
|
|
|
284 |
potential_label = get_openfda_label(rxcui=potential_rxcui, drug_name=potential_prescription)
|
285 |
+
if not potential_rxcui and not potential_label: warnings.append(f"INFO: Could not reliably identify '{potential_prescription}'. Checks may be incomplete.")
|
|
|
|
|
286 |
|
287 |
+
print(" Step 2: Performing Allergy Check...");
|
|
|
|
|
288 |
for allergy in allergies_lower:
|
289 |
+
if allergy == potential_med_lower: warnings.append(f"CRITICAL ALLERGY (Name Match): Patient allergic to '{allergy}'. Potential prescription is '{potential_prescription}'.")
|
290 |
+
elif allergy in ["penicillin", "pcns"] and potential_med_lower in ["amoxicillin", "ampicillin", "augmentin", "piperacillin"]: warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to Penicillin. High risk with '{potential_prescription}'.")
|
291 |
+
elif allergy == "sulfa" and potential_med_lower in ["sulfamethoxazole", "bactrim", "sulfasalazine"]: warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to Sulfa. High risk with '{potential_prescription}'.")
|
292 |
+
elif allergy in ["nsaids", "aspirin"] and potential_med_lower in ["ibuprofen", "naproxen", "ketorolac", "diclofenac"]: warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to NSAIDs/Aspirin. Risk with '{potential_prescription}'.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
if potential_label:
|
294 |
+
contraindications = potential_label.get("contraindications"); warnings_section = potential_label.get("warnings_and_cautions") or potential_label.get("warnings")
|
|
|
|
|
295 |
if contraindications:
|
296 |
allergy_mentions_ci = search_text_list(contraindications, allergies_lower)
|
297 |
+
if allergy_mentions_ci: warnings.append(f"ALLERGY RISK (Contraindication Found): Label for '{potential_prescription}' mentions contraindication potentially related to patient allergies: {'; '.join(allergy_mentions_ci)}")
|
|
|
|
|
298 |
if warnings_section:
|
299 |
allergy_mentions_warn = search_text_list(warnings_section, allergies_lower)
|
300 |
+
if allergy_mentions_warn: warnings.append(f"ALLERGY RISK (Warning Found): Label for '{potential_prescription}' mentions warnings potentially related to patient allergies: {'; '.join(allergy_mentions_warn)}")
|
|
|
301 |
|
|
|
302 |
print(" Step 3: Performing Drug-Drug Interaction Check...")
|
303 |
+
if potential_rxcui or potential_label:
|
304 |
for current_med_name in current_med_names_lower:
|
305 |
+
if not current_med_name or current_med_name == potential_med_lower: continue
|
|
|
306 |
print(f" Checking interaction between '{potential_prescription}' and '{current_med_name}'...")
|
307 |
+
current_rxcui = get_rxcui(current_med_name); current_label = get_openfda_label(rxcui=current_rxcui, drug_name=current_med_name)
|
308 |
+
search_terms_for_current = [current_med_name];
|
309 |
+
if current_rxcui: search_terms_for_current.append(current_rxcui)
|
310 |
+
search_terms_for_potential = [potential_med_lower];
|
311 |
+
if potential_rxcui: search_terms_for_potential.append(potential_rxcui)
|
|
|
|
|
|
|
|
|
|
|
312 |
interaction_found_flag = False
|
|
|
313 |
if potential_label and potential_label.get("drug_interactions"):
|
314 |
interaction_mentions = search_text_list(potential_label.get("drug_interactions"), search_terms_for_current)
|
315 |
+
if interaction_mentions: warnings.append(f"Potential Interaction ({potential_prescription.capitalize()} Label): Mentions '{current_med_name.capitalize()}'. Snippets: {'; '.join(interaction_mentions)}"); interaction_found_flag = True
|
316 |
+
if current_label and current_label.get("drug_interactions") and not interaction_found_flag:
|
|
|
|
|
|
|
|
|
317 |
interaction_mentions = search_text_list(current_label.get("drug_interactions"), search_terms_for_potential)
|
318 |
+
if interaction_mentions: warnings.append(f"Potential Interaction ({current_med_name.capitalize()} Label): Mentions '{potential_prescription.capitalize()}'. Snippets: {'; '.join(interaction_mentions)}")
|
319 |
+
else: warnings.append(f"INFO: Drug-drug interaction check skipped for '{potential_prescription}' as it could not be identified via RxNorm/OpenFDA.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
+
final_warnings = list(set(warnings)); status = "warning" if any("CRITICAL" in w or "Interaction" in w or "RISK" in w for w in final_warnings) else "clear"
|
322 |
+
if not final_warnings: status = "clear"
|
323 |
message = f"Interaction/Allergy check for '{potential_prescription}': {len(final_warnings)} potential issue(s) identified using RxNorm/OpenFDA." if final_warnings else f"No major interactions or allergy issues identified for '{potential_prescription}' based on RxNorm/OpenFDA lookup."
|
324 |
print(f"--- Interaction Check Complete for '{potential_prescription}' ---")
|
|
|
325 |
return json.dumps({"status": status, "message": message, "warnings": final_warnings})
|
|
|
326 |
|
327 |
@tool("flag_risk", args_schema=FlagRiskInput)
|
328 |
def flag_risk(risk_description: str, urgency: str) -> str:
|
|
|
332 |
return json.dumps({"status": "flagged", "message": f"Risk '{risk_description}' flagged with {urgency} urgency."})
|
333 |
|
334 |
# Initialize Search Tool
|
335 |
+
search_tool = TavilySearchResults(max_results=ClinicalAppSettings.MAX_SEARCH_RESULTS, name="tavily_search_results")
|
|
|
|
|
|
|
336 |
|
337 |
# --- LangGraph Setup ---
|
|
|
|
|
338 |
class AgentState(TypedDict):
|
339 |
+
messages: Annotated[list[Any], operator.add]; patient_data: Optional[dict]
|
340 |
+
tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
tool_executor = ToolExecutor(tools)
|
342 |
+
model = ChatGroq(temperature=ClinicalAppSettings.TEMPERATURE, model=ClinicalAppSettings.MODEL_NAME)
|
|
|
|
|
|
|
|
|
|
|
343 |
model_with_tools = model.bind_tools(tools)
|
344 |
|
345 |
+
# --- Graph Nodes ---
|
|
|
|
|
346 |
def agent_node(state: AgentState):
|
|
|
347 |
print("\n---AGENT NODE---")
|
348 |
current_messages = state['messages']
|
349 |
if not current_messages or not isinstance(current_messages[0], SystemMessage):
|
350 |
+
print("Prepending System Prompt."); current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages
|
|
|
351 |
print(f"Invoking LLM with {len(current_messages)} messages.")
|
352 |
try:
|
353 |
response = model_with_tools.invoke(current_messages)
|
|
|
355 |
if hasattr(response, 'tool_calls') and response.tool_calls: print(f"Agent Response Tool Calls: {response.tool_calls}")
|
356 |
else: print("Agent Response: No tool calls.")
|
357 |
except Exception as e:
|
358 |
+
print(f"ERROR in agent_node during LLM invocation: {type(e).__name__} - {e}"); traceback.print_exc()
|
|
|
359 |
error_message = AIMessage(content=f"Sorry, an internal error occurred while processing the request: {type(e).__name__}")
|
360 |
return {"messages": [error_message]}
|
361 |
return {"messages": [response]}
|
362 |
|
|
|
363 |
def tool_node(state: AgentState):
|
|
|
364 |
print("\n---TOOL NODE---")
|
365 |
+
tool_messages = []; last_message = state['messages'][-1]
|
|
|
|
|
366 |
if not isinstance(last_message, AIMessage) or not getattr(last_message, 'tool_calls', None):
|
367 |
+
print("Warning: Tool node called unexpectedly without tool calls."); return {"messages": []}
|
368 |
+
tool_calls = last_message.tool_calls; print(f"Tool calls received: {json.dumps(tool_calls, indent=2)}")
|
369 |
+
prescriptions_requested = {}; interaction_checks_requested = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
for call in tool_calls:
|
371 |
tool_name = call.get('name'); tool_args = call.get('args', {})
|
372 |
+
if tool_name == 'prescribe_medication': med_name = tool_args.get('medication_name', '').lower();
|
373 |
+
if med_name: prescriptions_requested[med_name] = call
|
374 |
+
elif tool_name == 'check_drug_interactions': potential_med = tool_args.get('potential_prescription', '').lower()
|
375 |
+
if potential_med: interaction_checks_requested[potential_med] = call
|
376 |
+
valid_tool_calls_for_execution = []; blocked_ids = set()
|
|
|
|
|
|
|
|
|
377 |
for med_name, prescribe_call in prescriptions_requested.items():
|
378 |
if med_name not in interaction_checks_requested:
|
379 |
st.error(f"**Safety Violation:** AI attempted to prescribe '{med_name}' without requesting `check_drug_interactions` in the *same turn*. Prescription blocked.")
|
380 |
error_msg = ToolMessage(content=json.dumps({"status": "error", "message": f"Interaction check for '{med_name}' must be requested *before or alongside* the prescription call."}), tool_call_id=prescribe_call['id'], name=prescribe_call['name'])
|
381 |
+
tool_messages.append(error_msg); blocked_ids.add(prescribe_call['id'])
|
|
|
|
|
382 |
valid_tool_calls_for_execution = [call for call in tool_calls if call['id'] not in blocked_ids]
|
383 |
+
patient_data = state.get("patient_data", {}); patient_meds_full = patient_data.get("medications", {}).get("current", []); patient_allergies = patient_data.get("allergies", [])
|
|
|
|
|
|
|
|
|
|
|
384 |
for call in valid_tool_calls_for_execution:
|
385 |
if call['name'] == 'check_drug_interactions':
|
386 |
if 'args' not in call: call['args'] = {}
|
387 |
+
call['args']['current_medications'] = patient_meds_full; call['args']['allergies'] = patient_allergies; print(f"Augmented interaction check args for call ID {call['id']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
if valid_tool_calls_for_execution:
|
389 |
print(f"Attempting to execute {len(valid_tool_calls_for_execution)} tools: {[c['name'] for c in valid_tool_calls_for_execution]}")
|
390 |
try:
|
|
|
392 |
for call, resp in zip(valid_tool_calls_for_execution, responses):
|
393 |
tool_call_id = call['id']; tool_name = call['name']
|
394 |
if isinstance(resp, Exception):
|
395 |
+
error_type = type(resp).__name__; error_str = str(resp); print(f"ERROR executing tool '{tool_name}' (ID: {tool_call_id}): {error_type} - {error_str}"); traceback.print_exc()
|
396 |
+
st.error(f"Error executing action '{tool_name}': {error_type}"); error_content = json.dumps({"status": "error", "message": f"Failed to execute '{tool_name}': {error_type} - {error_str}"})
|
|
|
|
|
|
|
397 |
tool_messages.append(ToolMessage(content=error_content, tool_call_id=tool_call_id, name=tool_name))
|
398 |
+
if isinstance(resp, AttributeError) and "'dict' object has no attribute 'tool'" in error_str: print("\n *** DETECTED SPECIFIC ATTRIBUTE ERROR ('dict' object has no attribute 'tool') *** \n")
|
|
|
399 |
else:
|
400 |
+
print(f"Tool '{tool_name}' (ID: {tool_call_id}) executed successfully."); content_str = str(resp); tool_messages.append(ToolMessage(content=content_str, tool_call_id=tool_call_id, name=tool_name))
|
|
|
|
|
401 |
except Exception as e:
|
402 |
+
print(f"CRITICAL UNEXPECTED ERROR within tool_node logic: {type(e).__name__} - {e}"); traceback.print_exc(); st.error(f"Critical internal error processing actions: {e}")
|
403 |
+
error_content = json.dumps({"status": "error", "message": f"Internal error processing tools: {e}"}); processed_ids = {msg.tool_call_id for msg in tool_messages}
|
|
|
|
|
404 |
for call in valid_tool_calls_for_execution:
|
405 |
if call['id'] not in processed_ids: tool_messages.append(ToolMessage(content=error_content, tool_call_id=call['id'], name=call['name']))
|
406 |
+
print(f"Returning {len(tool_messages)} tool messages."); return {"messages": tool_messages}
|
407 |
|
408 |
+
# --- Graph Edges (Routing Logic) ---
|
|
|
|
|
|
|
|
|
409 |
def should_continue(state: AgentState) -> str:
|
410 |
+
print("\n---ROUTING DECISION---"); last_message = state['messages'][-1] if state['messages'] else None
|
|
|
|
|
411 |
if not isinstance(last_message, AIMessage): return "end_conversation_turn"
|
412 |
if "Sorry, an internal error occurred" in last_message.content: return "end_conversation_turn"
|
413 |
if getattr(last_message, 'tool_calls', None): return "continue_tools"
|
414 |
else: return "end_conversation_turn"
|
415 |
|
416 |
+
# --- Graph Definition & Compilation ---
|
417 |
+
workflow = StateGraph(AgentState); workflow.add_node("agent", agent_node); workflow.add_node("tools", tool_node)
|
418 |
+
workflow.set_entry_point("agent"); workflow.add_conditional_edges("agent", should_continue, {"continue_tools": "tools", "end_conversation_turn": END})
|
419 |
+
workflow.add_edge("tools", "agent"); app = workflow.compile(); print("LangGraph compiled successfully.")
|
|
|
|
|
|
|
|
|
|
|
420 |
|
421 |
# --- Streamlit UI ---
|
422 |
def main():
|
423 |
st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
|
424 |
st.title(f"π©Ί {ClinicalAppSettings.APP_TITLE}")
|
425 |
st.caption(f"Interactive Assistant | LangGraph/Groq/Tavily/UMLS/OpenFDA | Model: {ClinicalAppSettings.MODEL_NAME}")
|
|
|
|
|
426 |
if "messages" not in st.session_state: st.session_state.messages = []
|
427 |
if "patient_data" not in st.session_state: st.session_state.patient_data = None
|
428 |
if "graph_app" not in st.session_state: st.session_state.graph_app = app
|
429 |
|
430 |
+
# --- Patient Data Input Sidebar ---
|
431 |
with st.sidebar:
|
432 |
st.header("π Patient Intake Form")
|
433 |
+
# Input fields (Demographics, HPI, History, Meds/Allergies, Social/Family, Vitals/Exam)
|
434 |
+
st.subheader("Demographics"); age = st.number_input("Age", 0, 120, 55); sex = st.selectbox("Sex", ["Male", "Female", "Other"])
|
435 |
+
st.subheader("HPI"); chief_complaint = st.text_input("Chief Complaint", "Chest pain"); hpi_details = st.text_area("HPI Details", "55 y/o male...", height=150); symptoms = st.multiselect("Symptoms", ["Nausea", "Diaphoresis", "SOB", "Dizziness"], default=["Nausea", "Diaphoresis"])
|
436 |
+
st.subheader("History"); pmh = st.text_area("PMH", "HTN, HLD, DM2, MI"); psh = st.text_area("PSH", "Appendectomy")
|
437 |
+
st.subheader("Meds & Allergies"); current_meds_str = st.text_area("Current Meds", "Lisinopril 10mg daily\nMetformin 1000mg BID\nAtorvastatin 40mg daily"); allergies_str = st.text_area("Allergies", "Penicillin (rash), Sulfa")
|
438 |
+
st.subheader("Social/Family"); social_history = st.text_area("SH", "Smoker"); family_history = st.text_area("FHx", "Father MI")
|
439 |
+
st.subheader("Vitals & Exam"); col1, col2 = st.columns(2);
|
440 |
+
with col1: temp_c = st.number_input("Temp C", 35.0, 42.0, 36.8, format="%.1f"); hr_bpm = st.number_input("HR", 30, 250, 95); rr_rpm = st.number_input("RR", 5, 50, 18)
|
441 |
+
with col2: bp_mmhg = st.text_input("BP", "155/90"); spo2_percent = st.number_input("SpO2", 70, 100, 96); pain_scale = st.slider("Pain", 0, 10, 8)
|
442 |
+
exam_notes = st.text_area("Exam Notes", "Awake, alert...", height=100)
|
443 |
+
|
444 |
+
if st.button("Start/Update Consultation"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
|
446 |
+
current_med_names_only = [];
|
447 |
+
for med in current_meds_list: match = re.match(r"^\s*([a-zA-Z\-]+)", med);
|
448 |
+
if match: current_med_names_only.append(match.group(1).lower())
|
|
|
|
|
|
|
|
|
449 |
allergies_list = []
|
450 |
for a in allergies_str.split(','):
|
451 |
+
cleaned_allergy = a.strip();
|
452 |
+
if cleaned_allergy: match = re.match(r"^\s*([a-zA-Z\-\s/]+)(?:\s*\(.*\))?", cleaned_allergy); name_part = match.group(1).strip().lower() if match else cleaned_allergy.lower(); allergies_list.append(name_part)
|
453 |
+
st.session_state.patient_data = { "demographics": {"age": age, "sex": sex}, "hpi": {"chief_complaint": chief_complaint, "details": hpi_details, "symptoms": symptoms}, "pmh": {"conditions": pmh}, "psh": {"procedures": psh}, "medications": {"current": current_meds_list, "names_only": current_med_names_only}, "allergies": allergies_list, "social_history": {"details": social_history}, "family_history": {"details": family_history}, "vitals": { "temp_c": temp_c, "hr_bpm": hr_bpm, "bp_mmhg": bp_mmhg, "rr_rpm": rr_rpm, "spo2_percent": spo2_percent, "pain_scale": pain_scale}, "exam_findings": {"notes": exam_notes} }
|
454 |
+
red_flags = check_red_flags(st.session_state.patient_data); st.sidebar.markdown("---")
|
455 |
+
if red_flags: st.sidebar.warning("**Initial Red Flags:**"); [st.sidebar.warning(f"- {flag.replace('Red Flag: ','')}") for flag in red_flags]
|
456 |
+
else: st.sidebar.success("No immediate red flags.")
|
457 |
+
initial_prompt = "Initiate consultation for the patient described in the intake form. Review data and begin analysis."
|
458 |
+
st.session_state.messages = [HumanMessage(content=initial_prompt)]; st.success("Patient data loaded/updated.")
|
459 |
+
|
460 |
+
# --- Main Chat Interface Area ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
461 |
st.header("π¬ Clinical Consultation")
|
462 |
+
# Display loop - REMOVED key= ARGUMENT
|
463 |
+
for msg in st.session_state.messages:
|
|
|
|
|
|
|
464 |
if isinstance(msg, HumanMessage):
|
465 |
+
with st.chat_message("user"): st.markdown(msg.content) # No key
|
466 |
elif isinstance(msg, AIMessage):
|
467 |
+
with st.chat_message("assistant"): # No key
|
468 |
ai_content = msg.content; structured_output = None
|
469 |
try:
|
470 |
json_match = re.search(r"```json\s*(\{.*?\})\s*```", ai_content, re.DOTALL | re.IGNORECASE)
|
|
|
477 |
structured_output = json.loads(ai_content); ai_content = ""
|
478 |
else: st.markdown(ai_content)
|
479 |
except Exception as e: st.markdown(ai_content); print(f"Error parsing/displaying AI JSON: {e}")
|
|
|
480 |
if structured_output and isinstance(structured_output, dict):
|
481 |
+
st.divider(); st.subheader("π AI Analysis & Recommendations") # Display logic for JSON...
|
|
|
482 |
cols = st.columns(2)
|
483 |
+
with cols[0]: # Assessment, DDx, Risk
|
484 |
+
st.markdown("**Assessment:**"); st.markdown(f"> {structured_output.get('assessment', 'N/A')}")
|
485 |
+
st.markdown("**Differential Diagnosis:**"); ddx = structured_output.get('differential_diagnosis', []);
|
486 |
+
if ddx: [st.expander(f"{'π₯π₯π₯'[('High','Medium','Low').index(item.get('likelihood','Low')[0])] if item.get('likelihood','?')[0] in 'HML' else '?'} {item.get('diagnosis', 'Unknown')} ({item.get('likelihood','?')})").write(f"**Rationale:** {item.get('rationale', 'N/A')}") for item in ddx]
|
|
|
|
|
|
|
|
|
487 |
else: st.info("No DDx provided.")
|
488 |
+
st.markdown("**Risk Assessment:**"); risk = structured_output.get('risk_assessment', {}); flags=risk.get('identified_red_flags',[]); concerns=risk.get("immediate_concerns",[]); comps=risk.get("potential_complications",[])
|
|
|
489 |
if flags: st.warning(f"**Flags:** {', '.join(flags)}")
|
490 |
if concerns: st.warning(f"**Concerns:** {', '.join(concerns)}")
|
491 |
if comps: st.info(f"**Potential Complications:** {', '.join(comps)}")
|
492 |
if not flags and not concerns: st.success("No major risks highlighted.")
|
493 |
+
with cols[1]: # Plan
|
494 |
+
st.markdown("**Recommended Plan:**"); plan = structured_output.get('recommended_plan', {})
|
495 |
+
for section in ["investigations","therapeutics","consultations","patient_education"]: st.markdown(f"_{section.replace('_',' ').capitalize()}:_"); items = plan.get(section); [st.markdown(f"- {item}") for item in items] if items and isinstance(items, list) else (st.markdown(f"- {items}") if items else st.markdown("_None_")); st.markdown("")
|
496 |
+
st.markdown("**Rationale & Guideline Check:**"); st.markdown(f"> {structured_output.get('rationale_summary', 'N/A')}")
|
|
|
|
|
|
|
|
|
497 |
interaction_summary = structured_output.get("interaction_check_summary", "")
|
498 |
+
if interaction_summary: st.markdown("**Interaction Check Summary:**"); st.markdown(f"> {interaction_summary}")
|
499 |
st.divider()
|
|
|
500 |
if getattr(msg, 'tool_calls', None):
|
501 |
+
with st.expander("π οΈ AI requested actions", expanded=False): # Tool call display logic...
|
502 |
+
for tc in msg.tool_calls: try: st.code(f"Action: {tc.get('name', 'Unknown')}\nArgs: {json.dumps(tc.get('args', {}), indent=2)}", language="json")
|
503 |
+
except Exception as display_e: st.error(f"Could not display tool call: {display_e}"); st.code(str(tc))
|
|
|
|
|
504 |
elif isinstance(msg, ToolMessage):
|
505 |
tool_name_display = getattr(msg, 'name', 'tool_execution')
|
506 |
+
with st.chat_message(tool_name_display, avatar="π οΈ"): # No key
|
507 |
+
try: # Tool message display logic...
|
508 |
+
tool_data = json.loads(msg.content); status = tool_data.get("status", "info"); message = tool_data.get("message", msg.content); details = tool_data.get("details"); warnings = tool_data.get("warnings")
|
|
|
|
|
509 |
if status == "success" or status == "clear" or status == "flagged": st.success(f"{message}", icon="β
" if status != "flagged" else "π¨")
|
510 |
+
elif status == "warning": st.warning(f"{message}", icon="β οΈ");
|
511 |
+
if warnings and isinstance(warnings, list): st.caption("Details:"); [st.caption(f"- {warn}") for warn in warnings]
|
|
|
|
|
|
|
512 |
else: st.error(f"{message}", icon="β")
|
513 |
if details: st.caption(f"Details: {details}")
|
514 |
+
except json.JSONDecodeError: st.info(f"{msg.content}")
|
515 |
except Exception as e: st.error(f"Error displaying tool message: {e}", icon="β"); st.caption(f"Raw content: {msg.content}")
|
516 |
|
517 |
+
# --- Chat Input Logic ---
|
|
|
518 |
if prompt := st.chat_input("Your message or follow-up query..."):
|
519 |
+
if not st.session_state.patient_data: st.warning("Please load patient data first."); st.stop()
|
520 |
+
user_message = HumanMessage(content=prompt); st.session_state.messages.append(user_message)
|
521 |
+
with st.chat_message("user"): st.markdown(prompt) # Display user msg immediately
|
|
|
|
|
|
|
|
|
522 |
current_state = AgentState(messages=st.session_state.messages, patient_data=st.session_state.patient_data)
|
|
|
523 |
with st.spinner("SynapseAI is thinking..."):
|
524 |
try:
|
525 |
final_state = st.session_state.graph_app.invoke(current_state, {"recursion_limit": 15})
|
526 |
+
st.session_state.messages = final_state['messages'] # Update state with results
|
527 |
+
except Exception as e: print(f"CRITICAL ERROR: {e}"); traceback.print_exc(); st.error(f"Error: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
528 |
st.rerun() # Refresh display
|
529 |
|
530 |
+
# Disclaimer
|
531 |
+
st.markdown("---"); st.warning("**Disclaimer:** SynapseAI is for demonstration...")
|
|
|
532 |
|
533 |
if __name__ == "__main__":
|
534 |
main()
|