mgbam commited on
Commit
b731976
Β·
verified Β·
1 Parent(s): 23d48f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -540
app.py CHANGED
@@ -1,30 +1,26 @@
1
- # -*- coding: utf-8 -*-
2
  import streamlit as st
3
- import requests
4
  import json
5
  import re
6
  import os
7
- import operator
8
  import traceback
9
- from functools import lru_cache
10
  from dotenv import load_dotenv
11
 
12
- from langchain_groq import ChatGroq
13
- from langchain_community.tools.tavily_search import TavilySearchResults
14
- from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
15
- # from langchain_core.prompts import ChatPromptTemplate # Not explicitly used
16
- from langchain_core.pydantic_v1 import BaseModel, Field
17
- from langchain_core.tools import tool
18
- from langgraph.prebuilt import ToolExecutor
19
- from langgraph.graph import StateGraph, END
20
-
21
- from typing import Optional, List, Dict, Any, TypedDict, Annotated
22
 
23
  # --- Environment Variable Loading & Validation ---
24
  load_dotenv()
 
25
  UMLS_API_KEY = os.environ.get("UMLS_API_KEY")
26
  GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
27
  TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY")
 
28
  missing_keys = []
29
  if not UMLS_API_KEY:
30
  missing_keys.append("UMLS_API_KEY")
@@ -32,547 +28,97 @@ if not GROQ_API_KEY:
32
  missing_keys.append("GROQ_API_KEY")
33
  if not TAVILY_API_KEY:
34
  missing_keys.append("TAVILY_API_KEY")
 
35
  if missing_keys:
36
- st.error(f"Missing API Key(s): {', '.join(missing_keys)}.")
37
  st.stop()
38
 
39
- # --- Configuration & Constants ---
40
  class ClinicalAppSettings:
41
  APP_TITLE = "SynapseAI (UMLS/FDA Integrated)"
42
  PAGE_LAYOUT = "wide"
