File size: 13,896 Bytes
b731976
788074d
b34efbf
 
 
 
 
 
4b23857
b731976
4b23857
b731976
 
 
 
b34efbf
6b6515d
b34efbf
71db5de
b731976
b34efbf
 
 
 
6b6515d
 
 
99a7bc0
b731976
99a7bc0
9988477
b731976
99a7bc0
a2143ab
99a7bc0
6b6515d
 
4258926
31ea2bf
4258926
 
 
b731976
 
 
6b6515d
 
 
b731976
 
 
 
 
 
4b23857
6b6515d
9988477
4258926
 
4b23857
6b6515d
 
 
4b23857
6b6515d
 
 
 
b731976
99a7bc0
b731976
 
4258926
6b6515d
 
 
31ea2bf
6b6515d
 
b731976
6b6515d
b731976
6b6515d
 
 
b731976
71db5de
99a7bc0
6b6515d
99a7bc0
b731976
9988477
 
4258926
a2143ab
9988477
4258926
6b6515d
4258926
6b2d9f7
6b6515d
 
31ea2bf
6b6515d
 
 
 
4b23857
6b6515d
4b23857
6b6515d
 
 
4b23857
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99a7bc0
b731976
31ea2bf
6b6515d
 
71db5de
6b6515d
 
 
4258926
b34efbf
99a7bc0
6b6515d
 
4b23857
6b6515d
 
 
 
 
 
 
4258926
9988477
4258926
6b6515d
 
b731976
4b23857
6b6515d
4b23857
b731976
31ea2bf
b731976
 
 
4b23857
71db5de
31ea2bf
9988477
6b6515d
b564942
 
6b6515d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# app.py
import streamlit as st
import json
import re
import os
import traceback
from dotenv import load_dotenv

# Import agent logic and message types
try:
    from agent import ClinicalAgent, AgentState, check_red_flags # Import necessary components
    from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
except ImportError as e:
    st.error(f"Failed to import from agent.py: {e}. Make sure agent.py is in the same directory.")
    st.stop()


# --- Environment Variable Loading & Validation ---
load_dotenv()
# Check keys required by agent.py are present before initializing the agent
UMLS_API_KEY = os.environ.get("UMLS_API_KEY")
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY")
missing_keys = []
if not UMLS_API_KEY: missing_keys.append("UMLS_API_KEY")
if not GROQ_API_KEY: missing_keys.append("GROQ_API_KEY")
if not TAVILY_API_KEY: missing_keys.append("TAVILY_API_KEY")
if missing_keys:
    st.error(f"Missing required API Key(s): {', '.join(missing_keys)}. Please set them in Hugging Face Space Secrets or environment variables.")
    st.stop()

# --- App Configuration ---
class ClinicalAppSettings:
    APP_TITLE = "SynapseAI (UMLS/FDA Integrated)"
    PAGE_LAYOUT = "wide"
    MODEL_NAME_DISPLAY = "Llama3-70b (via Groq)" # Defined in agent.py


