File size: 23,030 Bytes
38d6cb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6b7fcf
38d6cb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
from chat_utils import *
from model_utils import *

import json
import shutil
from faker import Faker


PATH_MED  = "model/medication_recommendation/best.ckpt"
PATH_DIAG = "model/diagnosis_prediction/best.ckpt"
#shutil.rmtree(".cache/", ignore_errors=True)


def main():

    # ---- SETTINGS PAGE ----
    st.set_page_config(page_title="PIE-Med - Dashboard", page_icon="🩺", layout="wide")

    # with open('css/style.css') as f:
    #     hide_streamlit_style = f.read()
    # st.markdown(hide_streamlit_style, unsafe_allow_html=True)

    # ---- SESSION STATE ----
    if 'patient' not in st.session_state:
        st.session_state.patient = None
    if 'name' not in st.session_state:
        st.session_state.name = None
    if 'lastname' not in st.session_state:
        st.session_state.lastname = None
    if 'gender_sign' not in st.session_state:
        st.session_state.gender_sign = None

    # ---- SIDE BAR ----
    # st.sidebar.image(".\streamlit_images\logo_icon.png")
    # st.sidebar.divider()

    # ---- MAIN PAGE ----
    st.title(":rainbow[PIE-Med]")
    st.markdown("Welcome to PIE-Med 🩺!")

    desc = st.empty()
    desc1 = st.empty()
    desc.caption("**PIE-Med** 🩺, a cutting-edge system designed to enhance medical decision-making through \
                    the integration of **Graph Neural Networks (GNNs)** ⚙️, **eXplainable AI (XAI)** ❓ techniques, \
                    and **Large Language Models (LLMs)** 🧠.")
    desc1.caption("**⏳ WAIT MINUTES FOR THE LOADING OF THE MODELS AND THE DATASET**")

    model_med_ig, model_med_gnn, model_diag_ig, model_diag_gnn, \
        dataset, mimic3sample_med, mimic3sample_diag = load_gnn()
    checkpoint_MED = torch.load(PATH_MED)
    checkpoint_DIAG = torch.load(PATH_DIAG)

    desc1.empty()

    fake = Faker()

    selected_patient = None
    if selected_patient is None:
        placeholder2 = st.empty()
        with placeholder2.expander("⚠️ **Before using the framework, read the disclaimer for the use of Framework**"):
            disclaimer = f"""

            The use of our Healthcare framework based on MIMIC III (https://physionet.org/content/mimiciii/1.4/) is subject to the terms and warnings as follows:

            **Research and Decision Support Purpose:** Our framework has been developed primarily for research and decision support in the healthcare context. The information and recommendations generated should not replace the professional judgment of qualified healthcare practitioners but may be utilized as support for the final decision by the doctor or the directly involved party.

            **Data Origin:** The processed healthcare data originates from the MIMIC III database and undergoes enrichment and modeling through the application of Heterogeneous Graph Neural Network. It is important to note that the original data may contain variations and limitations, and the accuracy of the processed information depends on the quality of the input data.

            **Medical Recommendations:** The drug and diagnosis recommendations generated by the framework are hypothetical and based on Graph Neural Network learning models. These should not be considered definitive prescriptions, and the final decision regarding patient treatment should be made by a qualified medical professional.

            **Human Readable Explanations:** The embedded explainability system in the framework utilizes graph explainability models and Large Language Models (LLM) to generate understandable explanations for end-users, such as physicians. However, these explanations are interpretations of the model results and may not fully reflect the complexity of medical reasoning.

            **Framework Limitations:** Our framework has intrinsic limitations, including those related to the quality of input data, the characteristics of the machine learning model, and the dynamics of the healthcare context. Users are encouraged to exercise caution in interpreting the provided information.

            **User Responsibility:** Users accessing and utilizing our framework are responsible for the accurate interpretation of the provided information and for making appropriate decisions based on their clinical judgment. The creators assume no responsibility for any consequences arising from improper use or misinterpretation of the information generated by the framework.

            By using our healthcare data processing framework, the user agrees to comply with these conditions. The continuous evolution of the fields of medicine and technology may necessitate periodic updates to this disclaimer.

            """

            st.subheader("Disclaimer")
            st.info(disclaimer)
            agree = st.checkbox("I accept and have read the disclaimer!")
            placeholder1 = st.empty()
            placeholder1.warning("You must accept the disclaimer to use the framework!", icon="⚠️")

            if not(agree):
                st.stop()

            placeholder1.empty()
            placeholder2.info("You can now use the framework! 🎉 Please select the task and select a patient! 🩺")
            task = st.sidebar.selectbox(label='Select __task__: ', index=None, placeholder="Select type of task", options=['medications', 'diagnosis'])

            if task is None:
                st.stop()
            elif task == "medications":
                mimic3sample = mimic3sample_med
            elif task == "diagnosis":
                mimic3sample = mimic3sample_diag

            mimic_df = pd.DataFrame(mimic3sample.samples)

            selected_patient = st.sidebar.selectbox(label='Select __patient__ n°: ', index=None, placeholder="Select a patient", options=mimic_df['patient_id'].unique())
            while selected_patient is None:
                st.stop()

    desc.empty()
    placeholder2.empty()

    patient_dict = dataset.patients
    patient_info = patient_dict[selected_patient]
    gender = patient_info.gender

    if selected_patient != st.session_state.patient:
        if gender == "M":
            first_name = fake.first_name_male()
            last_name = fake.last_name_male()
            gender_sign = "male_sign"
        elif gender == "F":
            first_name = fake.first_name_female()
            last_name = fake.last_name_female()
            gender_sign = "female_sign"
        else:
            first_name = "Name"
            last_name = "Unknown"

        st.session_state.patient = selected_patient
        st.session_state.name = ":blue[" + first_name + "]"
        st.session_state.lastname = last_name
        st.session_state.gender_sign = gender_sign

    patient = st.session_state.patient
    name = st.session_state.name
    lastname = st.session_state.lastname
    gender_sign = st.session_state.gender_sign

    mimic_df_patient = mimic_df[mimic_df['patient_id'] == selected_patient] # select all the rows with the selected patient

    for i in range(len(mimic_df_patient)):
        if i == len(mimic_df_patient) - 1:
            last_visit = mimic_df_patient.iloc[[i]]

    # ---- Patient info ----
    # st.subheader(":blue[DASHBOARD OF] ")
    st.warning("🚨 **NOTE** 🚨: The patient's name, shown below, was randomly generated for demonstration purposes.")
    st.title("{} {} :{}:".format(name, lastname, gender_sign))
    st.caption("Patient n°: {}  -  Gender: {}  -  Ethnicity: {}".format(patient, patient_info.gender, patient_info.ethnicity))

    l1, r1 = st.columns([0.44, 0.56])

    with l1:
        st.subheader("📋 Medical history")
        # st.caption("The following table shows the *complete* medical history of the patient n°: **{}**.".format(patient))

        visit = st.selectbox(label='🏥 __Hospital admission__ n°: ', options=mimic_df_patient['visit_id'].unique())
        if visit:
            mimic_df_patient_visit = mimic_df_patient[mimic_df_patient['visit_id'] == visit] # select all the rows with the selected visit
            if task == "medications":
                mimic_df_patient_visit_filtered = mimic_df_patient_visit.drop(columns=['visit_id', 'patient_id', 'drugs_hist'])
            elif task == "diagnosis":
                mimic_df_patient_visit_filtered = mimic_df_patient_visit.drop(columns=['visit_id', 'patient_id'])

            atc = InnerMap.load("ATC")
            icd9 = InnerMap.load("ICD9CM")
            icd9_proc = InnerMap.load("ICD9PROC")

            for column in mimic_df_patient_visit_filtered.columns:
                with st.expander("{}".format(column)):
                    try:
                        if column == "medications":
                            if task == "medications":
                                med_history = [[med, atc.lookup(med)] for med in mimic_df_patient_visit_filtered[column].explode() if med]
                            elif task == "diagnosis":
                                med_history = [[med, atc.lookup(med)] for med in (mimic_df_patient_visit_filtered[column].explode()).explode() if med]
                            st.dataframe(med_history, hide_index=True, column_config={"0": "ATC", "1": "Description"})
                        elif column == "diagnosis":
                            if task == "medications":
                                col_history = [[idx, icd9.lookup(idx)] for idx in (mimic_df_patient_visit_filtered[column].explode()).explode() if idx]
                            elif task == "diagnosis":
                                col_history = [[idx+'0', icd9.lookup(idx+'0')] if idx.startswith('E') else [idx, icd9.lookup(idx)] for idx in mimic_df_patient_visit_filtered[column].explode() if idx]
                            st.dataframe(col_history, hide_index=True, column_config={"0": "ICD9", "1": "Description"})
                        elif column == "symptoms":
                            col_history = [[idx, icd9.lookup(idx)] for idx in (mimic_df_patient_visit_filtered[column].explode()).explode() if idx]
                            st.dataframe(col_history, hide_index=True, column_config={"0": "ICD9", "1": "Description"})
                        elif column == "procedures":
                            col_history = [[idx, icd9_proc.lookup(idx)] for idx in (mimic_df_patient_visit_filtered[column].explode()).explode() if idx]
                            st.dataframe(col_history, hide_index=True, column_config={"0": "ICD9", "1": "Description"})
                    except:
                        st.write("No data available for this column.")

        st.subheader(f"🧾 Recommended _{task}_")
        st.caption(f"""The following {task} are recommended for the patient during the **hospital admission n°: \
                   {format(last_visit['visit_id'].item())}**. \n The recommendations are based on the \
                    output probabilities generated by the **GNN (_Graph Neural Network_)** model.""")

        if task == "medications":
            model_med_ig.load_state_dict(checkpoint_MED)
            model_med_gnn.load_state_dict(checkpoint_MED)
            model = model_med_ig
        elif task == "diagnosis":
            model_diag_ig.load_state_dict(checkpoint_DIAG)
            model_diag_gnn.load_state_dict(checkpoint_DIAG)
            model = model_diag_ig

        # ---- Output model ----
        model.eval()
        output = model(last_visit['patient_id'],
                    last_visit['visit_id'],
                    last_visit['diagnosis'],
                    last_visit['procedures'],
                    last_visit['symptoms'],
                    last_visit['medications'])

        list_output, list_indices = get_list_output(output['y_prob'], last_visit, task, mimic3sample)
        list_output = [[idx, item] for idx, item in zip(*list_indices, *list_output) if item]
        st.dataframe(list_output, column_config={"0": "ID", "1": f"Recommended {task}"}, height=None, width=None)

    with r1:
        st.subheader(f"""🗣 *Why* did the model recommend these {task}?""")
        r1l1, r1c1, r1r1 = st.columns(3)
        with r1l1:
            visualization = st.radio("Visualization", options=["Explainable", "Interpretable"], horizontal=True)
        with r1c1:
            algorithm = st.radio("Algorithm", options=["IG", "GNNExplainer"], horizontal=True)
        with r1r1:
            threshold = st.slider("Threshold", min_value=10, max_value=50, value=15, step=5, format=None, key=None)

        if task == "medications" and algorithm == "IG":
            model = model_med_ig
        elif task == "medications" and algorithm == "GNNExplainer":
            model = model_med_gnn
        elif task == "diagnosis" and algorithm == "IG":
            model = model_diag_ig
        elif task == "diagnosis" and algorithm == "GNNExplainer":
            model = model_diag_gnn

        st.caption(f"""The graph shown as follows provides an interpretation of the model's decision making process on the recommended \
                    *{task}* for the patient during the **hospital admission n°: {format(last_visit['visit_id'].item())}**. \
                    \n\n The interpretability is based on the **{algorithm} (_{task}_)** algorithm.""")
        options = [item[1] for item in list_output if item]
        selected_label = st.selectbox(f'Select the {task} to explain', index=None, 
                                        placeholder=f"Choice a {task} from Recommended {task} ranking to explain", 
                                        options=options)

        if selected_label is None:
            st.stop()

        selected_idx = [item[0] for item in list_output if item[1] == selected_label]

        st.caption("Legend of the graph:")
        col1, col2, col3, col4, col5, col6, col7, col8 = st.columns([0.1, 0.3, 0.1, 0.3, 0.1, 0.3, 0.1, 0.3])

        with col1:
            st.markdown(
                """
                <style>
                #square1 {
                    width: 20px;
                    height: 20px;
                    background: #20b2aa;
                    border-radius: 3px;
                }
                </style>
                <div id="square1"></div>
                """,
                unsafe_allow_html=True,
            )

            st.markdown(
                """
                <style>
                #square2 {
                    width: 20px;
                    height: 20px;
                    background: #fa8072;
                    border-radius: 3px;
                    margin-top: 20px;
                }
                </style>
                <div id="square2"></div>
                """,
                unsafe_allow_html=True,
            )

        with col2:
            st.caption("Patient")

            st.caption("Visit")

        with col3:
            st.markdown(
                """
                <style>
                #square3 {
                    width: 20px;
                    height: 20px;
                    background: #cd853f;
                    border-radius: 3px;
                }
                </style>
                <div id="square3"></div>
                """,
                unsafe_allow_html=True,
            )
            st.markdown(
                """
                <style>
                #square4 {
                    width: 20px;
                    height: 20px;
                    background: #da70d6;
                    border-radius: 3px;
                    margin-top: 20px;
                }
                </style>
                <div id="square4"></div>
                """,
                unsafe_allow_html=True,
            )

        with col4:
            st.caption("Diagnosis")
            st.caption("Procedures")

        with col5:
            st.markdown(
                """
                <style>
                #square5 {
                    width: 20px;
                    height: 20px;
                    background: #98fb98;
                    border-radius: 3px;
                }
                </style>
                <div id="square5"></div>
                """,
                unsafe_allow_html=True,
            )

        with col6:
            st.caption("Symptoms")

        with col7:
            st.markdown(
                """
                <style>
                #square6 {
                    width: 20px;
                    height: 20px;
                    background: #87ceeb;
                    border-radius: 3px;
                }
                </style>
                <div id="square6"></div>
                """,
                unsafe_allow_html=True,
            )

        with col8:
            st.caption("Medications")

        explain_sample = {}
        for visit_sample in mimic3sample.samples:
            if visit_sample['patient_id'] == patient and visit_sample['visit_id'] == last_visit['visit_id'].item():
                if visit_sample.get('drugs_hist') != None:
                    del visit_sample['drugs_hist']
                explain_sample['test'] = visit_sample

        model.eval()
        explain_dataset = SampleEHRDataset(list(explain_sample.values()), code_vocs="ATC")
        explainability(model, explain_dataset, selected_idx[0], visualization, algorithm, task, threshold)


    ####################### CARE AI module ##################################
    st.header('🩺🧠 Medical Agents Evaluation')
    st.caption("The section shown as follows is dedicated to the Explainability module, which is responsible for generating the analysis of the doctors' proposals and the collaborative discussion between the medical team members for the final decision on the patient's treatment.")

    model_name = st.selectbox("Select the LLM model", options=["meta/llama3-8b-instruct"])

    explanation = st.button("Generate explanation")
    if not(explanation):
        st.stop()

    col1, col2 = st.columns([0.5, 0.6], gap="large")

    with col1:
        with open("streamlit_results/medical_scenario.txt", "r") as f:
            medical_scenario = f.read()
        st.subheader("📄 Medical Scenario")
        st.caption(f"The scenario shown as follows for the patient in the **hospital admission n°: {format(last_visit['visit_id'].item())}** is provided by the medical team.")
        st.markdown('###')
        with st.expander("👁️ Read the medical scenario", expanded=True):
            container = st.container(height=1145)
            container.write(medical_scenario)

    with col2:
        st.subheader("👨‍⚕️🔎 Doctor Recruiter")
        st.caption("The doctor recruiter is responsible for recruiting the medical team to help the internist doctor make a final decision on the patient's during the collaborative discussion.")
        with st.status("Recruiting doctor...", expanded=False) as status:
            with open("streamlit_results/prompt_recruiter_doctors.txt", "r") as f:
                prompt_recruiter_doctors = f.read()
            text = doctor_recruiter(prompt_recruiter_doctors, model_name)
            if model_name == "meta/llama3-8b-instruct":
                text[0] = text[0].split("Here is the JSON file:\n\n")[1]
            json_data = json.loads(str(text[0]))
            with open("streamlit_results/recruited_doctors.json", "w") as f:
                json.dump(text[0], f, indent=4)

            for i, doctor in enumerate(json_data['doctors']):
                role = f"""**🥼 {doctor['role'].replace("_", " ")}**"""
                st.markdown(role)
                st.write(doctor['description'])
                if i != len(json_data['doctors'])-1:
                    st.divider()

            status.update(label="Doctor recruited!", state="complete", expanded=True)
        st.button('Rerun')

    st.subheader("Analysis Proposition")
    with st.spinner("Doctors are thinking..."):
        with open("streamlit_results/prompt_internist_doctor.txt", "r") as f:
            prompt_internist_doctor = f.read()

        prompt_reunion = f"""Based on your assessment and the medical team's recommendations regarding {task} during the patient visit:\n"""
        prompt_reunion += f"""Confront with your medical colleagues, highlighting relevant aspects related to the patient's condition and the {task}. Underline the crucial elements that influence your decision on its justification or unjustification in 30 words.\n"""
        prompt_reunion += f"""\nAnalysis of doctors' proposals\n\n"""

        for i in range(len(json_data['doctors'])):
            with st.status(f"The 👨‍⚕️ {json_data['doctors'][i]['role'].replace('_', ' ')} is analysing ...", expanded=False) as status_doc:
                with st.chat_message(name="user", avatar="streamlit_images/{}.png".format(i)):
                    analysis = """"""
                    analysis += f"""**Doctor**: {json_data['doctors'][i]['role'].replace(" ", "_")}\n\n"""
                    text = doctor_discussion(json_data['doctors'][i]['role'], prompt_internist_doctor, model_name)
                    analysis += "**Analysis**: " + text[0]
                    st.markdown(f"**Analysis**: {text[0]}")
                    status_doc.update(label="The 👨‍⚕️ {} analysed!".format(json_data['doctors'][i]['role'].replace('_', ' ')), state="complete", expanded=True)
                    prompt_reunion += f"""{analysis}"""
                    prompt_reunion += f"\n--------------------------------------------------\n\n"

    image, text = st.columns([0.2, 0.8])
    with image: 
        st.image("streamlit_images/collaborative.png")
    with text:
        st.subheader('Discussion')

    st.caption("The discussion shown as follows is based on the **Large Language Model** (LLM) **chosen**. The LLM is responsible for generating the discussion between the medical team members for the final decision on the patient's treatment.")
    with st.spinner("Doctors are discussing..."):
        internist_sys_message = f"""As an INTERNIST DOCTOR, you have the task of globally evaluating and managing the patient's health and pathology.\n"""
        internist_sys_message += f"""In the light of the entire discussion, you must provide a final schematic report to the doctor based on the recommendation and the doctors' opinions."""

        doc = multiagent_doctors(json_data, model_name)
        manager = care_discussion_start(doc, prompt_reunion, internist_sys_message, model_name)

        with st.chat_message(name="user", avatar="streamlit_images/internist.png"):
            internist = list(manager.chat_messages.values())
            internist_opinion = internist[0][6]['content']
            st.write(f"**{internist[0][6]['name'].replace('_',' ')}**: {internist_opinion}")

        # Add a download button:
        st.download_button(
            label="Download PDF", 
            data=gen_pdf(patient, name, lastname, last_visit['visit_id'].item(), list_output, medical_scenario, internist_opinion), 
            file_name=f"Medical_Report_Patient_{patient}.pdf", 
            mime="application/pdf", 
        )


if __name__ == "__main__":
    main()