43
- MODEL_NAME = "llama3-70b-8192"
44
- TEMPERATURE = 0.1
45
- MAX_SEARCH_RESULTS = 3
46
-
47
- class ClinicalPrompts:
48
- SYSTEM_PROMPT = """
49
- You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation... [SYSTEM PROMPT REMAINS THE SAME - OMITTED FOR BREVITY]
50
- """
51
-
52
- # --- API Helper Functions (get_rxcui, get_openfda_label, search_text_list) ---
53
- UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
54
- RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
55
- OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
56
-
57
- @lru_cache(maxsize=256)
58
- def get_rxcui(drug_name: str) -> Optional[str]:
59
- if not drug_name or not isinstance(drug_name, str):
60
- return None
61
- drug_name = drug_name.strip()
62
- if not drug_name:
63
- return None
64
- print(f"RxNorm Lookup for: '{drug_name}'")
65
- try:
66
- params = {"name": drug_name, "search": 1}
67
- response = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10)
68
- response.raise_for_status()
69
- data = response.json()
70
- if data and "idGroup" in data and "rxnormId" in data["idGroup"]:
71
- rxcui = data["idGroup"]["rxnormId"][0]
72
- print(f" Found RxCUI: {rxcui} for '{drug_name}'")
73
- return rxcui
74
- else:
75
- params = {"name": drug_name}
76
- response = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10)
77
- response.raise_for_status()
78
- data = response.json()
79
- if data and "drugGroup" in data and "conceptGroup" in data["drugGroup"]:
80
- for group in data["drugGroup"]["conceptGroup"]:
81
- if group.get("tty") in ["SBD", "SCD", "GPCK", "BPCK", "IN", "MIN", "PIN"]:
82
- if "conceptProperties" in group and group["conceptProperties"]:
83
- rxcui = group["conceptProperties"][0].get("rxcui")
84
- if rxcui:
85
- print(f" Found RxCUI (via /drugs): {rxcui} for '{drug_name}'")
86
- return rxcui
87
- print(f" RxCUI not found for '{drug_name}'.")
88
- return None
89
- except requests.exceptions.RequestException as e:
90
- print(f" Error fetching RxCUI for '{drug_name}': {e}")
91
- return None
92
- except json.JSONDecodeError as e:
93
- print(f" Error decoding RxNorm JSON response for '{drug_name}': {e}")
94
- return None
95
- except Exception as e:
96
- print(f" Unexpected error in get_rxcui for '{drug_name}': {e}")
97
- return None
98
-
99
- @lru_cache(maxsize=128)
100
- def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[dict]:
101
- if not rxcui and not drug_name:
102
- return None
103
- print(f"OpenFDA Label Lookup for: RXCUI={rxcui}, Name={drug_name}")
104
- search_terms = []
105
- if rxcui:
106
- search_terms.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"')
107
- if drug_name:
108
- search_terms.append(f'(openfda.brand_name:"{drug_name.lower()}" OR openfda.generic_name:"{drug_name.lower()}")')
109
- search_query = " OR ".join(search_terms)
110
- params = {"search": search_query, "limit": 1}
111
- try:
112
- response = requests.get(OPENFDA_API_BASE, params=params, timeout=15)
113
- response.raise_for_status()
114
- data = response.json()
115
- if data and "results" in data and data["results"]:
116
- print(f" Found OpenFDA label for query: {search_query}")
117
- return data["results"][0]
118
- print(f" No OpenFDA label found for query: {search_query}")
119
- return None
120
- except requests.exceptions.RequestException as e:
121
- print(f" Error fetching OpenFDA label: {e}")
122
- return None
123
- except json.JSONDecodeError as e:
124
- print(f" Error decoding OpenFDA JSON response: {e}")
125
- return None
126
- except Exception as e:
127
- print(f" Unexpected error in get_openfda_label: {e}")
128
- return None
129
-
130
- def search_text_list(text_list: Optional[List[str]], search_terms: List[str]) -> List[str]:
131
- found_snippets = []
132
- if not text_list or not search_terms:
133
- return found_snippets
134
- search_terms_lower = [str(term).lower() for term in search_terms if term]
135
- for text_item in text_list:
136
- if not isinstance(text_item, str):
137
- continue
138
- text_item_lower = text_item.lower()
139
- for term in search_terms_lower:
140
- if term in text_item_lower:
141
- start_index = text_item_lower.find(term)
142
- snippet_start = max(0, start_index - 50)
143
- snippet_end = min(len(text_item), start_index + len(term) + 100)
144
- snippet = text_item[snippet_start:snippet_end]
145
- snippet = snippet.replace(term, f"**{term}**", 1)
146
- found_snippets.append(f"...{snippet}...")
147
- break
148
- return found_snippets
149
-
150
- # --- Other Helper Functions (parse_bp, check_red_flags, format_patient_data_for_prompt) ---
151
- def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
152
- if not isinstance(bp_string, str):
153
- return None
154
- match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string.strip())
155
- if match:
156
- return int(match.group(1)), int(match.group(2))
157
- return None
158
-
159
- def check_red_flags(patient_data: dict) -> List[str]:
160
- flags = []
161
- if not patient_data:
162
- return flags
163
- symptoms = patient_data.get("hpi", {}).get("symptoms", [])
164
- vitals = patient_data.get("vitals", {})
165
- history = patient_data.get("pmh", {}).get("conditions", "")
166
- symptoms_lower = [str(s).lower() for s in symptoms if isinstance(s, str)]
167
-
168
- if "chest pain" in symptoms_lower:
169
- flags.append("Red Flag: Chest Pain reported.")
170
- if "shortness of breath" in symptoms_lower:
171
- flags.append("Red Flag: Shortness of Breath reported.")
172
- if "severe headache" in symptoms_lower:
173
- flags.append("Red Flag: Severe Headache reported.")
174
- if "sudden vision loss" in symptoms_lower:
175
- flags.append("Red Flag: Sudden Vision Loss reported.")
176
- if "weakness on one side" in symptoms_lower:
177
- flags.append("Red Flag: Unilateral Weakness reported (potential stroke).")
178
- if "hemoptysis" in symptoms_lower:
179
- flags.append("Red Flag: Hemoptysis (coughing up blood).")
180
- if "syncope" in symptoms_lower:
181
- flags.append("Red Flag: Syncope (fainting).")
182
-
183
- if vitals:
184
- temp = vitals.get("temp_c")
185
- hr = vitals.get("hr_bpm")
186
- rr = vitals.get("rr_rpm")
187
- spo2 = vitals.get("spo2_percent")
188
- bp_str = vitals.get("bp_mmhg")
189
-
190
- if temp is not None and temp >= 38.5:
191
- flags.append(f"Red Flag: Fever ({temp}Β°C).")
192
- if hr is not None and hr >= 120:
193
- flags.append(f"Red Flag: Tachycardia ({hr} bpm).")
194
- if hr is not None and hr <= 50:
195
- flags.append(f"Red Flag: Bradycardia ({hr} bpm).")
196
- if rr is not None and rr >= 24:
197
- flags.append(f"Red Flag: Tachypnea ({rr} rpm).")
198
- if spo2 is not None and spo2 <= 92:
199
- flags.append(f"Red Flag: Hypoxia ({spo2}%).")
200
-
201
- if bp_str:
202
- bp = parse_bp(bp_str)
203
- if bp:
204
- if bp[0] >= 180 or bp[1] >= 110:
205
- flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {bp_str} mmHg).")
206
- if bp[0] <= 90 or bp[1] <= 60:
207
- flags.append(f"Red Flag: Hypotension (BP: {bp_str} mmHg).")
208
-
209
- if history and isinstance(history, str):
210
- history_lower = history.lower()
211
- if "history of mi" in history_lower and "chest pain" in symptoms_lower:
212
- flags.append("Red Flag: History of MI with current Chest Pain.")
213
- if "history of dvt/pe" in history_lower and "shortness of breath" in symptoms_lower:
214
- flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.")
215
-
216
- return list(set(flags))
217
-
218
- def format_patient_data_for_prompt(data: dict) -> str:
219
- if not data:
220
- return "No patient data provided."
221
- prompt_str = ""
222
- for key, value in data.items():
223
- section_title = key.replace('_', ' ').title()
224
- if isinstance(value, dict) and value:
225
- has_content = any(sub_value for sub_value in value.values())
226
- if has_content:
227
- prompt_str += f"**{section_title}:**\n"
228
- for sub_key, sub_value in value.items():
229
- if sub_value:
230
- prompt_str += f" - {sub_key.replace('_', ' ').title()}: {sub_value}\n"
231
- elif isinstance(value, list) and value:
232
- prompt_str += f"**{section_title}:** {', '.join(map(str, value))}\n"
233
- elif value and not isinstance(value, dict):
234
- prompt_str += f"**{section_title}:** {value}\n"
235
- return prompt_str.strip()
236
-
237
- # --- Tool Definitions ---
238
- class LabOrderInput(BaseModel):
239
- test_name: str = Field(...)
240
- reason: str = Field(...)
241
- priority: str = Field("Routine")
242
-
243
- class PrescriptionInput(BaseModel):
244
- medication_name: str = Field(...)
245
- dosage: str = Field(...)
246
- route: str = Field(...)
247
- frequency: str = Field(...)
248
- duration: str = Field("As directed")
249
- reason: str = Field(...)
250
-
251
- class InteractionCheckInput(BaseModel):
252
- potential_prescription: str = Field(...)
253
- current_medications: Optional[List[str]] = Field(None)
254
- allergies: Optional[List[str]] = Field(None)
255
-
256
- class FlagRiskInput(BaseModel):
257
- risk_description: str = Field(...)
258
- urgency: str = Field("High")
259
-
260
- @tool("order_lab_test", args_schema=LabOrderInput)
261
- def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
262
- """
263
- Orders a lab test with the specified test name, reason, and priority.
264
- """
265
- print(f"Executing order_lab_test: {test_name}, Reason: {reason}, Priority: {priority}")
266
- return json.dumps({
267
- "status": "success",
268
- "message": f"Lab Ordered: {test_name} ({priority})",
269
- "details": f"Reason: {reason}"
270
- })
271
-
272
- @tool("prescribe_medication", args_schema=PrescriptionInput)
273
- def prescribe_medication(medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str) -> str:
274
- """
275
- Prepares a prescription for the specified medication including dosage, route, frequency, and duration.
276
- """
277
- print(f"Executing prescribe_medication: {medication_name} {dosage}...")
278
- return json.dumps({
279
- "status": "success",
280
- "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}",
281
- "details": f"Duration: {duration}. Reason: {reason}"
282
- })
283
-
284
- @tool("check_drug_interactions", args_schema=InteractionCheckInput)
285
- def check_drug_interactions(potential_prescription: str, current_medications: Optional[List[str]] = None, allergies: Optional[List[str]] = None) -> str:
286
- """
287
- Checks for potential drug interactions and allergy risks for the given prescription.
288
- """
289
- print(f"\n--- Executing REAL check_drug_interactions ---")
290
- print(f"Checking potential prescription: '{potential_prescription}'")
291
- warnings = []
292
- potential_med_lower = potential_prescription.lower().strip()
293
- current_meds_list = current_medications or []
294
- allergies_list = allergies or []
295
- current_med_names_lower = []
296
- for med in current_meds_list:
297
- match = re.match(r"^\s*([a-zA-Z\-]+)", str(med))
298
- if match:
299
- current_med_names_lower.append(match.group(1).lower())
300
- allergies_lower = [str(a).lower().strip() for a in allergies_list if a]
301
- print(f" Against Current Meds (names): {current_med_names_lower}")
302
- print(f" Against Allergies: {allergies_lower}")
303
- print(f" Step 1: Normalizing '{potential_prescription}'...")
304
- potential_rxcui = get_rxcui(potential_prescription)
305
- potential_label = get_openfda_label(rxcui=potential_rxcui, drug_name=potential_prescription)
306
- if not potential_rxcui and not potential_label:
307
- warnings.append(f"INFO: Could not reliably identify '{potential_prescription}'. Checks may be incomplete.")
308
- print(" Step 2: Performing Allergy Check...")
309
- for allergy in allergies_lower:
310
- if allergy == potential_med_lower:
311
- warnings.append(f"CRITICAL ALLERGY (Name Match): Patient allergic to '{allergy}'. Potential prescription is '{potential_prescription}'.")
312
- elif allergy in ["penicillin", "pcns"] and potential_med_lower in ["amoxicillin", "ampicillin", "augmentin", "piperacillin"]:
313
- warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to Penicillin. High risk with '{potential_prescription}'.")
314
- elif allergy == "sulfa" and potential_med_lower in ["sulfamethoxazole", "bactrim", "sulfasalazine"]:
315
- warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to Sulfa. High risk with '{potential_prescription}'.")
316
- elif allergy in ["nsaids", "aspirin"] and potential_med_lower in ["ibuprofen", "naproxen", "ketorolac", "diclofenac"]:
317
- warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to NSAIDs/Aspirin. Risk with '{potential_prescription}'.")
318
- if potential_label:
319
- contraindications = potential_label.get("contraindications")
320
- warnings_section = potential_label.get("warnings_and_cautions") or potential_label.get("warnings")
321
- if contraindications:
322
- allergy_mentions_ci = search_text_list(contraindications, allergies_lower)
323
- if allergy_mentions_ci:
324
- warnings.append(f"ALLERGY RISK (Contraindication Found): Label for '{potential_prescription}' mentions contraindication potentially related to patient allergies: {'; '.join(allergy_mentions_ci)}")
325
- if warnings_section:
326
- allergy_mentions_warn = search_text_list(warnings_section, allergies_lower)
327
- if allergy_mentions_warn:
328
- warnings.append(f"ALLERGY RISK (Warning Found): Label for '{potential_prescription}' mentions warnings potentially related to patient allergies: {'; '.join(allergy_mentions_warn)}")
329
- print(" Step 3: Performing Drug-Drug Interaction Check...")
330
- if potential_rxcui or potential_label:
331
- for current_med_name in current_med_names_lower:
332
- if not current_med_name or current_med_name == potential_med_lower:
333
- continue
334
- print(f" Checking interaction between '{potential_prescription}' and '{current_med_name}'...")
335
- current_rxcui = get_rxcui(current_med_name)
336
- current_label = get_openfda_label(rxcui=current_rxcui, drug_name=current_med_name)
337
- search_terms_for_current = [current_med_name]
338
- if current_rxcui:
339
- search_terms_for_current.append(current_rxcui)
340
- search_terms_for_potential = [potential_med_lower]
341
- if potential_rxcui:
342
- search_terms_for_potential.append(potential_rxcui)
343
- interaction_found_flag = False
344
- if potential_label and potential_label.get("drug_interactions"):
345
- interaction_mentions = search_text_list(potential_label.get("drug_interactions"), search_terms_for_current)
346
- if interaction_mentions:
347
- warnings.append(f"Potential Interaction ({potential_prescription.capitalize()} Label): Mentions '{current_med_name.capitalize()}'. Snippets: {'; '.join(interaction_mentions)}")
348
- interaction_found_flag = True
349
- if current_label and current_label.get("drug_interactions") and not interaction_found_flag:
350
- interaction_mentions = search_text_list(current_label.get("drug_interactions"), search_terms_for_potential)
351
- if interaction_mentions:
352
- warnings.append(f"Potential Interaction ({current_med_name.capitalize()} Label): Mentions '{potential_prescription.capitalize()}'. Snippets: {'; '.join(interaction_mentions)}")
353
- else:
354
- warnings.append(f"INFO: Drug-drug interaction check skipped for '{potential_prescription}' as it could not be identified via RxNorm/OpenFDA.")
355
- final_warnings = list(set(warnings))
356
- status = "warning" if any("CRITICAL" in w or "Interaction" in w or "RISK" in w for w in final_warnings) else "clear"
357
- if not final_warnings:
358
- status = "clear"
359
- message = f"Interaction/Allergy check for '{potential_prescription}': {len(final_warnings)} potential issue(s) identified using RxNorm/OpenFDA." if final_warnings else f"No major interactions or allergy issues identified for '{potential_prescription}' based on RxNorm/OpenFDA lookup."
360
- print(f"--- Interaction Check Complete for '{potential_prescription}' ---")
361
- return json.dumps({
362
- "status": status,
363
- "message": message,
364
- "warnings": final_warnings
365
- })
366
-
367
- @tool("flag_risk", args_schema=FlagRiskInput)
368
- def flag_risk(risk_description: str, urgency: str) -> str:
369
- """
370
- Flags a clinical risk with the provided description and urgency.
371
- """
372
- print(f"Executing flag_risk: {risk_description}, Urgency: {urgency}")
373
- st.error(f"🚨 **{urgency.upper()} RISK FLAGGED by AI:** {risk_description}", icon="🚨")
374
- return json.dumps({
375
- "status": "flagged",
376
- "message": f"Risk '{risk_description}' flagged with {urgency} urgency."
377
- })
378
-
379
- search_tool = TavilySearchResults(max_results=ClinicalAppSettings.MAX_SEARCH_RESULTS, name="tavily_search_results")
380
-
381
- # --- LangGraph Setup ---
382
- class AgentState(TypedDict):
383
- messages: Annotated[list[Any], operator.add]
384
- patient_data: Optional[dict]
385
-
386
- tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool]
387
- tool_executor = ToolExecutor(tools)
388
- model = ChatGroq(temperature=ClinicalAppSettings.TEMPERATURE, model=ClinicalAppSettings.MODEL_NAME)
389
- model_with_tools = model.bind_tools(tools)
390
-
391
- # --- Graph Nodes (agent_node, tool_node) ---
392
- def agent_node(state: AgentState):
393
- print("\n---AGENT NODE---")
394
- current_messages = state['messages']
395
- if not current_messages or not isinstance(current_messages[0], SystemMessage):
396
- print("Prepending System Prompt.")
397
- current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages
398
- print(f"Invoking LLM with {len(current_messages)} messages.")
399
- try:
400
- response = model_with_tools.invoke(current_messages)
401
- print(f"Agent Raw Response Type: {type(response)}")
402
- if hasattr(response, 'tool_calls') and response.tool_calls:
403
- print(f"Agent Response Tool Calls: {response.tool_calls}")
404
- else:
405
- print("Agent Response: No tool calls.")
406
- except Exception as e:
407
- print(f"ERROR in agent_node: {e}")
408
- traceback.print_exc()
409
- error_message = AIMessage(content=f"Error: {e}")
410
- return {"messages": [error_message]}
411
- return {"messages": [response]}
412
-
413
- def tool_node(state: AgentState):
414
- print("\n---TOOL NODE---")
415
- tool_messages = []
416
- last_message = state['messages'][-1]
417
- if not isinstance(last_message, AIMessage) or not getattr(last_message, 'tool_calls', None):
418
- print("Warning: Tool node called unexpectedly.")
419
- return {"messages": []}
420
- tool_calls = last_message.tool_calls
421
- print(f"Tool calls received: {json.dumps(tool_calls, indent=2)}")
422
- prescriptions_requested = {}
423
- interaction_checks_requested = {}
424
- for call in tool_calls:
425
- tool_name = call.get('name')
426
- tool_args = call.get('args', {})
427
- if tool_name == 'prescribe_medication':
428
- med_name = tool_args.get('medication_name', '').lower()
429
- if med_name:
430
- prescriptions_requested[med_name] = call
431
- elif tool_name == 'check_drug_interactions':
432
- potential_med = tool_args.get('potential_prescription', '').lower()
433
- if potential_med:
434
- interaction_checks_requested[potential_med] = call
435
- valid_tool_calls_for_execution = []
436
- blocked_ids = set()
437
- for med_name, prescribe_call in prescriptions_requested.items():
438
- if med_name not in interaction_checks_requested:
439
- st.error(f"**Safety Violation:** AI tried to prescribe '{med_name}' without check.")
440
- error_msg = ToolMessage(
441
- content=json.dumps({
442
- "status": "error",
443
- "message": f"Interaction check needed for '{med_name}'."
444
- }),
445
- tool_call_id=prescribe_call['id'],
446
- name=prescribe_call['name']
447
- )
448
- tool_messages.append(error_msg)
449
- blocked_ids.add(prescribe_call['id'])
450
- valid_tool_calls_for_execution = [call for call in tool_calls if call['id'] not in blocked_ids]
451
- patient_data = state.get("patient_data", {})
452
- patient_meds_full = patient_data.get("medications", {}).get("current", [])
453
- patient_allergies = patient_data.get("allergies", [])
454
- for call in valid_tool_calls_for_execution:
455
- if call['name'] == 'check_drug_interactions':
456
- if 'args' not in call:
457
- call['args'] = {}
458
- call['args']['current_medications'] = patient_meds_full
459
- call['args']['allergies'] = patient_allergies
460
- print(f"Augmented interaction check args for call ID {call['id']}")
461
- if valid_tool_calls_for_execution:
462
- print(f"Attempting execution: {[c['name'] for c in valid_tool_calls_for_execution]}")
463
- try:
464
- responses = tool_executor.batch(valid_tool_calls_for_execution, return_exceptions=True)
465
- for call, resp in zip(valid_tool_calls_for_execution, responses):
466
- tool_call_id = call['id']
467
- tool_name = call['name']
468
- if isinstance(resp, Exception):
469
- error_type = type(resp).__name__
470
- error_str = str(resp)
471
- print(f"ERROR executing tool '{tool_name}': {error_type} - {error_str}")
472
- traceback.print_exc()
473
- st.error(f"Error: {error_type}")
474
- error_content = json.dumps({
475
- "status": "error",
476
- "message": f"Failed: {error_type} - {error_str}"
477
- })
478
- tool_messages.append(ToolMessage(content=error_content, tool_call_id=tool_call_id, name=tool_name))
479
- else:
480
- print(f"Tool '{tool_name}' executed.")
481
- # If response is a dict, convert it to a JSON string.
482
- if isinstance(resp, dict):
483
- content_str = json.dumps(resp)
484
- else:
485
- content_str = str(resp)
486
- tool_messages.append(ToolMessage(content=content_str, tool_call_id=tool_call_id, name=tool_name))
487
- except Exception as e:
488
- print(f"CRITICAL TOOL NODE ERROR: {e}")
489
- traceback.print_exc()
490
- st.error(f"Critical error: {e}")
491
- error_content = json.dumps({
492
- "status": "error",
493
- "message": f"Internal error: {e}"
494
- })
495
- processed_ids = {msg.tool_call_id for msg in tool_messages}
496
- [tool_messages.append(ToolMessage(content=error_content, tool_call_id=call['id'], name=call['name']))
497
- for call in valid_tool_calls_for_execution if call['id'] not in processed_ids]
498
- print(f"Returning {len(tool_messages)} tool messages.")
499
- return {"messages": tool_messages}
500
-
501
- # --- Graph Edges (Routing Logic) ---
502
- def should_continue(state: AgentState) -> str:
503
- print("\n---ROUTING DECISION---")
504
- last_message = state['messages'][-1] if state['messages'] else None
505
- if not isinstance(last_message, AIMessage):
506
- return "end_conversation_turn"
507
- if "Sorry, an internal error occurred" in last_message.content:
508
- return "end_conversation_turn"
509
- if getattr(last_message, 'tool_calls', None):
510
- return "continue_tools"
511
- else:
512
- return "end_conversation_turn"
513
-
514
- # --- Graph Definition & Compilation ---
515
- workflow = StateGraph(AgentState)
516
- workflow.add_node("agent", agent_node)
517
- workflow.add_node("tools", tool_node)
518
- workflow.set_entry_point("agent")
519
- workflow.add_conditional_edges("agent", should_continue, {"continue_tools": "tools", "end_conversation_turn": END})
520
- workflow.add_edge("tools", "agent")
521
- app = workflow.compile()
522
- print("LangGraph compiled successfully.")
523
 