# --- Streamlit UI ---
def main():
    st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
    st.title(f"🩺 {ClinicalAppSettings.APP_TITLE}")
    st.caption(f"Interactive Assistant | LangGraph/Groq/Tavily/UMLS/OpenFDA | Model: {ClinicalAppSettings.MODEL_NAME_DISPLAY}")

    # Initialize session state
    if "messages" not in st.session_state: st.session_state.messages = []
    if "patient_data" not in st.session_state: st.session_state.patient_data = None
    if "summary" not in st.session_state: st.session_state.summary = None
    if "agent" not in st.session_state:
        try:
            st.session_state.agent = ClinicalAgent()
            print("ClinicalAgent successfully initialized in Streamlit session state.")
        except Exception as e:
            st.error(f"Failed to initialize Clinical Agent: {e}. Check API keys and dependencies.")
            print(f"ERROR Initializing ClinicalAgent: {e}"); traceback.print_exc(); st.stop()

    # --- Patient Data Input Sidebar ---
    with st.sidebar:
        st.header("πŸ“„ Patient Intake Form")
        # Input fields... (Assume full fields as before)
        st.subheader("Demographics"); age = st.number_input("Age", 0, 120, 55, key="sb_age"); sex = st.selectbox("Sex", ["Male", "Female", "Other"], key="sb_sex")
        st.subheader("HPI"); chief_complaint = st.text_input("Chief Complaint", "Chest pain", key="sb_cc"); hpi_details = st.text_area("HPI Details", "55 y/o male...", height=100, key="sb_hpi"); symptoms = st.multiselect("Symptoms", ["Nausea", "Diaphoresis", "SOB", "Dizziness", "Severe Headache", "Syncope", "Hemoptysis"], default=["Nausea", "Diaphoresis"], key="sb_sym")
        st.subheader("History"); pmh = st.text_area("PMH", "HTN, HLD, DM2, History of MI", key="sb_pmh"); psh = st.text_area("PSH", "Appendectomy", key="sb_psh")
        st.subheader("Meds & Allergies"); current_meds_str = st.text_area("Current Meds", "Lisinopril 10mg daily\nMetformin 1000mg BID\nWarfarin 5mg daily", key="sb_meds"); allergies_str = st.text_area("Allergies", "Penicillin (rash), Aspirin", key="sb_allergies")
        st.subheader("Social/Family"); social_history = st.text_area("SH", "Smoker", key="sb_sh"); family_history = st.text_area("FHx", "Father MI", key="sb_fhx")
        st.subheader("Vitals & Exam"); col1, col2 = st.columns(2);
        with col1: temp_c = st.number_input("Temp C", 35.0, 42.0, 36.8, format="%.1f", key="sb_temp"); hr_bpm = st.number_input("HR", 30, 250, 95, key="sb_hr"); rr_rpm = st.number_input("RR", 5, 50, 18, key="sb_rr")
        with col2: bp_mmhg = st.text_input("BP", "155/90", key="sb_bp"); spo2_percent = st.number_input("SpO2", 70, 100, 96, key="sb_spo2"); pain_scale = st.slider("Pain", 0, 10, 8, key="sb_pain")
        exam_notes = st.text_area("Exam Notes", "Awake, alert...", height=50, key="sb_exam")

        if st.button("Start/Update Consultation", key="sb_start"):
            # Compile data...
            current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
            current_med_names_only = [];
            for med in current_meds_list: match = re.match(r"^\s*([a-zA-Z\-]+)", med);
            if match: current_med_names_only.append(match.group(1).lower())
            allergies_list = []
            for a in allergies_str.split(','): cleaned_allergy = a.strip();
            if cleaned_allergy: match = re.match(r"^\s*([a-zA-Z\-\s/]+)(?:\s*\(.*\))?", cleaned_allergy); name_part = match.group(1).strip().lower() if match else cleaned_allergy.lower(); allergies_list.append(name_part)
            # Update patient data in session state
            st.session_state.patient_data = { "demographics": {"age": age, "sex": sex}, "hpi": {"chief_complaint": chief_complaint, "details": hpi_details, "symptoms": symptoms}, "pmh": {"conditions": pmh}, "psh": {"procedures": psh}, "medications": {"current": current_meds_list, "names_only": current_med_names_only}, "allergies": allergies_list, "social_history": {"details": social_history}, "family_history": {"details": family_history}, "vitals": { "temp_c": temp_c, "hr_bpm": hr_bpm, "bp_mmhg": bp_mmhg, "rr_rpm": rr_rpm, "spo2_percent": spo2_percent, "pain_scale": pain_scale}, "exam_findings": {"notes": exam_notes} }
            # Call check_red_flags from agent module
            red_flags = check_red_flags(st.session_state.patient_data); st.sidebar.markdown("---");
            if red_flags: st.sidebar.warning("**Initial Red Flags:**"); [st.sidebar.warning(f"- {flag.replace('Red Flag: ','')}") for flag in red_flags]
            else: st.sidebar.success("No immediate red flags.")
            # Reset conversation and summary on new intake
            initial_prompt = "Initiate consultation. Review patient data and begin analysis."
            st.session_state.messages = [HumanMessage(content=initial_prompt)]
            st.session_state.summary = None # Reset summary
            st.success("Patient data loaded/updated.")
            st.rerun()

    # --- Main Chat Interface Area ---
    st.header("πŸ’¬ Clinical Consultation")
    # Display loop
    for msg in st.session_state.messages:
        if isinstance(msg, HumanMessage):
            with st.chat_message("user"): st.markdown(msg.content)
        elif isinstance(msg, AIMessage):
            with st.chat_message("assistant"):
                ai_content = msg.content; structured_output = None
                try: # JSON Parsing logic...
                    json_match = re.search(r"```json\s*(\{.*?\})\s*```", ai_content, re.DOTALL | re.IGNORECASE)
                    if json_match: json_str = json_match.group(1); prefix = ai_content[:json_match.start()].strip(); suffix = ai_content[json_match.end():].strip();
                    if prefix: st.markdown(prefix); structured_output = json.loads(json_str);
                    if suffix: st.markdown(suffix)
                    elif ai_content.strip().startswith("{") and ai_content.strip().endswith("}"): structured_output = json.loads(ai_content); ai_content = ""
                    else: st.markdown(ai_content)
                except Exception as e: st.markdown(ai_content); print(f"Error parsing/displaying AI JSON: {e}")

                if structured_output and isinstance(structured_output, dict): # Structured JSON display logic...
                    st.divider(); st.subheader("πŸ“Š AI Analysis & Recommendations")
                    cols = st.columns(2);
                    with cols[0]: # Assessment, DDx, Risk
                        st.markdown("**Assessment:**"); st.markdown(f"> {structured_output.get('assessment', 'N/A')}")
                        st.markdown("**Differential Diagnosis:**"); ddx = structured_output.get('differential_diagnosis', []);
                        if ddx: [st.expander(f"{'πŸ₯‡πŸ₯ˆπŸ₯‰'[('High','Medium','Low').index(item.get('likelihood','Low')[0])] if item.get('likelihood','?')[0] in 'HML' else '?'} {item.get('diagnosis', 'Unknown')} ({item.get('likelihood','?')})").write(f"**Rationale:** {item.get('rationale', 'N/A')}") for item in ddx]
                        else: st.info("No DDx provided.")

                        # Risk Assessment Display (CORRECTED - Separate lines)
                        st.markdown(f"**Risk Assessment:**")
                        risk = structured_output.get('risk_assessment', {})
                        flags = risk.get('identified_red_flags', [])
                        concerns = risk.get("immediate_concerns", [])
                        comps = risk.get("potential_complications", [])

                        if flags:
                            st.warning(f"**Flags:** {', '.join(flags)}")
                        if concerns:
                            st.warning(f"**Concerns:** {', '.join(concerns)}")
                        if comps:
                            st.info(f"**Potential Complications:** {', '.join(comps)}")
                        # Add a message if no risks were highlighted by the AI assessment
                        if not flags and not concerns and not comps:
                            st.success("No specific risks highlighted in this AI assessment.")

                    with cols[1]: # Plan
                         st.markdown("**Recommended Plan:**"); plan = structured_output.get('recommended_plan', {});
                         for section in ["investigations","therapeutics","consultations","patient_education"]: st.markdown(f"_{section.replace('_',' ').capitalize()}:_"); items = plan.get(section); [st.markdown(f"- {item}") for item in items] if items and isinstance(items, list) else (st.markdown(f"- {items}") if items else st.markdown("_None_")); st.markdown("")

                    # Rationale & Interaction Summary
                    st.markdown("**Rationale & Guideline Check:**"); st.markdown(f"> {structured_output.get('rationale_summary', 'N/A')}")
                    interaction_summary = structured_output.get("interaction_check_summary", "");
                    if interaction_summary: st.markdown("**Interaction Check Summary:**"); st.markdown(f"> {interaction_summary}");
                    st.divider()

                # Tool Call Display
                if getattr(msg, 'tool_calls', None):
                     with st.expander("πŸ› οΈ AI requested actions", expanded=False):
                         if msg.tool_calls:
                            for tc in msg.tool_calls:
                                try: st.code(f"Action: {tc.get('name', 'Unknown Tool')}\nArgs: {json.dumps(tc.get('args', {}), indent=2)}", language="json")
                                except Exception as display_e: st.error(f"Could not display tool call args: {display_e}", icon="⚠️"); st.code(f"Action: {tc.get('name', 'Unknown Tool')}\nRaw Args: {tc.get('args')}")
                         else: st.caption("_No actions requested._")
        elif isinstance(msg, ToolMessage):
            tool_name_display = getattr(msg, 'name', 'tool_execution')
            with st.chat_message(tool_name_display, avatar="πŸ› οΈ"):
                try: # Tool message display logic...
                    tool_data = json.loads(msg.content); status = tool_data.get("status", "info"); message = tool_data.get("message", msg.content); details = tool_data.get("details"); warnings = tool_data.get("warnings");
                    if tool_name_display == "flag_risk" and status == "flagged": st.error(f"🚨 **RISK FLAGGED:** {message}", icon="🚨") # Show flag in UI too
                    elif status == "success" or status == "clear": st.success(f"{message}", icon="βœ…")
                    elif status == "warning": st.warning(f"{message}", icon="⚠️");
                    if warnings and isinstance(warnings, list): st.caption("Details:"); [st.caption(f"- {warn}") for warn in warnings]
                    else: st.error(f"{message}", icon="❌") # Assume error if not known status
                    if details: st.caption(f"Details: {details}")
                except json.JSONDecodeError: st.info(f"{msg.content}") # Display raw if not JSON
                except Exception as e: st.error(f"Error displaying tool message: {e}", icon="❌"); st.caption(f"Raw content: {msg.content}")

    # --- Chat Input Logic ---
    if prompt := st.chat_input("Your message or follow-up query..."):
        if not st.session_state.patient_data: st.warning("Please load patient data first."); st.stop()
        if 'agent' not in st.session_state or not st.session_state.agent: st.error("Agent not initialized. Check logs."); st.stop()

        user_message = HumanMessage(content=prompt); st.session_state.messages.append(user_message)
        with st.chat_message("user"): st.markdown(prompt)
        current_state_dict = {"messages": st.session_state.messages, "patient_data": st.session_state.patient_data, "summary": st.session_state.get("summary"), "interaction_warnings": None}
        with st.spinner("SynapseAI is processing..."):
            try:
                final_state = st.session_state.agent.invoke_turn(current_state_dict)
                st.session_state.messages = final_state.get('messages', [])
                st.session_state.summary = final_state.get('summary')
            except Exception as e: print(f"CRITICAL ERROR: {e}"); traceback.print_exc(); st.error(f"Error: {e}"); st.session_state.messages.append(AIMessage(content=f"Error processing request: {e}"))
        st.rerun()

    # Disclaimer
    st.markdown("---"); st.warning("**Disclaimer:** SynapseAI is for demonstration...")

if __name__ == "__main__":
    main()