Spaces:
Running
Running
Commit
·
38d6cb6
1
Parent(s):
34a939d
Add PIE-Med Demo
Browse files- .gitattributes +2 -0
- .gitignore +1 -0
- PIE-Med.png +3 -0
- README.md +79 -1
- chat_utils.py +120 -0
- css/style.css +34 -0
- dashboard.py +470 -0
- font/OpenSans-Bold.ttf +0 -0
- font/OpenSans-Italic.ttf +0 -0
- font/OpenSans.ttf +0 -0
- model/diagnosis_prediction/best.ckpt +3 -0
- model/diagnosis_prediction/last.ckpt +3 -0
- model/medication_recommendation/best.ckpt +3 -0
- model/medication_recommendation/last.ckpt +3 -0
- model_utils.py +212 -0
- requirements.txt +0 -0
- static-kg/ANAT_DIAG.csv +3 -0
- static-kg/ATC3/DRUG_DIAG.csv +3 -0
- static-kg/ATC3/PC_DRUG.csv +3 -0
- static-kg/ATC3/SYMP_DRUG.csv +3 -0
- static-kg/DIAG_SYMP.csv +3 -0
- static-kg/DRUG_DIAG.csv +3 -0
- static-kg/PC_DRUG.csv +3 -0
- static-kg/SYMP_DRUG.csv +3 -0
- streamlit_images/0.png +3 -0
- streamlit_images/1.png +3 -0
- streamlit_images/2.png +3 -0
- streamlit_images/3.png +3 -0
- streamlit_images/4.png +3 -0
- streamlit_images/Internist.png +3 -0
- streamlit_images/collaborative.png +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.csv filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
streamlit_results/
|
PIE-Med.png
ADDED
![]() |
Git LFS Details
|
README.md
CHANGED
@@ -11,4 +11,82 @@ license: cc-by-sa-4.0
|
|
11 |
short_description: '🩺 PIE-Med: Predicting, Interpreting and Explaining Medical'
|
12 |
---
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
short_description: '🩺 PIE-Med: Predicting, Interpreting and Explaining Medical'
|
12 |
---
|
13 |
|
14 |
+
# 🩺 *PIE*-Med: *P*redicting, *I*nterpreting and *E*xplaining Medical Recommendations
|
15 |
+
|
16 |
+
Welcome to the repository for **PIE-Med**, a cutting-edge system designed to enhance medical decision-making through the integration of Graph Neural Networks (GNNs), Explainable AI (XAI) techniques, and Large Language Models (LLMs).
|
17 |
+
|
18 |
+
## 🎥 Demo (or GIF)
|
19 |
+
[Watch our demo](https://drive.google.com/file/d/1e9VXslnBzOOp5QHh4GTrT-La1PdKxhzS/preview) to see PIE-Med in action and learn how it can transform healthcare recommendations!
|
20 |
+
|
21 |
+
## 📊 Data Source
|
22 |
+
We use the **[MIMIC-III](https://mimic.physionet.org/)** dataset, a freely accessible critical care database containing de-identified health information, including vital signs, laboratory test results, medications, and more. You can find more details about the dataset here:
|
23 |
+
|
24 |
+
## 🛠 Technologies Used
|
25 |
+
- **Python**: Core programming language
|
26 |
+
- **Pandas**: Data manipulation and analysis ([Pandas Documentation](https://pandas.pydata.org/))
|
27 |
+
- **PyHealth**: Medical data preprocessing ([PyHealth Documentation](https://pyhealth.readthedocs.io/en/latest/))
|
28 |
+
- **PyTorch Geometric**: Building and training GNNs ([PyTorch Geometric Documentation](https://pytorch-geometric.readthedocs.io/en/latest/))
|
29 |
+
- **Integrated Gradients & GNNExplainer**: Interpretability techniques ([PyTorch Geometric Documentation](https://pytorch-geometric.readthedocs.io/en/latest/))
|
30 |
+
- **Streamlit**: User interface development ([Streamlit Documentation](https://streamlit.io/))
|
31 |
+
- **Py AutoGen Multi-Agent Conversation Framework**: Multi-agent collaboration and explanation ([Py AutoGen Documentation](https://microsoft.github.io/autogen/))
|
32 |
+
|
33 |
+
The PIE-Med system's computational requirements depend on the configuration used. For resource-limited environments, the light configuration with an Intel i7 CPU and 16GB RAM offers a basic but functional setup, suitable for testing on small datasets. However, more demanding tasks, such as working with larger datasets or leveraging advanced machine learning techniques (e.g., Graph Neural Networks), benefit from cloud setups like the complete configuration, which includes a GPU (NVIDIA Tesla T4). In resource-constrained contexts, optimizing models and reducing dataset size would be crucial to ensure feasible performance.
|
34 |
+
|
35 |
+
## 🔬 Methodological Workflow
|
36 |
+
PIE-Med follows a comprehensive Predict→Interpret→Explain (PIE) paradigm:
|
37 |
+
|
38 |
+
1. **Prediction Phase**: We construct a heterogeneous patient graph from MIMIC-III data and apply GNNs to generate personalized medical recommendations.
|
39 |
+
2. **Interpretation Phase**: Integrated Gradients and GNNExplainer techniques are used to provide insights into the GNN's decision-making process.
|
40 |
+
3. **Explanation Phase**: A collaborative ensemble of LLM agents analyzes the model's outputs and generates comprehensive, understandable explanations.
|
41 |
+
|
42 |
+

|
43 |
+
|
44 |
+
## 🌟 Key Features
|
45 |
+
- **Integration of GNNs and LLMs**: Combining structured machine learning with natural language processing for robust recommendations.
|
46 |
+
- **Enhanced Interpretability**: Using XAI techniques to make the decision-making process transparent.
|
47 |
+
- **Collaborative Explanation**: Multi-agent LLMs provide detailed and understandable recommendations.
|
48 |
+
|
49 |
+
## 🚀 Getting Started
|
50 |
+
Follow these steps to set up and run PIE-Med on your local machine:
|
51 |
+
|
52 |
+
### Prerequisites
|
53 |
+
Ensure you have the following installed:
|
54 |
+
- Python 3.7+
|
55 |
+
|
56 |
+
### Installation
|
57 |
+
1. **Clone the repository**:
|
58 |
+
```bash
|
59 |
+
git clone https://github.com/picuslab/PIE-Med.git
|
60 |
+
cd PIE-Med
|
61 |
+
```
|
62 |
+
|
63 |
+
2. **Create a virtual environment**:
|
64 |
+
```bash
|
65 |
+
python -m venv venv
|
66 |
+
source venv/bin/activate # On Windows use `venv\Scripts\activate`
|
67 |
+
```
|
68 |
+
|
69 |
+
3. **Install the required packages**:
|
70 |
+
```bash
|
71 |
+
pip install -r requirements.txt
|
72 |
+
```
|
73 |
+
|
74 |
+
### Running the Application
|
75 |
+
1. **Run the Streamlit application**:
|
76 |
+
```bash
|
77 |
+
streamlit run dashboard.py
|
78 |
+
```
|
79 |
+
|
80 |
+
Open your web browser and go to `http://localhost:8501` to interact with the application.
|
81 |
+
|
82 |
+
## 📈 Conclusions
|
83 |
+
PIE-Med showcases the potential of combining GNNs, XAI, and LLMs to improve medical recommendations, enhancing both accuracy and interpretability. Our system effectively separates prediction from explanation, reducing biases and enhancing decision quality.
|
84 |
+
|
85 |
+
## ⚖ Ethical considerations
|
86 |
+
|
87 |
+
**PIE-Med** aims to support medical decision-making, but is not a substitute for professional medical advice. Users should confirm recommendations with authorised healthcare providers, as limitations of AI may affect accuracy. The system ensures transparency through interpretability techniques, but all results should be considered complementary to expert advice. **⚠️ Please note that the following repository is only a DEMO, with anonymised data used for illustrative purposes only**.
|
88 |
+
|
89 |
+
## 🙏 Acknowledgments
|
90 |
+
We extend our gratitude to the creators of the MIMIC-III database, the developers of the Python libraries used, and our research team for their contributions to this project.
|
91 |
+
|
92 |
+
👨💻 This project was developed by Antonio Romano, Giuseppe Riccio, Marco Postiglione and Vincenzo Moscato
|
chat_utils.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import autogen
|
3 |
+
|
4 |
+
from autogen import OpenAIWrapper, AssistantAgent, UserProxyAgent
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
|
8 |
+
def setting_config(model_name: str) -> tuple[List[dict], dict]:
|
9 |
+
if model_name.startswith("gpt"):
|
10 |
+
config_list = [
|
11 |
+
{
|
12 |
+
"model": model_name,
|
13 |
+
"base_url": "https://api.openai.com/v1",
|
14 |
+
"api_key": st.secrets.openai,
|
15 |
+
}
|
16 |
+
]
|
17 |
+
else:
|
18 |
+
config_list = [
|
19 |
+
{
|
20 |
+
"model": model_name, # "google/gemma-2-9b-it",
|
21 |
+
"base_url": "https://integrate.api.nvidia.com/v1",
|
22 |
+
"api_key": st.secrets.nvidia,
|
23 |
+
"max_tokens": 1000,
|
24 |
+
}
|
25 |
+
]
|
26 |
+
|
27 |
+
llm_config={
|
28 |
+
"timeout": 500,
|
29 |
+
"seed": 42,
|
30 |
+
"config_list": config_list,
|
31 |
+
"temperature": 0.5,
|
32 |
+
}
|
33 |
+
|
34 |
+
return config_list, llm_config
|
35 |
+
|
36 |
+
|
37 |
+
class TrackableUserProxyAgent(UserProxyAgent):
|
38 |
+
t = 0
|
39 |
+
def _process_received_message(self, message, sender, silent):
|
40 |
+
global t
|
41 |
+
with st.chat_message(sender.name, avatar="streamlit_images/{}.png".format(self.t)):
|
42 |
+
st.write(f"**{message['name'].replace('_',' ')}**: {message['content']}")
|
43 |
+
self.t += 1
|
44 |
+
if self.t == 4:
|
45 |
+
self.t = 0
|
46 |
+
st.divider()
|
47 |
+
|
48 |
+
return super()._process_received_message(message, sender, silent)
|
49 |
+
|
50 |
+
|
51 |
+
def doctor_recruiter(prompt_recruiter_doctors: str, model_name: str) -> str:
|
52 |
+
config_list, llm_config = setting_config(model_name)
|
53 |
+
client = OpenAIWrapper(api_key=config_list[0]['api_key'], config_list=config_list)
|
54 |
+
response = client.create(messages=[{"role": "user", "content": prompt_recruiter_doctors + "\nReturn ONLY JSON file (don't use Markdown tags like: ```json)."}],
|
55 |
+
temperature=0.3,
|
56 |
+
seed=42,
|
57 |
+
model=config_list[0]['model'])
|
58 |
+
text = client.extract_text_or_completion_object(response)
|
59 |
+
|
60 |
+
return text
|
61 |
+
|
62 |
+
|
63 |
+
def doctor_discussion(doctor_description: str, prompt_internist_doctor: str, model_name: str) -> str:
|
64 |
+
config_list, llm_config = setting_config(model_name)
|
65 |
+
doctor = OpenAIWrapper(api_key=config_list[0]['api_key'], config_list=config_list)
|
66 |
+
response = doctor.create(messages=[
|
67 |
+
{"role": "assistant",
|
68 |
+
"content": doctor_description},
|
69 |
+
{"role": "user",
|
70 |
+
"content": prompt_internist_doctor + "\nDon't use Markdown tags."}
|
71 |
+
],
|
72 |
+
temperature=0.3,
|
73 |
+
model=config_list[0]['model'])
|
74 |
+
text = doctor.extract_text_or_completion_object(response)
|
75 |
+
|
76 |
+
return text
|
77 |
+
|
78 |
+
|
79 |
+
def multiagent_doctors(json_data: dict, model_name: str) -> List[AssistantAgent]:
|
80 |
+
config_list, llm_config = setting_config(model_name)
|
81 |
+
doc = []
|
82 |
+
for i in range(len(json_data['doctors'])):
|
83 |
+
doc.append(AssistantAgent(
|
84 |
+
name=json_data['doctors'][i]['role'].replace(" ", "_"),
|
85 |
+
llm_config=llm_config,
|
86 |
+
system_message="As a " + json_data['doctors'][i]['role'].replace(" ", "_") + ". Discuss with other medical experts in the team to help the INTERNIST DOCTOR make a final decision. Avoid postponing further examinations and repeating opinions given in the analysis, but explain in a logical and concise manner why you are making this final decision."))
|
87 |
+
|
88 |
+
return doc
|
89 |
+
|
90 |
+
|
91 |
+
def care_discussion_start(doc: List[AssistantAgent], prompt_reunion: str, internist_sys_message: str, model_name: str) -> autogen.GroupChatManager:
|
92 |
+
config_list, llm_config = setting_config(model_name)
|
93 |
+
doc.append(TrackableUserProxyAgent(
|
94 |
+
name="internist_doctor",
|
95 |
+
human_input_mode="NEVER",
|
96 |
+
max_consecutive_auto_reply=1,
|
97 |
+
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith(("JUSTIFIABLE", "UNJUSTIFIABLE")),
|
98 |
+
code_execution_config=False,
|
99 |
+
llm_config=llm_config,
|
100 |
+
system_message=internist_sys_message
|
101 |
+
))
|
102 |
+
|
103 |
+
groupchat = autogen.GroupChat(agents=doc,
|
104 |
+
messages=[],
|
105 |
+
max_round=(len(doc)+1),
|
106 |
+
speaker_selection_method="round_robin",
|
107 |
+
role_for_select_speaker_messages='user',
|
108 |
+
)
|
109 |
+
|
110 |
+
manager = autogen.GroupChatManager(groupchat=groupchat,
|
111 |
+
llm_config=llm_config,
|
112 |
+
max_consecutive_auto_reply=1
|
113 |
+
)
|
114 |
+
|
115 |
+
doc[-1].initiate_chat(
|
116 |
+
manager,
|
117 |
+
message=prompt_reunion,
|
118 |
+
)
|
119 |
+
|
120 |
+
return manager
|
css/style.css
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<style>
|
2 |
+
div[data-testid="stToolbar"] {
|
3 |
+
visibility: hidden;
|
4 |
+
height: 0%;
|
5 |
+
position: fixed;
|
6 |
+
}
|
7 |
+
|
8 |
+
div[data-testid="stDecoration"] {
|
9 |
+
visibility: hidden;
|
10 |
+
height: 0%;
|
11 |
+
position: fixed;
|
12 |
+
}
|
13 |
+
|
14 |
+
div[data-testid="stStatusWidget"] {
|
15 |
+
visibility: hidden;
|
16 |
+
height: 0%;
|
17 |
+
position: fixed;
|
18 |
+
}
|
19 |
+
|
20 |
+
#MainMenu {
|
21 |
+
visibility: hidden;
|
22 |
+
height: 0%;
|
23 |
+
}
|
24 |
+
|
25 |
+
header {
|
26 |
+
visibility: hidden;
|
27 |
+
height: 0%;
|
28 |
+
}
|
29 |
+
|
30 |
+
footer {
|
31 |
+
visibility: hidden;
|
32 |
+
height: 0%;
|
33 |
+
}
|
34 |
+
</style>
|
dashboard.py
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from chat_utils import *
|
2 |
+
from model_utils import *
|
3 |
+
|
4 |
+
import json
|
5 |
+
import shutil
|
6 |
+
from faker import Faker
|
7 |
+
|
8 |
+
|
9 |
+
PATH_MED = "model/medication_recommendation/best.ckpt"
|
10 |
+
PATH_DIAG = "model/diagnosis_prediction/best.ckpt"
|
11 |
+
#shutil.rmtree(".cache/", ignore_errors=True)
|
12 |
+
|
13 |
+
|
14 |
+
def main():
|
15 |
+
|
16 |
+
# ---- SETTINGS PAGE ----
|
17 |
+
st.set_page_config(page_title="PIE-Med - Dashboard", page_icon="🩺", layout="wide")
|
18 |
+
|
19 |
+
# with open('css/style.css') as f:
|
20 |
+
# hide_streamlit_style = f.read()
|
21 |
+
# st.markdown(hide_streamlit_style, unsafe_allow_html=True)
|
22 |
+
|
23 |
+
# ---- SESSION STATE ----
|
24 |
+
if 'patient' not in st.session_state:
|
25 |
+
st.session_state.patient = None
|
26 |
+
if 'name' not in st.session_state:
|
27 |
+
st.session_state.name = None
|
28 |
+
if 'lastname' not in st.session_state:
|
29 |
+
st.session_state.lastname = None
|
30 |
+
if 'gender_sign' not in st.session_state:
|
31 |
+
st.session_state.gender_sign = None
|
32 |
+
|
33 |
+
# ---- SIDE BAR ----
|
34 |
+
# st.sidebar.image(".\streamlit_images\logo_icon.png")
|
35 |
+
# st.sidebar.divider()
|
36 |
+
|
37 |
+
# ---- MAIN PAGE ----
|
38 |
+
st.title(":rainbow[PIE-Med]")
|
39 |
+
st.markdown("Welcome to PIE-Med 🩺!")
|
40 |
+
|
41 |
+
desc = st.empty()
|
42 |
+
desc1 = st.empty()
|
43 |
+
desc.caption("**PIE-Med** 🩺, a cutting-edge system designed to enhance medical decision-making through \
|
44 |
+
the integration of **Graph Neural Networks (GNNs)** ⚙️, **eXplainable AI (XAI)** ❓ techniques, \
|
45 |
+
and **Large Language Models (LLMs)** 🧠.")
|
46 |
+
desc1.caption("**⏳ WAIT MINUTES FOR THE LOADING OF THE MODELS AND THE DATASET**")
|
47 |
+
|
48 |
+
model_med_ig, model_med_gnn, model_diag_ig, model_diag_gnn, \
|
49 |
+
dataset, mimic3sample_med, mimic3sample_diag = load_gnn()
|
50 |
+
checkpoint_MED = torch.load(PATH_MED)
|
51 |
+
checkpoint_DIAG = torch.load(PATH_DIAG)
|
52 |
+
|
53 |
+
desc1.empty()
|
54 |
+
|
55 |
+
fake = Faker()
|
56 |
+
|
57 |
+
selected_patient = None
|
58 |
+
if selected_patient is None:
|
59 |
+
placeholder2 = st.empty()
|
60 |
+
with placeholder2.expander("⚠️ **Before using the framework, read the disclaimer for the use of Framework**"):
|
61 |
+
disclaimer = f"""
|
62 |
+
|
63 |
+
The use of our Healthcare framework based on MIMIC III (https://physionet.org/content/mimiciii/1.4/) is subject to the terms and warnings as follows:
|
64 |
+
|
65 |
+
**Research and Decision Support Purpose:** Our framework has been developed primarily for research and decision support in the healthcare context. The information and recommendations generated should not replace the professional judgment of qualified healthcare practitioners but may be utilized as support for the final decision by the doctor or the directly involved party.
|
66 |
+
|
67 |
+
**Data Origin:** The processed healthcare data originates from the MIMIC III database and undergoes enrichment and modeling through the application of Heterogeneous Graph Neural Network. It is important to note that the original data may contain variations and limitations, and the accuracy of the processed information depends on the quality of the input data.
|
68 |
+
|
69 |
+
**Medical Recommendations:** The drug and diagnosis recommendations generated by the framework are hypothetical and based on Graph Neural Network learning models. These should not be considered definitive prescriptions, and the final decision regarding patient treatment should be made by a qualified medical professional.
|
70 |
+
|
71 |
+
**Human Readable Explanations:** The embedded explainability system in the framework utilizes graph explainability models and Large Language Models (LLM) to generate understandable explanations for end-users, such as physicians. However, these explanations are interpretations of the model results and may not fully reflect the complexity of medical reasoning.
|
72 |
+
|
73 |
+
**Framework Limitations:** Our framework has intrinsic limitations, including those related to the quality of input data, the characteristics of the machine learning model, and the dynamics of the healthcare context. Users are encouraged to exercise caution in interpreting the provided information.
|
74 |
+
|
75 |
+
**User Responsibility:** Users accessing and utilizing our framework are responsible for the accurate interpretation of the provided information and for making appropriate decisions based on their clinical judgment. The creators assume no responsibility for any consequences arising from improper use or misinterpretation of the information generated by the framework.
|
76 |
+
|
77 |
+
By using our healthcare data processing framework, the user agrees to comply with these conditions. The continuous evolution of the fields of medicine and technology may necessitate periodic updates to this disclaimer.
|
78 |
+
|
79 |
+
"""
|
80 |
+
|
81 |
+
st.subheader("Disclaimer")
|
82 |
+
st.info(disclaimer)
|
83 |
+
agree = st.checkbox("I accept and have read the disclaimer!")
|
84 |
+
placeholder1 = st.empty()
|
85 |
+
placeholder1.warning("You must accept the disclaimer to use the framework!", icon="⚠️")
|
86 |
+
|
87 |
+
if not(agree):
|
88 |
+
st.stop()
|
89 |
+
|
90 |
+
placeholder1.empty()
|
91 |
+
placeholder2.info("You can now use the framework! 🎉 Please select the task and select a patient! 🩺")
|
92 |
+
task = st.sidebar.selectbox(label='Select __task__: ', index=None, placeholder="Select type of task", options=['medications', 'diagnosis'])
|
93 |
+
|
94 |
+
if task is None:
|
95 |
+
st.stop()
|
96 |
+
elif task == "medications":
|
97 |
+
mimic3sample = mimic3sample_med
|
98 |
+
elif task == "diagnosis":
|
99 |
+
mimic3sample = mimic3sample_diag
|
100 |
+
|
101 |
+
mimic_df = pd.DataFrame(mimic3sample.samples)
|
102 |
+
|
103 |
+
selected_patient = st.sidebar.selectbox(label='Select __patient__ n°: ', index=None, placeholder="Select a patient", options=mimic_df['patient_id'].unique())
|
104 |
+
while selected_patient is None:
|
105 |
+
st.stop()
|
106 |
+
|
107 |
+
desc.empty()
|
108 |
+
placeholder2.empty()
|
109 |
+
|
110 |
+
patient_dict = dataset.patients
|
111 |
+
patient_info = patient_dict[selected_patient]
|
112 |
+
gender = patient_info.gender
|
113 |
+
|
114 |
+
if selected_patient != st.session_state.patient:
|
115 |
+
if gender == "M":
|
116 |
+
first_name = fake.first_name_male()
|
117 |
+
last_name = fake.last_name_male()
|
118 |
+
gender_sign = "male_sign"
|
119 |
+
elif gender == "F":
|
120 |
+
first_name = fake.first_name_female()
|
121 |
+
last_name = fake.last_name_female()
|
122 |
+
gender_sign = "female_sign"
|
123 |
+
else:
|
124 |
+
first_name = "Name"
|
125 |
+
last_name = "Unknown"
|
126 |
+
|
127 |
+
st.session_state.patient = selected_patient
|
128 |
+
st.session_state.name = ":blue[" + first_name + "]"
|
129 |
+
st.session_state.lastname = last_name
|
130 |
+
st.session_state.gender_sign = gender_sign
|
131 |
+
|
132 |
+
patient = st.session_state.patient
|
133 |
+
name = st.session_state.name
|
134 |
+
lastname = st.session_state.lastname
|
135 |
+
gender_sign = st.session_state.gender_sign
|
136 |
+
|
137 |
+
mimic_df_patient = mimic_df[mimic_df['patient_id'] == selected_patient] # select all the rows with the selected patient
|
138 |
+
|
139 |
+
for i in range(len(mimic_df_patient)):
|
140 |
+
if i == len(mimic_df_patient) - 1:
|
141 |
+
last_visit = mimic_df_patient.iloc[[i]]
|
142 |
+
|
143 |
+
# ---- Patient info ----
|
144 |
+
# st.subheader(":blue[DASHBOARD OF] ")
|
145 |
+
st.warning("🚨 **NOTE** 🚨: The patient's name, shown below, was randomly generated for demonstration purposes.")
|
146 |
+
st.title("{} {} :{}:".format(name, lastname, gender_sign))
|
147 |
+
st.caption("Patient n°: {} - Gender: {} - Ethnicity: {}".format(patient, patient_info.gender, patient_info.ethnicity))
|
148 |
+
|
149 |
+
l1, r1 = st.columns([0.44, 0.56])
|
150 |
+
|
151 |
+
with l1:
|
152 |
+
st.subheader("📋 Medical history")
|
153 |
+
# st.caption("The following table shows the *complete* medical history of the patient n°: **{}**.".format(patient))
|
154 |
+
|
155 |
+
visit = st.selectbox(label='🏥 __Hospital admission__ n°: ', options=mimic_df_patient['visit_id'].unique())
|
156 |
+
if visit:
|
157 |
+
mimic_df_patient_visit = mimic_df_patient[mimic_df_patient['visit_id'] == visit] # select all the rows with the selected visit
|
158 |
+
if task == "medications":
|
159 |
+
mimic_df_patient_visit_filtered = mimic_df_patient_visit.drop(columns=['visit_id', 'patient_id', 'drugs_hist'])
|
160 |
+
elif task == "diagnosis":
|
161 |
+
mimic_df_patient_visit_filtered = mimic_df_patient_visit.drop(columns=['visit_id', 'patient_id'])
|
162 |
+
|
163 |
+
atc = InnerMap.load("ATC")
|
164 |
+
icd9 = InnerMap.load("ICD9CM")
|
165 |
+
icd9_proc = InnerMap.load("ICD9PROC")
|
166 |
+
|
167 |
+
for column in mimic_df_patient_visit_filtered.columns:
|
168 |
+
with st.expander("{}".format(column)):
|
169 |
+
try:
|
170 |
+
if column == "medications":
|
171 |
+
if task == "medications":
|
172 |
+
med_history = [[med, atc.lookup(med)] for med in mimic_df_patient_visit_filtered[column].explode() if med]
|
173 |
+
elif task == "diagnosis":
|
174 |
+
med_history = [[med, atc.lookup(med)] for med in (mimic_df_patient_visit_filtered[column].explode()).explode() if med]
|
175 |
+
st.dataframe(med_history, hide_index=True, column_config={"0": "ATC", "1": "Description"})
|
176 |
+
elif column == "diagnosis":
|
177 |
+
if task == "medications":
|
178 |
+
col_history = [[idx, icd9.lookup(idx)] for idx in (mimic_df_patient_visit_filtered[column].explode()).explode() if idx]
|
179 |
+
elif task == "diagnosis":
|
180 |
+
col_history = [[idx+'0', icd9.lookup(idx+'0')] if idx.startswith('E') else [idx, icd9.lookup(idx)] for idx in mimic_df_patient_visit_filtered[column].explode() if idx]
|
181 |
+
st.dataframe(col_history, hide_index=True, column_config={"0": "ICD9", "1": "Description"})
|
182 |
+
elif column == "symptoms":
|
183 |
+
col_history = [[idx, icd9.lookup(idx)] for idx in (mimic_df_patient_visit_filtered[column].explode()).explode() if idx]
|
184 |
+
st.dataframe(col_history, hide_index=True, column_config={"0": "ICD9", "1": "Description"})
|
185 |
+
elif column == "procedures":
|
186 |
+
col_history = [[idx, icd9_proc.lookup(idx)] for idx in (mimic_df_patient_visit_filtered[column].explode()).explode() if idx]
|
187 |
+
st.dataframe(col_history, hide_index=True, column_config={"0": "ICD9", "1": "Description"})
|
188 |
+
except:
|
189 |
+
st.write("No data available for this column.")
|
190 |
+
|
191 |
+
st.subheader(f"🧾 Recommended _{task}_")
|
192 |
+
st.caption(f"""The following {task} are recommended for the patient during the **hospital admission n°: \
|
193 |
+
{format(last_visit['visit_id'].item())}**. \n The recommendations are based on the \
|
194 |
+
output probabilities generated by the **GNN (_Graph Neural Network_)** model.""")
|
195 |
+
|
196 |
+
if task == "medications":
|
197 |
+
model_med_ig.load_state_dict(checkpoint_MED)
|
198 |
+
model_med_gnn.load_state_dict(checkpoint_MED)
|
199 |
+
model = model_med_ig
|
200 |
+
elif task == "diagnosis":
|
201 |
+
model_diag_ig.load_state_dict(checkpoint_DIAG)
|
202 |
+
model_diag_gnn.load_state_dict(checkpoint_DIAG)
|
203 |
+
model = model_diag_ig
|
204 |
+
|
205 |
+
# ---- Output model ----
|
206 |
+
model.eval()
|
207 |
+
output = model(last_visit['patient_id'],
|
208 |
+
last_visit['visit_id'],
|
209 |
+
last_visit['diagnosis'],
|
210 |
+
last_visit['procedures'],
|
211 |
+
last_visit['symptoms'],
|
212 |
+
last_visit['medications'])
|
213 |
+
|
214 |
+
list_output, list_indices = get_list_output(output['y_prob'], last_visit, task, mimic3sample)
|
215 |
+
list_output = [[idx, item] for idx, item in zip(*list_indices, *list_output) if item]
|
216 |
+
st.dataframe(list_output, column_config={"0": "ID", "1": f"Recommended {task}"}, height=None, width=None)
|
217 |
+
|
218 |
+
with r1:
|
219 |
+
st.subheader(f"""🗣 *Why* did the model recommend these {task}?""")
|
220 |
+
r1l1, r1c1, r1r1 = st.columns(3)
|
221 |
+
with r1l1:
|
222 |
+
visualization = st.radio("Visualization", options=["Explainable", "Interpretable"], horizontal=True)
|
223 |
+
with r1c1:
|
224 |
+
algorithm = st.radio("Algorithm", options=["IG", "GNNExplainer"], horizontal=True)
|
225 |
+
with r1r1:
|
226 |
+
threshold = st.slider("Threshold", min_value=10, max_value=50, value=15, step=5, format=None, key=None)
|
227 |
+
|
228 |
+
if task == "medications" and algorithm == "IG":
|
229 |
+
model = model_med_ig
|
230 |
+
elif task == "medications" and algorithm == "GNNExplainer":
|
231 |
+
model = model_med_gnn
|
232 |
+
elif task == "diagnosis" and algorithm == "IG":
|
233 |
+
model = model_diag_ig
|
234 |
+
elif task == "diagnosis" and algorithm == "GNNExplainer":
|
235 |
+
model = model_diag_gnn
|
236 |
+
|
237 |
+
st.caption(f"""The graph shown as follows provides an interpretation of the model's decision making process on the recommended \
|
238 |
+
*{task}* for the patient during the **hospital admission n°: {format(last_visit['visit_id'].item())}**. \
|
239 |
+
\n\n The interpretability is based on the **{algorithm} (_{task}_)** algorithm.""")
|
240 |
+
options = [item[1] for item in list_output if item]
|
241 |
+
selected_label = st.selectbox(f'Select the {task} to explain', index=None,
|
242 |
+
placeholder=f"Choice a {task} from Recommended {task} ranking to explain",
|
243 |
+
options=options)
|
244 |
+
|
245 |
+
if selected_label is None:
|
246 |
+
st.stop()
|
247 |
+
|
248 |
+
selected_idx = [item[0] for item in list_output if item[1] == selected_label]
|
249 |
+
|
250 |
+
st.caption("Legend of the graph:")
|
251 |
+
col1, col2, col3, col4, col5, col6, col7, col8 = st.columns([0.1, 0.3, 0.1, 0.3, 0.1, 0.3, 0.1, 0.3])
|
252 |
+
|
253 |
+
with col1:
|
254 |
+
st.markdown(
|
255 |
+
"""
|
256 |
+
<style>
|
257 |
+
#square1 {
|
258 |
+
width: 20px;
|
259 |
+
height: 20px;
|
260 |
+
background: #20b2aa;
|
261 |
+
border-radius: 3px;
|
262 |
+
}
|
263 |
+
</style>
|
264 |
+
<div id="square1"></div>
|
265 |
+
""",
|
266 |
+
unsafe_allow_html=True,
|
267 |
+
)
|
268 |
+
|
269 |
+
st.markdown(
|
270 |
+
"""
|
271 |
+
<style>
|
272 |
+
#square2 {
|
273 |
+
width: 20px;
|
274 |
+
height: 20px;
|
275 |
+
background: #fa8072;
|
276 |
+
border-radius: 3px;
|
277 |
+
margin-top: 20px;
|
278 |
+
}
|
279 |
+
</style>
|
280 |
+
<div id="square2"></div>
|
281 |
+
""",
|
282 |
+
unsafe_allow_html=True,
|
283 |
+
)
|
284 |
+
|
285 |
+
with col2:
|
286 |
+
st.caption("Patient")
|
287 |
+
|
288 |
+
st.caption("Visit")
|
289 |
+
|
290 |
+
with col3:
|
291 |
+
st.markdown(
|
292 |
+
"""
|
293 |
+
<style>
|
294 |
+
#square3 {
|
295 |
+
width: 20px;
|
296 |
+
height: 20px;
|
297 |
+
background: #cd853f;
|
298 |
+
border-radius: 3px;
|
299 |
+
}
|
300 |
+
</style>
|
301 |
+
<div id="square3"></div>
|
302 |
+
""",
|
303 |
+
unsafe_allow_html=True,
|
304 |
+
)
|
305 |
+
st.markdown(
|
306 |
+
"""
|
307 |
+
<style>
|
308 |
+
#square4 {
|
309 |
+
width: 20px;
|
310 |
+
height: 20px;
|
311 |
+
background: #da70d6;
|
312 |
+
border-radius: 3px;
|
313 |
+
margin-top: 20px;
|
314 |
+
}
|
315 |
+
</style>
|
316 |
+
<div id="square4"></div>
|
317 |
+
""",
|
318 |
+
unsafe_allow_html=True,
|
319 |
+
)
|
320 |
+
|
321 |
+
with col4:
|
322 |
+
st.caption("Diagnosis")
|
323 |
+
st.caption("Procedures")
|
324 |
+
|
325 |
+
with col5:
|
326 |
+
st.markdown(
|
327 |
+
"""
|
328 |
+
<style>
|
329 |
+
#square5 {
|
330 |
+
width: 20px;
|
331 |
+
height: 20px;
|
332 |
+
background: #98fb98;
|
333 |
+
border-radius: 3px;
|
334 |
+
}
|
335 |
+
</style>
|
336 |
+
<div id="square5"></div>
|
337 |
+
""",
|
338 |
+
unsafe_allow_html=True,
|
339 |
+
)
|
340 |
+
|
341 |
+
with col6:
|
342 |
+
st.caption("Symptoms")
|
343 |
+
|
344 |
+
with col7:
|
345 |
+
st.markdown(
|
346 |
+
"""
|
347 |
+
<style>
|
348 |
+
#square6 {
|
349 |
+
width: 20px;
|
350 |
+
height: 20px;
|
351 |
+
background: #87ceeb;
|
352 |
+
border-radius: 3px;
|
353 |
+
}
|
354 |
+
</style>
|
355 |
+
<div id="square6"></div>
|
356 |
+
""",
|
357 |
+
unsafe_allow_html=True,
|
358 |
+
)
|
359 |
+
|
360 |
+
with col8:
|
361 |
+
st.caption("Medications")
|
362 |
+
|
363 |
+
explain_sample = {}
|
364 |
+
for visit_sample in mimic3sample.samples:
|
365 |
+
if visit_sample['patient_id'] == patient and visit_sample['visit_id'] == last_visit['visit_id'].item():
|
366 |
+
if visit_sample.get('drugs_hist') != None:
|
367 |
+
del visit_sample['drugs_hist']
|
368 |
+
explain_sample['test'] = visit_sample
|
369 |
+
|
370 |
+
model.eval()
|
371 |
+
explain_dataset = SampleEHRDataset(list(explain_sample.values()), code_vocs="ATC")
|
372 |
+
explainability(model, explain_dataset, selected_idx[0], visualization, algorithm, task, threshold)
|
373 |
+
|
374 |
+
|
375 |
+
####################### CARE AI module ##################################
|
376 |
+
st.header('🩺🧠 Medical Agents Evaluation')
|
377 |
+
st.caption("The section shown as follows is dedicated to the Explainability module, which is responsible for generating the analysis of the doctors' proposals and the collaborative discussion between the medical team members for the final decision on the patient's treatment.")
|
378 |
+
|
379 |
+
model_name = st.selectbox("Select the LLM model", options=["gpt-4o-mini", "gpt-3.5-turbo", "google/gemma-2-9b-it", "meta/llama3-8b-instruct", "mistralai/mistral-7b-instruct-v0.2"])
|
380 |
+
|
381 |
+
explanation = st.button("Generate explanation")
|
382 |
+
if not(explanation):
|
383 |
+
st.stop()
|
384 |
+
|
385 |
+
col1, col2 = st.columns([0.5, 0.6], gap="large")
|
386 |
+
|
387 |
+
with col1:
|
388 |
+
with open("streamlit_results/medical_scenario.txt", "r") as f:
|
389 |
+
medical_scenario = f.read()
|
390 |
+
st.subheader("📄 Medical Scenario")
|
391 |
+
st.caption(f"The scenario shown as follows for the patient in the **hospital admission n°: {format(last_visit['visit_id'].item())}** is provided by the medical team.")
|
392 |
+
st.markdown('###')
|
393 |
+
with st.expander("👁️ Read the medical scenario", expanded=True):
|
394 |
+
container = st.container(height=1145)
|
395 |
+
container.write(medical_scenario)
|
396 |
+
|
397 |
+
with col2:
|
398 |
+
st.subheader("👨⚕️🔎 Doctor Recruiter")
|
399 |
+
st.caption("The doctor recruiter is responsible for recruiting the medical team to help the internist doctor make a final decision on the patient's during the collaborative discussion.")
|
400 |
+
with st.status("Recruiting doctor...", expanded=False) as status:
|
401 |
+
with open("streamlit_results/prompt_recruiter_doctors.txt", "r") as f:
|
402 |
+
prompt_recruiter_doctors = f.read()
|
403 |
+
text = doctor_recruiter(prompt_recruiter_doctors, model_name)
|
404 |
+
if model_name == "meta/llama3-8b-instruct":
|
405 |
+
text[0] = text[0].split("Here is the JSON file:\n\n")[1]
|
406 |
+
json_data = json.loads(str(text[0]))
|
407 |
+
with open("streamlit_results/recruited_doctors.json", "w") as f:
|
408 |
+
json.dump(text[0], f, indent=4)
|
409 |
+
|
410 |
+
for i, doctor in enumerate(json_data['doctors']):
|
411 |
+
role = f"""**🥼 {doctor['role'].replace("_", " ")}**"""
|
412 |
+
st.markdown(role)
|
413 |
+
st.write(doctor['description'])
|
414 |
+
if i != len(json_data['doctors'])-1:
|
415 |
+
st.divider()
|
416 |
+
|
417 |
+
status.update(label="Doctor recruited!", state="complete", expanded=True)
|
418 |
+
st.button('Rerun')
|
419 |
+
|
420 |
+
st.subheader("Analysis Proposition")
|
421 |
+
with st.spinner("Doctors are thinking..."):
|
422 |
+
with open("streamlit_results/prompt_internist_doctor.txt", "r") as f:
|
423 |
+
prompt_internist_doctor = f.read()
|
424 |
+
|
425 |
+
prompt_reunion = f"""Based on your assessment and the medical team's recommendations regarding {task} during the patient visit:\n"""
|
426 |
+
prompt_reunion += f"""Confront with your medical colleagues, highlighting relevant aspects related to the patient's condition and the {task}. Underline the crucial elements that influence your decision on its justification or unjustification in 30 words.\n"""
|
427 |
+
prompt_reunion += f"""\nAnalysis of doctors' proposals\n\n"""
|
428 |
+
|
429 |
+
for i in range(len(json_data['doctors'])):
|
430 |
+
with st.status(f"The 👨⚕️ {json_data['doctors'][i]['role'].replace('_', ' ')} is analysing ...", expanded=False) as status_doc:
|
431 |
+
with st.chat_message(name="user", avatar="streamlit_images/{}.png".format(i)):
|
432 |
+
analysis = """"""
|
433 |
+
analysis += f"""**Doctor**: {json_data['doctors'][i]['role'].replace(" ", "_")}\n\n"""
|
434 |
+
text = doctor_discussion(json_data['doctors'][i]['role'], prompt_internist_doctor, model_name)
|
435 |
+
analysis += "**Analysis**: " + text[0]
|
436 |
+
st.markdown(f"**Analysis**: {text[0]}")
|
437 |
+
status_doc.update(label="The 👨⚕️ {} analysed!".format(json_data['doctors'][i]['role'].replace('_', ' ')), state="complete", expanded=True)
|
438 |
+
prompt_reunion += f"""{analysis}"""
|
439 |
+
prompt_reunion += f"\n--------------------------------------------------\n\n"
|
440 |
+
|
441 |
+
image, text = st.columns([0.2, 0.8])
|
442 |
+
with image:
|
443 |
+
st.image("streamlit_images/collaborative.png")
|
444 |
+
with text:
|
445 |
+
st.subheader('Discussion')
|
446 |
+
|
447 |
+
st.caption("The discussion shown as follows is based on the **Large Language Model** (LLM) **chosen**. The LLM is responsible for generating the discussion between the medical team members for the final decision on the patient's treatment.")
|
448 |
+
with st.spinner("Doctors are discussing..."):
|
449 |
+
internist_sys_message = f"""As an INTERNIST DOCTOR, you have the task of globally evaluating and managing the patient's health and pathology.\n"""
|
450 |
+
internist_sys_message += f"""In the light of the entire discussion, you must provide a final schematic report to the doctor based on the recommendation and the doctors' opinions."""
|
451 |
+
|
452 |
+
doc = multiagent_doctors(json_data, model_name)
|
453 |
+
manager = care_discussion_start(doc, prompt_reunion, internist_sys_message, model_name)
|
454 |
+
|
455 |
+
with st.chat_message(name="user", avatar="streamlit_images/internist.png"):
|
456 |
+
internist = list(manager.chat_messages.values())
|
457 |
+
internist_opinion = internist[0][6]['content']
|
458 |
+
st.write(f"**{internist[0][6]['name'].replace('_',' ')}**: {internist_opinion}")
|
459 |
+
|
460 |
+
# Add a download button:
|
461 |
+
st.download_button(
|
462 |
+
label="Download PDF",
|
463 |
+
data=gen_pdf(patient, name, lastname, last_visit['visit_id'].item(), list_output, medical_scenario, internist_opinion),
|
464 |
+
file_name=f"Medical_Report_Patient_{patient}.pdf",
|
465 |
+
mime="application/pdf",
|
466 |
+
)
|
467 |
+
|
468 |
+
|
469 |
+
if __name__ == "__main__":
|
470 |
+
main()
|
font/OpenSans-Bold.ttf
ADDED
Binary file (131 kB). View file
|
|
font/OpenSans-Italic.ttf
ADDED
Binary file (137 kB). View file
|
|
font/OpenSans.ttf
ADDED
Binary file (131 kB). View file
|
|
model/diagnosis_prediction/best.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5be49c0eb4f53152c76d86203a492b5baea5f81081d710fd2671e051bae2db3e
|
3 |
+
size 6773953
|
model/diagnosis_prediction/last.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bf971bb5a3ac59951dc94d3fb4e82375d7f5f715988fc700d9644d9f1016359b
|
3 |
+
size 6773953
|
model/medication_recommendation/best.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:19ed90e26be751ea72cffd0da6562b799f8ec63e2e34721414aa333b4929992b
|
3 |
+
size 7673537
|
model/medication_recommendation/last.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:602ccd50078f050c87fc20fb0caecbc836f03dd1e74e7d9586dcd4ff7edf0230
|
3 |
+
size 7673537
|
model_utils.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import streamlit.components.v1 as components
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from typing import Tuple, List
|
8 |
+
from fpdf import FPDF
|
9 |
+
|
10 |
+
from pyhealth.medcode import InnerMap
|
11 |
+
from pyhealth.datasets import MIMIC3Dataset, SampleEHRDataset
|
12 |
+
from pyhealth.tasks import medication_recommendation_mimic3_fn, diagnosis_prediction_mimic3_fn
|
13 |
+
from pyhealth.models import GNN
|
14 |
+
from pyhealth.explainer import HeteroGraphExplainer
|
15 |
+
|
16 |
+
|
17 |
+
@st.cache_resource(hash_funcs={torch.nn.parameter.Parameter: lambda _: None})
|
18 |
+
def load_gnn() -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.nn.Module,
|
19 |
+
MIMIC3Dataset, SampleEHRDataset, SampleEHRDataset]:
|
20 |
+
dataset = MIMIC3Dataset(
|
21 |
+
root=st.secrets.s3,
|
22 |
+
tables=["DIAGNOSES_ICD","PROCEDURES_ICD","PRESCRIPTIONS","NOTEEVENTS_ICD"],
|
23 |
+
code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 4}})},
|
24 |
+
)
|
25 |
+
|
26 |
+
mimic3sample_med = dataset.set_task(task_fn=medication_recommendation_mimic3_fn)
|
27 |
+
mimic3sample_diag = dataset.set_task(task_fn=diagnosis_prediction_mimic3_fn)
|
28 |
+
|
29 |
+
model_med_ig = GNN(
|
30 |
+
dataset=mimic3sample_med,
|
31 |
+
convlayer="GraphConv",
|
32 |
+
feature_keys=["procedures", "diagnosis", "symptoms"],
|
33 |
+
label_key="medications",
|
34 |
+
k=0,
|
35 |
+
embedding_dim=128,
|
36 |
+
hidden_channels=128
|
37 |
+
)
|
38 |
+
|
39 |
+
model_med_gnn = GNN(
|
40 |
+
dataset=mimic3sample_med,
|
41 |
+
convlayer="GraphConv",
|
42 |
+
feature_keys=["procedures", "diagnosis", "symptoms"],
|
43 |
+
label_key="medications",
|
44 |
+
k=0,
|
45 |
+
embedding_dim=128,
|
46 |
+
hidden_channels=128
|
47 |
+
)
|
48 |
+
|
49 |
+
model_diag_ig = GNN(
|
50 |
+
dataset=mimic3sample_diag,
|
51 |
+
convlayer="GraphConv",
|
52 |
+
feature_keys=["procedures", "medications", "symptoms"],
|
53 |
+
label_key="diagnosis",
|
54 |
+
k=0,
|
55 |
+
embedding_dim=128,
|
56 |
+
hidden_channels=128
|
57 |
+
)
|
58 |
+
|
59 |
+
model_diag_gnn = GNN(
|
60 |
+
dataset=mimic3sample_diag,
|
61 |
+
convlayer="GraphConv",
|
62 |
+
feature_keys=["procedures", "medications", "symptoms"],
|
63 |
+
label_key="diagnosis",
|
64 |
+
k=0,
|
65 |
+
embedding_dim=128,
|
66 |
+
hidden_channels=128
|
67 |
+
)
|
68 |
+
|
69 |
+
return model_med_ig, model_med_gnn, model_diag_ig, model_diag_gnn, dataset, mimic3sample_med, mimic3sample_diag
|
70 |
+
|
71 |
+
|
72 |
+
@st.cache_data(hash_funcs={torch.Tensor: lambda _: None})
|
73 |
+
def get_list_output(y_prob: torch.Tensor, last_visit: pd.DataFrame, task: str, _mimic3sample: SampleEHRDataset,
|
74 |
+
top_k: int = 10) -> List[str]:
|
75 |
+
sorted_indices = []
|
76 |
+
for i in range(len(y_prob)):
|
77 |
+
top_indices = np.argsort(-y_prob[i, :])[:top_k]
|
78 |
+
sorted_indices.append(top_indices)
|
79 |
+
|
80 |
+
list_output = []
|
81 |
+
|
82 |
+
# get the list of all labels in the dataset
|
83 |
+
if task == "medications":
|
84 |
+
list_labels = _mimic3sample.get_all_tokens('medications')
|
85 |
+
atc = InnerMap.load("ATC")
|
86 |
+
elif task == "diagnosis":
|
87 |
+
list_labels = _mimic3sample.get_all_tokens('diagnosis')
|
88 |
+
icd9 = InnerMap.load("ICD9CM")
|
89 |
+
|
90 |
+
sorted_indices = list(sorted_indices)
|
91 |
+
# iterate over the top indexes for each sample in test_ds
|
92 |
+
for (i, sample), top in zip(last_visit.iterrows(), sorted_indices):
|
93 |
+
# create an empty list to store the recommended medications for this sample
|
94 |
+
sample_list_output = []
|
95 |
+
|
96 |
+
# iterate over the top indexes for this sample
|
97 |
+
for k in top:
|
98 |
+
# append the medication at the i-th index to the recommended medications list for this sample
|
99 |
+
if task == "medications":
|
100 |
+
sample_list_output.append(atc.lookup(list_labels[k]))
|
101 |
+
elif task == "diagnosis":
|
102 |
+
if list_labels[k].startswith("E"):
|
103 |
+
list_labels[k] = list_labels[k] + "0"
|
104 |
+
sample_list_output.append(icd9.lookup(list_labels[k]))
|
105 |
+
|
106 |
+
# append the recommended medications for this sample to the recommended medications list
|
107 |
+
list_output.append(sample_list_output)
|
108 |
+
|
109 |
+
return list_output, sorted_indices
|
110 |
+
|
111 |
+
|
112 |
+
def explainability(model: GNN, explain_dataset: SampleEHRDataset, selected_idx: int,
|
113 |
+
visualization: str, algorithm: str, task: str, threshold: int):
|
114 |
+
explainer = HeteroGraphExplainer(
|
115 |
+
algorithm=algorithm,
|
116 |
+
dataset=explain_dataset,
|
117 |
+
model=model,
|
118 |
+
label_key=task,
|
119 |
+
threshold_value=threshold,
|
120 |
+
top_k=threshold,
|
121 |
+
feat_size=128,
|
122 |
+
root="./streamlit_results/",
|
123 |
+
)
|
124 |
+
|
125 |
+
if task == "medications":
|
126 |
+
visit_drug = explainer.subgraph['visit', 'medication'].edge_index
|
127 |
+
visit_drug = visit_drug.T
|
128 |
+
|
129 |
+
n = 0
|
130 |
+
for vis_drug in visit_drug:
|
131 |
+
vis_drug = np.array(vis_drug)
|
132 |
+
if vis_drug[1] == selected_idx:
|
133 |
+
break
|
134 |
+
n += 1
|
135 |
+
elif task == "diagnosis":
|
136 |
+
visit_diag = explainer.subgraph['visit', 'diagnosis'].edge_index
|
137 |
+
visit_diag = visit_diag.T
|
138 |
+
|
139 |
+
n = 0
|
140 |
+
for vis_diag in visit_diag:
|
141 |
+
vis_diag = np.array(vis_diag)
|
142 |
+
if vis_diag[1] == selected_idx:
|
143 |
+
break
|
144 |
+
n += 1
|
145 |
+
|
146 |
+
explainer.explain(n=n)
|
147 |
+
if visualization == "Explainable":
|
148 |
+
explainer.explain_graph(k=0, human_readable=True, dashboard=True)
|
149 |
+
else:
|
150 |
+
explainer.explain_graph(k=0, human_readable=False, dashboard=True)
|
151 |
+
|
152 |
+
explainer.explain_results(n=n)
|
153 |
+
explainer.explain_results(n=n, doctor_type="Internist_Doctor")
|
154 |
+
|
155 |
+
HtmlFile = open("streamlit_results/explain_graph.html", 'r', encoding='utf-8')
|
156 |
+
source_code = HtmlFile.read()
|
157 |
+
components.html(source_code, height=520)
|
158 |
+
|
159 |
+
|
160 |
+
def gen_pdf(patient, name, lastname, visit, list_output, medical_scenario, internist_scenario):
|
161 |
+
pdf = FPDF()
|
162 |
+
pdf.add_page()
|
163 |
+
pdf.add_font("OpenSans", style="", fname="font/OpenSans.ttf")
|
164 |
+
pdf.add_font("OpenSans", style="B", fname="font/OpenSans-Bold.ttf")
|
165 |
+
|
166 |
+
# Title
|
167 |
+
pdf.set_font("OpenSans", 'B', 14)
|
168 |
+
pdf.cell(0, 10, 'Patient Medical Report', 0, 1, 'C', markdown=True)
|
169 |
+
pdf.ln(5)
|
170 |
+
|
171 |
+
# Patient Info
|
172 |
+
pdf.set_font("OpenSans", 'B', 10)
|
173 |
+
pdf.cell(0, 10, 'Patient Information', 0, 1, 'L', markdown=True)
|
174 |
+
pdf.set_font("OpenSans", '', 8)
|
175 |
+
pdf.cell(0, 3, f"Patient ID: **{patient}** - Name: **{name.split('[')[1].split(']')[0]}** Surname: **{lastname}** - Hospital admission n°: **{visit}**", 0, 1, 'L', markdown=True)
|
176 |
+
pdf.ln(5)
|
177 |
+
|
178 |
+
# Left column (Medical Scenario)
|
179 |
+
left_x = 10
|
180 |
+
right_x = 110
|
181 |
+
col_width = 90
|
182 |
+
|
183 |
+
# Right column (Recommendations)
|
184 |
+
pdf.set_xy(right_x, pdf.get_y())
|
185 |
+
pdf.set_font("OpenSans", 'B', 10)
|
186 |
+
pdf.cell(col_width - 20, 10, 'Recommendations', 0, 1, 'L')
|
187 |
+
pdf.set_xy(right_x, pdf.get_y())
|
188 |
+
pdf.set_font("OpenSans", '', 8)
|
189 |
+
for i, output in enumerate(list_output):
|
190 |
+
tensor_value = output[0].item() # Convert tensor to number
|
191 |
+
recommendation = output[1]
|
192 |
+
pdf.set_xy(right_x, pdf.get_y())
|
193 |
+
pdf.cell(col_width - 20, 3, f"Medication {i+1}: {tensor_value}, {recommendation}", 0, 1, 'L')
|
194 |
+
|
195 |
+
|
196 |
+
# Medical Scenario
|
197 |
+
pdf.set_xy(left_x, pdf.get_y() - 40)
|
198 |
+
pdf.set_font("OpenSans", 'B', 10)
|
199 |
+
pdf.cell(col_width, 10, 'Medical Scenario', 0, 1, 'L', markdown=True)
|
200 |
+
pdf.set_xy(left_x, pdf.get_y())
|
201 |
+
pdf.set_font("OpenSans", '', 8)
|
202 |
+
pdf.multi_cell(col_width, 3, medical_scenario, 0, 'L', markdown=True)
|
203 |
+
|
204 |
+
# internist_scenario
|
205 |
+
pdf.set_xy(left_x, pdf.get_y())
|
206 |
+
pdf.set_font("OpenSans", 'B', 10)
|
207 |
+
pdf.cell(0, 10, 'Internist Scenario', 0, 1, 'L', markdown=True)
|
208 |
+
pdf.set_font("OpenSans", '', 8)
|
209 |
+
pdf.multi_cell(0, 3, internist_scenario, 0, 'L', markdown=True)
|
210 |
+
pdf.ln(5)
|
211 |
+
|
212 |
+
return bytes(pdf.output())
|
requirements.txt
ADDED
Binary file (4.74 kB). View file
|
|
static-kg/ANAT_DIAG.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eca33af4263adec32d995f90b0cfee696d3f1fd30046037e12aeb9111d075fab
|
3 |
+
size 42197
|
static-kg/ATC3/DRUG_DIAG.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:92062a024d9249185b691d3aa2ce582f775dd2fc60a38c181ca0ca1dc7d66adc
|
3 |
+
size 43241
|
static-kg/ATC3/PC_DRUG.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7e961e24caa0f2ab7b551e9f543b3bdc6d67185c8f72aec0150f473583a9c927
|
3 |
+
size 24293
|
static-kg/ATC3/SYMP_DRUG.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dd000e18a437afad1380208b7b5a1c64973e7a662c402e0e98a5ec6ba990c9d2
|
3 |
+
size 147177
|
static-kg/DIAG_SYMP.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cdc8140412f83c08fc551f8d22ae3512f51032267920c5b5cd1a94e5a369b954
|
3 |
+
size 16005
|
static-kg/DRUG_DIAG.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2792c214c680bc127f4e7157c51ecaadcd4e2687bcad0c46edbe9f9ed0b45d23
|
3 |
+
size 45636
|
static-kg/PC_DRUG.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:928ecd97a131b5f7232530ad7dec556e0b13a9c3934489106140b2635e410521
|
3 |
+
size 25264
|
static-kg/SYMP_DRUG.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:937c413c1a2128c09acdee88c83c625690890eb137eaf4d7cd57b805ce690b1d
|
3 |
+
size 158485
|
streamlit_images/0.png
ADDED
![]() |
Git LFS Details
|
streamlit_images/1.png
ADDED
![]() |
Git LFS Details
|
streamlit_images/2.png
ADDED
![]() |
Git LFS Details
|
streamlit_images/3.png
ADDED
![]() |
Git LFS Details
|
streamlit_images/4.png
ADDED
![]() |
Git LFS Details
|
streamlit_images/Internist.png
ADDED
![]() |
Git LFS Details
|
streamlit_images/collaborative.png
ADDED
![]() |
Git LFS Details
|