524
  # --- Streamlit UI ---
525
  def main():
526
  st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
527
  st.title(f"🩺 {ClinicalAppSettings.APP_TITLE}")
528
- st.caption(f"Interactive Assistant | LangGraph/Groq/Tavily/UMLS/OpenFDA | Model: {ClinicalAppSettings.MODEL_NAME}")
 
 
529
  if "messages" not in st.session_state:
530
  st.session_state.messages = []
531
  if "patient_data" not in st.session_state:
532
  st.session_state.patient_data = None
533
- if "graph_app" not in st.session_state:
534
- st.session_state.graph_app = app
 
 
 
 
 
 
 
 
 
 
 
535
 
536
  # --- Patient Data Input Sidebar ---
537
  with st.sidebar:
538
  st.header("πŸ“„ Patient Intake Form")
539
  # Input fields... (Using shorter versions for brevity, assume full fields are here)
540
  st.subheader("Demographics")
541
- age = st.number_input("Age", 0, 120, 55)
542
- sex = st.selectbox("Sex", ["Male", "Female", "Other"])
 
543
  st.subheader("HPI")
544
- chief_complaint = st.text_input("Chief Complaint", "Chest pain")
545
- hpi_details = st.text_area("HPI Details", "55 y/o male...", height=100)
546
- symptoms = st.multiselect("Symptoms", ["Nausea", "Diaphoresis", "SOB", "Dizziness"], default=["Nausea", "Diaphoresis"])
 
 
 
 
 
 
547
  st.subheader("History")
548
- pmh = st.text_area("PMH", "HTN, HLD, DM2, History of MI")
549
- psh = st.text_area("PSH", "Appendectomy")
 
550
  st.subheader("Meds & Allergies")
551
- current_meds_str = st.text_area("Current Meds", "Lisinopril 10mg daily\nMetformin 1000mg BID")
552
- allergies_str = st.text_area("Allergies", "Penicillin (rash)")
 
 
 
 
 
553
  st.subheader("Social/Family")
554
- social_history = st.text_area("SH", "Smoker")
555
- family_history = st.text_area("FHx", "Father MI")
 
556
  st.subheader("Vitals & Exam")
557
  col1, col2 = st.columns(2)
558
  with col1:
559
- temp_c = st.number_input("Temp C", 35.0, 42.0, 36.8, format="%.1f")
560
- hr_bpm = st.number_input("HR", 30, 250, 95)
561
- rr_rpm = st.number_input("RR", 5, 50, 18)
562
  with col2:
563
- bp_mmhg = st.text_input("BP", "155/90")
564
- spo2_percent = st.number_input("SpO2", 70, 100, 96)
565
- pain_scale = st.slider("Pain", 0, 10, 8)
566
- # Updated height to at least 68 pixels as required by Streamlit
567
- exam_notes = st.text_area("Exam Notes", "Awake, alert...", height=68)
568
 
569
- if st.button("Start/Update Consultation"):
 
570
  current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
571
  current_med_names_only = []
572
  for med in current_meds_list:
573
  match = re.match(r"^\s*([a-zA-Z\-]+)", med)
574
  if match:
575
  current_med_names_only.append(match.group(1).lower())
 
576
  allergies_list = []
577
  for a in allergies_str.split(','):
578
  cleaned_allergy = a.strip()
@@ -580,6 +126,8 @@ def main():
580
  match = re.match(r"^\s*([a-zA-Z\-\s/]+)(?:\s*\(.*\))?", cleaned_allergy)
581
  name_part = match.group(1).strip().lower() if match else cleaned_allergy.lower()
582
  allergies_list.append(name_part)
 
 
583
  st.session_state.patient_data = {
584
  "demographics": {"age": age, "sex": sex},
585
  "hpi": {"chief_complaint": chief_complaint, "details": hpi_details, "symptoms": symptoms},
@@ -599,16 +147,24 @@ def main():
599
  },
600
  "exam_findings": {"notes": exam_notes}
601
  }
 
 
602
  red_flags = check_red_flags(st.session_state.patient_data)
603
  st.sidebar.markdown("---")
604
  if red_flags:
605
  st.sidebar.warning("**Initial Red Flags:**")
606
- [st.sidebar.warning(f"- {flag.replace('Red Flag: ','')}") for flag in red_flags]
 
607
  else:
608
  st.sidebar.success("No immediate red flags.")
 
 
609
  initial_prompt = "Initiate consultation. Review patient data and begin analysis."
610
  st.session_state.messages = [HumanMessage(content=initial_prompt)]
 
611
  st.success("Patient data loaded/updated.")
 
 
612
 
613
  # --- Main Chat Interface Area ---
614
  st.header("πŸ’¬ Clinical Consultation")
@@ -616,12 +172,13 @@ def main():
616
  for msg in st.session_state.messages:
617
  if isinstance(msg, HumanMessage):
618
  with st.chat_message("user"):
619
- st.markdown(msg.content) # No key
620
  elif isinstance(msg, AIMessage):
621
  with st.chat_message("assistant"):
622
  ai_content = msg.content
623
  structured_output = None
624
  try:
 
625
  json_match = re.search(r"```json\s*(\{.*?\})\s*```", ai_content, re.DOTALL | re.IGNORECASE)
626
  if json_match:
627
  json_str = json_match.group(1)
@@ -636,11 +193,13 @@ def main():
636
  structured_output = json.loads(ai_content)
637
  ai_content = ""
638
  else:
639
- st.markdown(ai_content)
640
  except Exception as e:
641
  st.markdown(ai_content)
642
  print(f"Error parsing/displaying AI JSON: {e}")
 
643
  if structured_output and isinstance(structured_output, dict):
 
644
  st.divider()
645
  st.subheader("πŸ“Š AI Analysis & Recommendations")
646
  cols = st.columns(2)
@@ -650,7 +209,15 @@ def main():
650
  st.markdown("**Differential Diagnosis:**")
651
  ddx = structured_output.get('differential_diagnosis', [])
652
  if ddx:
653
- [st.expander(f"{'πŸ₯‡πŸ₯ˆπŸ₯‰'[('High','Medium','Low').index(item.get('likelihood','Low')[0])] if item.get('likelihood','?')[0] in 'HML' else '?'} {item.get('diagnosis', 'Unknown')} ({item.get('likelihood','?')})").write(f"**Rationale:** {item.get('rationale', 'N/A')}") for item in ddx]
 
 
 
 
 
 
 
 
654
  else:
655
  st.info("No DDx provided.")
656
  st.markdown("**Risk Assessment:**")
@@ -669,11 +236,12 @@ def main():
669
  with cols[1]:
670
  st.markdown("**Recommended Plan:**")
671
  plan = structured_output.get('recommended_plan', {})
672
- for section in ["investigations","therapeutics","consultations","patient_education"]:
673
- st.markdown(f"_{section.replace('_',' ').capitalize()}:_")
674
  items = plan.get(section)
675
  if items and isinstance(items, list):
676
- [st.markdown(f"- {item}") for item in items]
 
677
  elif items:
678
  st.markdown(f"- {items}")
679
  else:
@@ -687,35 +255,44 @@ def main():
687
  st.markdown(f"> {interaction_summary}")
688
  st.divider()
689
 
 
690
  if getattr(msg, 'tool_calls', None):
691
  with st.expander("πŸ› οΈ AI requested actions", expanded=False):
692
  if msg.tool_calls:
693
  for tc in msg.tool_calls:
694
  try:
695
- st.code(f"Action: {tc.get('name', 'Unknown Tool')}\nArgs: {json.dumps(tc.get('args', {}), indent=2)}", language="json")
 
 
 
696
  except Exception as display_e:
697
- st.error(f"Could not display tool call arguments properly: {display_e}", icon="⚠️")
698
  st.code(f"Action: {tc.get('name', 'Unknown Tool')}\nRaw Args: {tc.get('args')}")
699
  else:
700
- st.caption("_No actions requested in this turn._")
701
  elif isinstance(msg, ToolMessage):
702
  tool_name_display = getattr(msg, 'name', 'tool_execution')
703
  with st.chat_message(tool_name_display, avatar="πŸ› οΈ"):
704
  try:
 
705
  tool_data = json.loads(msg.content)
706
  status = tool_data.get("status", "info")
707
  message = tool_data.get("message", msg.content)
708
  details = tool_data.get("details")
709
  warnings = tool_data.get("warnings")
710
- if status in ["success", "clear", "flagged"]:
711
- st.success(f"{message}", icon="βœ…" if status != "flagged" else "🚨")
 
 
 
712
  elif status == "warning":
713
  st.warning(f"{message}", icon="⚠️")
714
- if warnings and isinstance(warnings, list):
715
- st.caption("Details:")
716
- [st.caption(f"- {warn}") for warn in warnings]
717
  else:
718
  st.error(f"{message}", icon="❌")
 
 
 
 
719
  if details:
720
  st.caption(f"Details: {details}")
721
  except json.JSONDecodeError:
@@ -729,19 +306,42 @@ def main():
729
  if not st.session_state.patient_data:
730
  st.warning("Please load patient data first.")
731
  st.stop()
 
 
 
 
 
732
  user_message = HumanMessage(content=prompt)
733
  st.session_state.messages.append(user_message)
734
  with st.chat_message("user"):
735
  st.markdown(prompt)
736
- current_state = AgentState(messages=st.session_state.messages, patient_data=st.session_state.patient_data)
737
- with st.spinner("SynapseAI is thinking..."):
 
 
 
 
 
 
 
 
 
738
  try:
739
- final_state = st.session_state.graph_app.invoke(current_state, {"recursion_limit": 15})
740
- st.session_state.messages = final_state['messages']
 
 
 
 
 
741
  except Exception as e:
742
- print(f"CRITICAL ERROR: {e}")
743
  traceback.print_exc()
744
- st.error(f"Error: {e}")
 
 
 
 
745
  st.rerun()
746
 
747
  # Disclaimer
 
1
+ # app.py
2
  import streamlit as st
 
3
  import json
4
  import re
5
  import os
 
6
  import traceback
 
7
  from dotenv import load_dotenv
8
 
9
+ # Import agent logic and message types from agent.py
10
+ try:
11
+ from agent import ClinicalAgent, AgentState, check_red_flags
12
+ from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
13
+ except ImportError as e:
14
+ st.error(f"Failed to import from agent.py: {e}. Make sure agent.py is in the same directory.")
15
+ st.stop()
 
 
 
16
 
17
  # --- Environment Variable Loading & Validation ---
18
  load_dotenv()
19
+ # Check keys required by agent.py are present before initializing the agent
20
  UMLS_API_KEY = os.environ.get("UMLS_API_KEY")
21
  GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
22
  TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY")
23
+
24
  missing_keys = []
25
  if not UMLS_API_KEY:
26
  missing_keys.append("UMLS_API_KEY")
 
28
  missing_keys.append("GROQ_API_KEY")
29
  if not TAVILY_API_KEY:
30
  missing_keys.append("TAVILY_API_KEY")
31
+
32
  if missing_keys:
33
+ st.error(f"Missing required API Key(s): {', '.join(missing_keys)}. Please set them in Hugging Face Space Secrets or environment variables.")
34
  st.stop()
35
 
36
+ # --- App Configuration ---
37
  class ClinicalAppSettings:
38
  APP_TITLE = "SynapseAI (UMLS/FDA Integrated)"
39
  PAGE_LAYOUT = "wide"
40
+ MODEL_NAME_DISPLAY = "Llama3-70b (via Groq)" # Defined in agent.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # --- Streamlit UI ---
43
  def main():
44
  st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
45
  st.title(f"🩺 {ClinicalAppSettings.APP_TITLE}")
46
+ st.caption(f"Interactive Assistant | LangGraph/Groq/Tavily/UMLS/OpenFDA | Model: {ClinicalAppSettings.MODEL_NAME_DISPLAY}")
47
+
48
+ # Initialize session state
49
  if "messages" not in st.session_state:
50
  st.session_state.messages = []
51
  if "patient_data" not in st.session_state:
52
  st.session_state.patient_data = None
53
+ if "summary" not in st.session_state:
54
+ st.session_state.summary = None
55
+
56
+ # Initialize the agent instance only once
57
+ if "agent" not in st.session_state:
58
+ try:
59
+ st.session_state.agent = ClinicalAgent()
60
+ print("ClinicalAgent successfully initialized in Streamlit session state.")
61
+ except Exception as e:
62
+ st.error(f"Failed to initialize Clinical Agent: {e}. Check API keys and dependencies.")
63
+ print(f"ERROR Initializing ClinicalAgent: {e}")
64
+ traceback.print_exc()
65
+ st.stop()
66
 
67
  # --- Patient Data Input Sidebar ---
68
  with st.sidebar:
69
  st.header("πŸ“„ Patient Intake Form")
70
  # Input fields... (Using shorter versions for brevity, assume full fields are here)
71
  st.subheader("Demographics")
72
+ age = st.number_input("Age", 0, 120, 55, key="sb_age")
73
+ sex = st.selectbox("Sex", ["Male", "Female", "Other"], key="sb_sex")
74
+
75
  st.subheader("HPI")
76
+ chief_complaint = st.text_input("Chief Complaint", "Chest pain", key="sb_cc")
77
+ hpi_details = st.text_area("HPI Details", "55 y/o male...", height=100, key="sb_hpi")
78
+ symptoms = st.multiselect(
79
+ "Symptoms",
80
+ ["Nausea", "Diaphoresis", "SOB", "Dizziness", "Severe Headache", "Syncope", "Hemoptysis"],
81
+ default=["Nausea", "Diaphoresis"],
82
+ key="sb_sym"
83
+ )
84
+
85
  st.subheader("History")
86
+ pmh = st.text_area("PMH", "HTN, HLD, DM2, History of MI", key="sb_pmh")
87
+ psh = st.text_area("PSH", "Appendectomy", key="sb_psh")
88
+
89
  st.subheader("Meds & Allergies")
90
+ current_meds_str = st.text_area(
91
+ "Current Meds",
92
+ "Lisinopril 10mg daily\nMetformin 1000mg BID\nWarfarin 5mg daily",
93
+ key="sb_meds"
94
+ )
95
+ allergies_str = st.text_area("Allergies", "Penicillin (rash), Aspirin", key="sb_allergies") # Added Warfarin/Aspirin for testing
96
+
97
  st.subheader("Social/Family")
98
+ social_history = st.text_area("SH", "Smoker", key="sb_sh")
99
+ family_history = st.text_area("FHx", "Father MI", key="sb_fhx")
100
+
101
  st.subheader("Vitals & Exam")
102
  col1, col2 = st.columns(2)
103
  with col1:
104
+ temp_c = st.number_input("Temp C", 35.0, 42.0, 36.8, format="%.1f", key="sb_temp")
105
+ hr_bpm = st.number_input("HR", 30, 250, 95, key="sb_hr")
106
+ rr_rpm = st.number_input("RR", 5, 50, 18, key="sb_rr")
107
  with col2:
108
+ bp_mmhg = st.text_input("BP", "155/90", key="sb_bp")
109
+ spo2_percent = st.number_input("SpO2", 70, 100, 96, key="sb_spo2")
110
+ pain_scale = st.slider("Pain", 0, 10, 8, key="sb_pain")
111
+ exam_notes = st.text_area("Exam Notes", "Awake, alert...", height=50, key="sb_exam")
 
112
 
113
+ if st.button("Start/Update Consultation", key="sb_start"):
114
+ # Compile data...
115
  current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
116
  current_med_names_only = []
117
  for med in current_meds_list:
118
  match = re.match(r"^\s*([a-zA-Z\-]+)", med)
119
  if match:
120
  current_med_names_only.append(match.group(1).lower())
121
+
122
  allergies_list = []
123
  for a in allergies_str.split(','):
124
  cleaned_allergy = a.strip()
 
126
  match = re.match(r"^\s*([a-zA-Z\-\s/]+)(?:\s*\(.*\))?", cleaned_allergy)
127
  name_part = match.group(1).strip().lower() if match else cleaned_allergy.lower()
128
  allergies_list.append(name_part)
129
+
130
+ # Update patient data in session state
131
  st.session_state.patient_data = {
132
  "demographics": {"age": age, "sex": sex},
133
  "hpi": {"chief_complaint": chief_complaint, "details": hpi_details, "symptoms": symptoms},
 
147
  },
148
  "exam_findings": {"notes": exam_notes}
149
  }
150
+
151
+ # Call check_red_flags from agent module
152
  red_flags = check_red_flags(st.session_state.patient_data)
153
  st.sidebar.markdown("---")
154
  if red_flags:
155
  st.sidebar.warning("**Initial Red Flags:**")
156
+ for flag in red_flags:
157
+ st.sidebar.warning(f"- {flag.replace('Red Flag: ', '')}")
158
  else:
159
  st.sidebar.success("No immediate red flags.")
160
+
161
+ # Reset conversation and summary on new intake
162
  initial_prompt = "Initiate consultation. Review patient data and begin analysis."
163
  st.session_state.messages = [HumanMessage(content=initial_prompt)]
164
+ st.session_state.summary = None # Reset summary
165
  st.success("Patient data loaded/updated.")
166
+ # Rerun might be needed if the main area should clear or update based on new data
167
+ st.rerun()
168
 
169
  # --- Main Chat Interface Area ---
170
  st.header("πŸ’¬ Clinical Consultation")
 
172
  for msg in st.session_state.messages:
173
  if isinstance(msg, HumanMessage):
174
  with st.chat_message("user"):
175
+ st.markdown(msg.content)
176
  elif isinstance(msg, AIMessage):
177
  with st.chat_message("assistant"):
178
  ai_content = msg.content
179
  structured_output = None
180
  try:
181
+ # JSON Parsing logic...
182
  json_match = re.search(r"```json\s*(\{.*?\})\s*```", ai_content, re.DOTALL | re.IGNORECASE)
183
  if json_match:
184
  json_str = json_match.group(1)
 
193
  structured_output = json.loads(ai_content)
194
  ai_content = ""
195
  else:
196
+ st.markdown(ai_content) # Display non-JSON content
197
  except Exception as e:
198
  st.markdown(ai_content)
199
  print(f"Error parsing/displaying AI JSON: {e}")
200
+
201
  if structured_output and isinstance(structured_output, dict):
202
+ # Structured JSON display logic...
203
  st.divider()
204
  st.subheader("πŸ“Š AI Analysis & Recommendations")
205
  cols = st.columns(2)
 
209
  st.markdown("**Differential Diagnosis:**")
210
  ddx = structured_output.get('differential_diagnosis', [])
211
  if ddx:
212
+ for item in ddx:
213
+ likelihood = item.get('likelihood', 'Low')
214
+ if likelihood and likelihood[0] in 'HML':
215
+ medal = "πŸ₯‡" if likelihood[0] == 'H' else "πŸ₯ˆ" if likelihood[0] == 'M' else "πŸ₯‰"
216
+ else:
217
+ medal = "?"
218
+ expander_title = f"{medal} {item.get('diagnosis', 'Unknown')} ({likelihood})"
219
+ with st.expander(expander_title):
220
+ st.write(f"**Rationale:** {item.get('rationale', 'N/A')}")
221
  else:
222
  st.info("No DDx provided.")
223
  st.markdown("**Risk Assessment:**")
 
236
  with cols[1]:
237
  st.markdown("**Recommended Plan:**")
238
  plan = structured_output.get('recommended_plan', {})
239
+ for section in ["investigations", "therapeutics", "consultations", "patient_education"]:
240
+ st.markdown(f"_{section.replace('_', ' ').capitalize()}:_")
241
  items = plan.get(section)
242
  if items and isinstance(items, list):
243
+ for item in items:
244
+ st.markdown(f"- {item}")
245
  elif items:
246
  st.markdown(f"- {items}")
247
  else:
 
255
  st.markdown(f"> {interaction_summary}")
256
  st.divider()
257
 
258
+ # Tool Call Display
259
  if getattr(msg, 'tool_calls', None):
260
  with st.expander("πŸ› οΈ AI requested actions", expanded=False):
261
  if msg.tool_calls:
262
  for tc in msg.tool_calls:
263
  try:
264
+ st.code(
265
+ f"Action: {tc.get('name', 'Unknown Tool')}\nArgs: {json.dumps(tc.get('args', {}), indent=2)}",
266
+ language="json"
267
+ )
268
  except Exception as display_e:
269
+ st.error(f"Could not display tool call args: {display_e}", icon="⚠️")
270
  st.code(f"Action: {tc.get('name', 'Unknown Tool')}\nRaw Args: {tc.get('args')}")
271
  else:
272
+ st.caption("_No actions requested._")
273
  elif isinstance(msg, ToolMessage):
274
  tool_name_display = getattr(msg, 'name', 'tool_execution')
275
  with st.chat_message(tool_name_display, avatar="πŸ› οΈ"):
276
  try:
277
+ # Tool message display logic...
278
  tool_data = json.loads(msg.content)
279
  status = tool_data.get("status", "info")
280
  message = tool_data.get("message", msg.content)
281
  details = tool_data.get("details")
282
  warnings = tool_data.get("warnings")
283
+ # Display flagged risks immediately if the tool signals it
284
+ if tool_name_display == "flag_risk" and status == "flagged":
285
+ st.error(f"🚨 **RISK FLAGGED:** {message}", icon="🚨")
286
+ elif status in ["success", "clear"]:
287
+ st.success(f"{message}", icon="βœ…")
288
  elif status == "warning":
289
  st.warning(f"{message}", icon="⚠️")
 
 
 
290
  else:
291
  st.error(f"{message}", icon="❌")
292
+ if warnings and isinstance(warnings, list):
293
+ st.caption("Details:")
294
+ for warn in warnings:
295
+ st.caption(f"- {warn}")
296
  if details:
297
  st.caption(f"Details: {details}")
298
  except json.JSONDecodeError:
 
306
  if not st.session_state.patient_data:
307
  st.warning("Please load patient data first.")
308
  st.stop()
309
+ if 'agent' not in st.session_state or not st.session_state.agent:
310
+ st.error("Agent not initialized. Check logs.")
311
+ st.stop()
312
+
313
+ # Append user message and display immediately
314
  user_message = HumanMessage(content=prompt)
315
  st.session_state.messages.append(user_message)
316
  with st.chat_message("user"):
317
  st.markdown(prompt)
318
+
319
+ # Prepare state for the agent
320
+ current_state_dict = {
321
+ "messages": st.session_state.messages,
322
+ "patient_data": st.session_state.patient_data,
323
+ "summary": st.session_state.get("summary"),
324
+ "interaction_warnings": None # Start clean
325
+ }
326
+
327
+ # Invoke the agent's graph for one turn
328
+ with st.spinner("SynapseAI is processing..."):
329
  try:
330
+ # Call the agent instance's method
331
+ final_state = st.session_state.agent.invoke_turn(current_state_dict)
332
+
333
+ # Update Streamlit session state from the returned agent state
334
+ st.session_state.messages = final_state.get('messages', [])
335
+ st.session_state.summary = final_state.get('summary')
336
+
337
  except Exception as e:
338
+ print(f"CRITICAL ERROR during agent invocation: {type(e).__name__} - {e}")
339
  traceback.print_exc()
340
+ st.error(f"An error occurred during processing: {e}", icon="❌")
341
+ # Append error to messages for user visibility
342
+ st.session_state.messages.append(AIMessage(content=f"Error during processing: {e}"))
343
+
344
+ # Rerun Streamlit script to update the chat display
345
  st.rerun()
346
 
347
  # Disclaimer