Commit
·
d4ca2d2
1
Parent(s):
dc3f347
Initial commit
Browse files- .gitattributes +3 -0
- .gitignore +21 -0
- .streamlit/config.toml +5 -0
- README.md +4 -4
- app.py +142 -0
- data/kg_edge_types.csv +3 -0
- data/kg_edges.csv +3 -0
- data/kg_node_types.csv +3 -0
- data/kg_nodes.csv +3 -0
- media/about_header.svg +1 -0
- media/explore_header.svg +1 -0
- media/gravity_logo.png +0 -0
- media/gravity_logo.svg +1 -0
- media/input_header.svg +1 -0
- media/pfp/anoori.png +3 -0
- media/pfp/gravity.png +3 -0
- media/pfp/mzitnik.png +3 -0
- media/pfp/ndagan.png +3 -0
- media/pfp/rbalicer.png +3 -0
- media/predict_header.svg +1 -0
- media/validate_header.svg +1 -0
- menu.py +51 -0
- pages/about.py +30 -0
- pages/admin.py +13 -0
- pages/input.py +253 -0
- pages/predict.py +274 -0
- pages/split.py +114 -0
- pages/validate.py +124 -0
- project_config.py +44 -0
- requirements.txt +11 -0
- sync_data.sh +51 -0
- utils.py +28 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
media/pfp/*.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
data/*.csv filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.csv filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ignore Mac temporary files
|
2 |
+
*.DS_Store
|
3 |
+
.DS_Store
|
4 |
+
|
5 |
+
# Ignore python cache files
|
6 |
+
__pycache__/
|
7 |
+
|
8 |
+
# Ignore code
|
9 |
+
code/*
|
10 |
+
|
11 |
+
# Ignore model files
|
12 |
+
data/*.pt
|
13 |
+
data/disease_splits/*
|
14 |
+
models/embeddings/*
|
15 |
+
models/checkpoints/*
|
16 |
+
|
17 |
+
# Ignore secrets
|
18 |
+
.streamlit/secrets.toml
|
19 |
+
|
20 |
+
# Ignore user DB
|
21 |
+
auth/*
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[client]
|
2 |
+
showSidebarNavigation = false
|
3 |
+
|
4 |
+
[theme]
|
5 |
+
base="light"
|
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
title: Clinical Drug Repurposing
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
-
sdk_version: 1.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
1 |
---
|
2 |
title: Clinical Drug Repurposing
|
3 |
+
emoji: ⚕️
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: purple
|
6 |
sdk: streamlit
|
7 |
+
sdk_version: 1.34.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
app.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
# User authentication
|
4 |
+
import gspread
|
5 |
+
from oauth2client.service_account import ServiceAccountCredentials
|
6 |
+
import hmac
|
7 |
+
|
8 |
+
# Standard imports
|
9 |
+
import pandas as pd
|
10 |
+
|
11 |
+
# Custom and other imports
|
12 |
+
import project_config
|
13 |
+
from menu import menu
|
14 |
+
|
15 |
+
# Initialize st.session_state.role to None
|
16 |
+
if "role" not in st.session_state:
|
17 |
+
st.session_state.role = None
|
18 |
+
|
19 |
+
|
20 |
+
# From https://stackoverflow.com/questions/55961295/serviceaccountcredentials-from-json-keyfile-name-equivalent-for-remote-json
|
21 |
+
# See also https://www.slingacademy.com/article/pandas-how-to-read-and-update-google-sheet-files/
|
22 |
+
# See also https://docs.streamlit.io/develop/tutorials/databases/private-gsheet
|
23 |
+
# Note that the secrets cannot be passed in a group in HuggingFace Spaces,
|
24 |
+
# which is required for the native Streamlit implementation
|
25 |
+
def create_keyfile_dict():
|
26 |
+
variables_keys = {
|
27 |
+
# "spreadsheet": st.secrets['spreadsheet'], # spreadsheet
|
28 |
+
"type": st.secrets['type'], # type
|
29 |
+
"project_id": st.secrets['project_id'], # project_id
|
30 |
+
"private_key_id": st.secrets['private_key_id'], # private_key_id
|
31 |
+
# Have to replace \n with new lines (^l in Word) by hand
|
32 |
+
"private_key": st.secrets['private_key'], # private_key
|
33 |
+
"client_email": st.secrets['client_email'], # client_email
|
34 |
+
"client_id": st.secrets['client_id'], # client_id
|
35 |
+
"auth_uri": st.secrets['auth_uri'], # auth_uri
|
36 |
+
"token_uri": st.secrets['token_uri'], # token_uri
|
37 |
+
"auth_provider_x509_cert_url": st.secrets['auth_provider_x509_cert_url'], # auth_provider_x509_cert_url
|
38 |
+
"client_x509_cert_url": st.secrets['client_x509_cert_url'], # client_x509_cert_url
|
39 |
+
"universe_domain": st.secrets['universe_domain'] # universe_domain
|
40 |
+
}
|
41 |
+
return variables_keys
|
42 |
+
|
43 |
+
|
44 |
+
def check_password():
|
45 |
+
"""Returns `True` if the user had a correct password."""
|
46 |
+
|
47 |
+
def login_form():
|
48 |
+
"""Form with widgets to collect user information"""
|
49 |
+
# Header
|
50 |
+
col1, col2, col3 = st.columns(3)
|
51 |
+
with col2:
|
52 |
+
st.image(str(project_config.MEDIA_DIR / 'gravity_logo.svg'), width=300)
|
53 |
+
|
54 |
+
with st.form("Credentials"):
|
55 |
+
st.text_input("Username", key="username")
|
56 |
+
st.text_input("Password", type="password", key="password")
|
57 |
+
st.form_submit_button("Log In", on_click=password_entered)
|
58 |
+
|
59 |
+
def password_entered():
|
60 |
+
"""Checks whether a password entered by the user is correct."""
|
61 |
+
|
62 |
+
if project_config.VDI or project_config.LOCAL:
|
63 |
+
|
64 |
+
# Read the user database
|
65 |
+
user_db = pd.read_csv(project_config.AUTH_DIR / "crd_user_db.csv")
|
66 |
+
|
67 |
+
else:
|
68 |
+
|
69 |
+
# Define the scope
|
70 |
+
scope = [
|
71 |
+
'https://spreadsheets.google.com/feeds',
|
72 |
+
'https://www.googleapis.com/auth/drive'
|
73 |
+
]
|
74 |
+
|
75 |
+
# Add credentials to the account
|
76 |
+
creds = ServiceAccountCredentials.from_json_keyfile_dict(create_keyfile_dict(), scope)
|
77 |
+
|
78 |
+
# Authenticate and create the client
|
79 |
+
client = gspread.authorize(creds)
|
80 |
+
|
81 |
+
# Open the spreadsheet
|
82 |
+
sheet = client.open_by_url(st.secrets['spreadsheet']).worksheet("user_db")
|
83 |
+
data = sheet.get_all_records()
|
84 |
+
user_db = pd.DataFrame(data)
|
85 |
+
|
86 |
+
# Check if the username is in the database
|
87 |
+
if st.session_state["username"] in user_db.username.values:
|
88 |
+
|
89 |
+
st.session_state["username_correct"] = True
|
90 |
+
|
91 |
+
# Check if the password is correct
|
92 |
+
if hmac.compare_digest(
|
93 |
+
st.session_state["password"],
|
94 |
+
user_db.loc[user_db.username == st.session_state["username"], "password"].values[0],
|
95 |
+
):
|
96 |
+
|
97 |
+
st.session_state["password_correct"] = True
|
98 |
+
|
99 |
+
# Check if the username is an admin
|
100 |
+
if st.session_state["username"] in user_db[user_db.role == "admin"].username.values:
|
101 |
+
st.session_state["role"] = "admin"
|
102 |
+
else:
|
103 |
+
st.session_state["role"] = "user"
|
104 |
+
|
105 |
+
# Retrieve and store user name and team
|
106 |
+
st.session_state["name"] = user_db.loc[user_db.username == st.session_state["username"], "name"].values[0]
|
107 |
+
st.session_state["team"] = user_db.loc[user_db.username == st.session_state["username"], "team"].values[0]
|
108 |
+
st.session_state["profile_pic"] = st.session_state["username"]
|
109 |
+
|
110 |
+
# Don't store the password
|
111 |
+
del st.session_state["password"]
|
112 |
+
|
113 |
+
else:
|
114 |
+
st.session_state["password_correct"] = False
|
115 |
+
|
116 |
+
else:
|
117 |
+
st.session_state["username_correct"] = False
|
118 |
+
st.session_state["password_correct"] = False
|
119 |
+
|
120 |
+
# Return True if the username + password is validated
|
121 |
+
if st.session_state.get("password_correct", False):
|
122 |
+
return True
|
123 |
+
|
124 |
+
# Show inputs for username + password
|
125 |
+
login_form()
|
126 |
+
if "password_correct" in st.session_state:
|
127 |
+
|
128 |
+
if not st.session_state["username_correct"]:
|
129 |
+
st.error("User not found.")
|
130 |
+
elif not st.session_state["password_correct"]:
|
131 |
+
st.error("The password you entered is incorrect.")
|
132 |
+
else:
|
133 |
+
st.error("An unexpected error occurred.")
|
134 |
+
|
135 |
+
return False
|
136 |
+
|
137 |
+
menu() # Render the dynamic menu!
|
138 |
+
|
139 |
+
if not check_password():
|
140 |
+
st.stop()
|
141 |
+
|
142 |
+
st.switch_page("pages/about.py")
|
data/kg_edge_types.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2ee79b2f5021304a4dd82581568e8a8c940f94b29cd1206f7730bdff6b82cab4
|
3 |
+
size 5288
|
data/kg_edges.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab7d0e23c56381abf8e214cc5d4fae4e6a8b98957c8f2e5272b4f800953b1461
|
3 |
+
size 2765378133
|
data/kg_node_types.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e1a0afff52deec5f48689a22a479d14cd49333759e054624366687ec4ef306c8
|
3 |
+
size 192
|
data/kg_nodes.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a21c42a1ee345195038d438854ee5d4befa7b1984e5efa4865ed74825a75b6d9
|
3 |
+
size 8529743
|
media/about_header.svg
ADDED
|
media/explore_header.svg
ADDED
|
media/gravity_logo.png
ADDED
![]() |
media/gravity_logo.svg
ADDED
|
media/input_header.svg
ADDED
|
media/pfp/anoori.png
ADDED
![]() |
Git LFS Details
|
media/pfp/gravity.png
ADDED
![]() |
Git LFS Details
|
media/pfp/mzitnik.png
ADDED
![]() |
Git LFS Details
|
media/pfp/ndagan.png
ADDED
![]() |
Git LFS Details
|
media/pfp/rbalicer.png
ADDED
![]() |
Git LFS Details
|
media/predict_header.svg
ADDED
|
media/validate_header.svg
ADDED
|
menu.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# From https://docs.streamlit.io/develop/tutorials/multipage/st.page_link-nav
|
2 |
+
import streamlit as st
|
3 |
+
import os
|
4 |
+
import project_config
|
5 |
+
|
6 |
+
def authenticated_menu():
|
7 |
+
|
8 |
+
# Insert profile picture
|
9 |
+
pfp_path = str(project_config.MEDIA_DIR / 'pfp' / f"{st.session_state.profile_pic}.png")
|
10 |
+
if not os.path.exists(pfp_path):
|
11 |
+
pfp_path = str(project_config.MEDIA_DIR / 'pfp' / "gravity.png")
|
12 |
+
st.sidebar.image(pfp_path, use_column_width=True)
|
13 |
+
st.sidebar.markdown("---")
|
14 |
+
|
15 |
+
# Show a navigation menu for authenticated users
|
16 |
+
# st.sidebar.page_link("app.py", label="Switch Accounts", icon="🔒")
|
17 |
+
st.sidebar.page_link("pages/about.py", label="About", icon="📖")
|
18 |
+
st.sidebar.page_link("pages/input.py", label="Input", icon="💡")
|
19 |
+
st.sidebar.page_link("pages/predict.py", label="Predict", icon="🔍",
|
20 |
+
disabled=("query" not in st.session_state))
|
21 |
+
st.sidebar.page_link("pages/validate.py", label="Validate", icon="✅",
|
22 |
+
disabled=("query" not in st.session_state))
|
23 |
+
# st.sidebar.page_link("pages/explore.py", label="Explore", icon="🔍")
|
24 |
+
if st.session_state.role in ["admin"]:
|
25 |
+
st.sidebar.page_link("pages/admin.py", label="Manage Users", icon="🔧")
|
26 |
+
|
27 |
+
# Show the logout button
|
28 |
+
st.sidebar.markdown("---")
|
29 |
+
st.sidebar.button("Log Out", on_click=lambda: st.session_state.clear())
|
30 |
+
|
31 |
+
|
32 |
+
def unauthenticated_menu():
|
33 |
+
|
34 |
+
# Show a navigation menu for unauthenticated users
|
35 |
+
st.sidebar.page_link("app.py", label="Log In", icon="🔒")
|
36 |
+
|
37 |
+
|
38 |
+
def menu():
|
39 |
+
# Determine if a user is logged in or not, then show the correct navigation menu
|
40 |
+
if "role" not in st.session_state or st.session_state.role is None:
|
41 |
+
unauthenticated_menu()
|
42 |
+
return
|
43 |
+
authenticated_menu()
|
44 |
+
|
45 |
+
|
46 |
+
def menu_with_redirect():
|
47 |
+
# Redirect users to the main page if not logged in, otherwise continue to
|
48 |
+
# render the navigation menu
|
49 |
+
if "role" not in st.session_state or st.session_state.role is None:
|
50 |
+
st.switch_page("app.py")
|
51 |
+
menu()
|
pages/about.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from menu import menu_with_redirect
|
3 |
+
|
4 |
+
# Path manipulation
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
# Custom and other imports
|
8 |
+
import project_config
|
9 |
+
|
10 |
+
# Redirect to app.py if not logged in, otherwise show the navigation menu
|
11 |
+
menu_with_redirect()
|
12 |
+
|
13 |
+
# Header
|
14 |
+
st.image(str(project_config.MEDIA_DIR / 'about_header.svg'), use_column_width=True)
|
15 |
+
|
16 |
+
# Main content
|
17 |
+
st.markdown(f"Hello, {st.session_state.name}! Welcome to GRAVITY, a **GR**aph **A**I **VI**sualization **T**ool to query and visualize knowledge graph-grounded biomedical AI models.")
|
18 |
+
|
19 |
+
# Subheader
|
20 |
+
st.subheader("Clinical Drug Repurposing", divider = "grey")
|
21 |
+
|
22 |
+
st.markdown("""
|
23 |
+
Here, we use GRAVITY to visualize the outputs of our clinical drug repurposing algorithm. The algorithm predicts the probability of a drug treating a disease based on the drug-disease relationship in the knowledge graph.
|
24 |
+
""")
|
25 |
+
|
26 |
+
col1, col2, col3 = st.columns(3)
|
27 |
+
|
28 |
+
with col2:
|
29 |
+
if st.button("Make Predictions"):
|
30 |
+
st.switch_page("pages/input.py")
|
pages/admin.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from menu import menu_with_redirect
|
3 |
+
|
4 |
+
# Redirect to app.py if not logged in, otherwise show the navigation menu
|
5 |
+
menu_with_redirect()
|
6 |
+
|
7 |
+
# Verify the user's role
|
8 |
+
if st.session_state.role not in ["admin"]:
|
9 |
+
st.warning("You do not have permission to view this page.")
|
10 |
+
st.stop()
|
11 |
+
|
12 |
+
st.title("User Management")
|
13 |
+
st.markdown(f"You are currently logged with the role of {st.session_state.role}.")
|
pages/input.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from menu import menu_with_redirect
|
3 |
+
|
4 |
+
# Standard imports
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import subprocess
|
8 |
+
|
9 |
+
# Path manipulation
|
10 |
+
import os
|
11 |
+
from pathlib import Path
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
|
14 |
+
# Custom and other imports
|
15 |
+
import project_config
|
16 |
+
from utils import load_kg
|
17 |
+
|
18 |
+
# Redirect to app.py if not logged in, otherwise show the navigation menu
|
19 |
+
menu_with_redirect()
|
20 |
+
|
21 |
+
# Header
|
22 |
+
st.image(str(project_config.MEDIA_DIR / 'input_header.svg'), use_column_width=True)
|
23 |
+
|
24 |
+
st.subheader("Choose Disease Split", divider = "red")
|
25 |
+
|
26 |
+
with st.spinner('Loading disease splits...'):
|
27 |
+
|
28 |
+
if project_config.VDI or project_config.LOCAL:
|
29 |
+
|
30 |
+
# Read disease splits
|
31 |
+
# Load from Kempner using sync_data.sh
|
32 |
+
disease_splits = pd.read_csv(project_config.DATA_DIR / 'disease_splits' / 'disease_splits.csv',
|
33 |
+
dtype = {'node_index': str, 'disease_split_index': str})
|
34 |
+
|
35 |
+
else:
|
36 |
+
|
37 |
+
# Read disease splits from HF
|
38 |
+
disease_splits = hf_hub_download(repo_id=project_config.HF_REPO,
|
39 |
+
filename='disease_split/disease_splits.csv',
|
40 |
+
token=st.secrets["HF_TOKEN"], repo_type="dataset")
|
41 |
+
|
42 |
+
# Group disease splits by disease_split_index column
|
43 |
+
disease_splits_grouped = disease_splits.groupby('disease_split_index').size().reset_index(name='node_count')
|
44 |
+
|
45 |
+
# Subset to unique disease splits
|
46 |
+
splits_df =disease_splits[disease_splits['node_index'] == disease_splits['disease_split_index']]
|
47 |
+
splits_df = splits_df.drop_duplicates(subset='disease_split_index').reset_index(drop=True)
|
48 |
+
splits_df = splits_df[['node_index', 'node_name', 'node_id']]
|
49 |
+
|
50 |
+
# Merge with counts
|
51 |
+
splits_df = splits_df.merge(disease_splits_grouped, left_on='node_index', right_on='disease_split_index', how='left')
|
52 |
+
splits_df = splits_df.drop(columns='disease_split_index')
|
53 |
+
splits_df['node_name'] = splits_df['node_name'].str.replace(' \\(disease\\)', '', regex=True)
|
54 |
+
|
55 |
+
# Add row for all to beginning
|
56 |
+
splits_df['node_index'] = splits_df['node_index'].astype(str)
|
57 |
+
splits_df = pd.concat([pd.DataFrame([['all', 'all diseases', None, disease_splits.shape[0]]], columns=splits_df.columns), splits_df], ignore_index=True)
|
58 |
+
|
59 |
+
# For each disease split, count number of edges (number of rows in CSV file in disease_splits directory)
|
60 |
+
# Do not read file in
|
61 |
+
edge_counts = []
|
62 |
+
for index, row in splits_df.iterrows():
|
63 |
+
# Count lines
|
64 |
+
file_name = project_config.DATA_DIR / 'disease_splits' / 'split_edges' / f'{row["node_index"]}.csv'
|
65 |
+
edge_count = int(subprocess.check_output(['wc', '-l', file_name]).split()[0]) - 1
|
66 |
+
edge_counts.append(edge_count)
|
67 |
+
|
68 |
+
# Add edge counts to splits_df
|
69 |
+
splits_df['edge_count'] = edge_counts
|
70 |
+
|
71 |
+
# Get list of available modles
|
72 |
+
model_files = os.listdir(project_config.MODEL_DIR / 'embeddings')
|
73 |
+
model_files = [f for f in model_files if f.endswith('_embeddings.pt')]
|
74 |
+
|
75 |
+
# Get model metadata
|
76 |
+
def get_model_metadata(f):
|
77 |
+
|
78 |
+
# Get metadata
|
79 |
+
metadata = f.split('_')
|
80 |
+
date = '_'.join(metadata[:6])
|
81 |
+
date = pd.to_datetime(date, format='%Y_%m_%d_%H_%M_%S')
|
82 |
+
|
83 |
+
# Parameters
|
84 |
+
params = metadata[6].split('-')
|
85 |
+
params = {p.split('=')[0]: p.split('=')[1] for p in params}
|
86 |
+
|
87 |
+
# Add date to params
|
88 |
+
params['date'] = date
|
89 |
+
params['file'] = f
|
90 |
+
return params
|
91 |
+
|
92 |
+
# Get available models, only keep latest version per split
|
93 |
+
avail_models = pd.DataFrame([get_model_metadata(f) for f in model_files])
|
94 |
+
avail_models = avail_models.sort_values('date', ascending=False).drop_duplicates('test').reset_index(drop=True)
|
95 |
+
# avail_models.loc[avail_models['test'] == 'all', 'test'] = 'all diseases'
|
96 |
+
|
97 |
+
# Add column to indicate if model is available
|
98 |
+
splits_df['available'] = splits_df['node_index'].isin(avail_models['test'])
|
99 |
+
|
100 |
+
# If all diseases model is available, set all diseases to available
|
101 |
+
if avail_models['test'].str.contains('all').any():
|
102 |
+
splits_df.loc[splits_df['node_name'] == 'all diseases', 'available'] = True
|
103 |
+
|
104 |
+
####################################################################################################
|
105 |
+
|
106 |
+
# Select disease split from splits with available models
|
107 |
+
# Make dictionary with node_index: node_name, where name is value shown but index is used for query
|
108 |
+
# split_options = splits_df[splits_df['available']].copy()
|
109 |
+
split_options = splits_df.copy()
|
110 |
+
split_options = split_options.set_index('node_index')['node_name'].to_dict()
|
111 |
+
|
112 |
+
# Check if split is in session state
|
113 |
+
if "split" not in st.session_state:
|
114 |
+
split_index = 0
|
115 |
+
else:
|
116 |
+
split_index = list(split_options.keys()).index(st.session_state.split)
|
117 |
+
|
118 |
+
split = st.selectbox("Disease Split", list(split_options.keys()), format_func = lambda x: split_options[x],
|
119 |
+
index = split_index)
|
120 |
+
|
121 |
+
# Show all splits dataframe
|
122 |
+
splits_display = splits_df[['node_index', 'node_name', 'node_count', 'edge_count', 'available']].copy()
|
123 |
+
splits_display = splits_display.rename(columns = {'node_index': 'Split ID', 'node_name': 'Disease', 'node_count': 'Node Count', 'edge_count': 'Edge Count', 'available': 'Model Available'})
|
124 |
+
st.dataframe(splits_display, use_container_width = True, hide_index = True)
|
125 |
+
|
126 |
+
# Save split and available models to session state
|
127 |
+
st.session_state.split = split
|
128 |
+
st.session_state.splits_df = splits_df
|
129 |
+
st.session_state.avail_models = avail_models
|
130 |
+
|
131 |
+
if st.button("Explore Split"):
|
132 |
+
st.switch_page("pages/split.py")
|
133 |
+
|
134 |
+
####################################################################################################
|
135 |
+
|
136 |
+
st.subheader("Construct Query", divider = "red")
|
137 |
+
|
138 |
+
# # Checkbox to allow reverse edges
|
139 |
+
# allow_reverse_edges = st.checkbox("Allow reverse edges?", value = False)
|
140 |
+
allow_reverse_edges = False
|
141 |
+
|
142 |
+
# Load knowledge graph
|
143 |
+
kg_nodes = load_kg()
|
144 |
+
|
145 |
+
with st.spinner('Loading knowledge graph...'):
|
146 |
+
# kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
|
147 |
+
node_types = pd.read_csv(project_config.DATA_DIR / 'kg_node_types.csv')
|
148 |
+
edge_types = pd.read_csv(project_config.DATA_DIR / 'kg_edge_types.csv')
|
149 |
+
|
150 |
+
if not allow_reverse_edges:
|
151 |
+
edge_types = edge_types[edge_types.direction == 'forward']
|
152 |
+
|
153 |
+
# If query is not in session state, initialize it
|
154 |
+
if "query" not in st.session_state:
|
155 |
+
source_node_type_index = 0
|
156 |
+
source_node_index = 0
|
157 |
+
target_node_type_index = 0
|
158 |
+
relation_index = 0
|
159 |
+
|
160 |
+
if st.session_state.team == "Clalit":
|
161 |
+
source_node_type_index = 2
|
162 |
+
source_node_index = 0
|
163 |
+
target_node_type_index = 3
|
164 |
+
relation_index = 2
|
165 |
+
|
166 |
+
else:
|
167 |
+
source_node_type_index = st.session_state.query_options['source_node_type'].index(st.session_state.query['source_node_type'])
|
168 |
+
source_node_index = st.session_state.query_options['source_node'].index(st.session_state.query['source_node'])
|
169 |
+
target_node_type_index = st.session_state.query_options['target_node_type'].index(st.session_state.query['target_node_type'])
|
170 |
+
relation_index = st.session_state.query_options['relation'].index(st.session_state.query['relation'])
|
171 |
+
|
172 |
+
# Select source node type
|
173 |
+
source_node_type_options = node_types['node_type']
|
174 |
+
source_node_type = st.selectbox("Source Node Type", source_node_type_options,
|
175 |
+
format_func = lambda x: x.replace("_", " "), index = source_node_type_index)
|
176 |
+
|
177 |
+
# If source node type is disease, add option to select only diseases in current split
|
178 |
+
if source_node_type == 'disease':
|
179 |
+
|
180 |
+
# Get diseases in current split
|
181 |
+
if split == 'all':
|
182 |
+
split_diseases = disease_splits.drop_duplicates(subset='node_name')['node_name']
|
183 |
+
else:
|
184 |
+
split_diseases = disease_splits[disease_splits['disease_split_index'] == split]
|
185 |
+
split_diseases = split_diseases.drop_duplicates(subset='node_name')['node_name']
|
186 |
+
|
187 |
+
# Add checkbox to filter diseases
|
188 |
+
filter_diseases = st.checkbox("Filter diseases to current split?", value = False)
|
189 |
+
|
190 |
+
# Select source node
|
191 |
+
|
192 |
+
if source_node_type == 'disease' and filter_diseases:
|
193 |
+
# source_node_options = source_node_options[source_node_options.isin(split_diseases)]
|
194 |
+
source_node_options = split_diseases
|
195 |
+
else:
|
196 |
+
source_node_options = kg_nodes[kg_nodes['node_type'] == source_node_type]['node_name']
|
197 |
+
source_node = st.selectbox("Source Node", source_node_options,
|
198 |
+
index = source_node_index)
|
199 |
+
|
200 |
+
# Select target node type
|
201 |
+
target_node_type_options = edge_types[edge_types.x_type == source_node_type].y_type.unique()
|
202 |
+
target_node_type = st.selectbox("Target Node Type", target_node_type_options,
|
203 |
+
format_func = lambda x: x.replace("_", " "), index = target_node_type_index)
|
204 |
+
|
205 |
+
# Select relation
|
206 |
+
relation_options = edge_types[(edge_types.x_type == source_node_type) & (edge_types.y_type == target_node_type)].relation.unique()
|
207 |
+
relation = st.selectbox("Edge Type", relation_options,
|
208 |
+
format_func = lambda x: x.replace("_", "-"), index = relation_index)
|
209 |
+
|
210 |
+
# Button to submit query
|
211 |
+
if st.button("Submit Query"):
|
212 |
+
|
213 |
+
# Check if model is available for split
|
214 |
+
model_avail = splits_df.loc[splits_df['node_index'] == st.session_state.split, 'available'].values[0]
|
215 |
+
if not model_avail:
|
216 |
+
|
217 |
+
st.error("A trained model is not yet available for this disease split. Please select another disease split for which a trained model is available.", icon="🚨")
|
218 |
+
|
219 |
+
else:
|
220 |
+
|
221 |
+
# Save query to session state
|
222 |
+
st.session_state.query = {
|
223 |
+
"source_node_type": source_node_type,
|
224 |
+
"source_node": source_node,
|
225 |
+
"target_node_type": target_node_type,
|
226 |
+
"relation": relation
|
227 |
+
}
|
228 |
+
|
229 |
+
# Save query options to session state
|
230 |
+
st.session_state.query_options = {
|
231 |
+
"source_node_type": list(source_node_type_options),
|
232 |
+
"source_node": list(source_node_options),
|
233 |
+
"target_node_type": list(target_node_type_options),
|
234 |
+
"relation": list(relation_options)
|
235 |
+
}
|
236 |
+
|
237 |
+
# Delete validation from session state
|
238 |
+
if "validation" in st.session_state:
|
239 |
+
del st.session_state.validation
|
240 |
+
|
241 |
+
# # Write query to console
|
242 |
+
# st.write("Current Query:")
|
243 |
+
# st.write(st.session_state.query)
|
244 |
+
st.write("Query submitted.")
|
245 |
+
|
246 |
+
# Switch to the Predict page
|
247 |
+
st.switch_page("pages/predict.py")
|
248 |
+
|
249 |
+
|
250 |
+
st.subheader("Knowledge Graph", divider = "red")
|
251 |
+
display_data = kg_nodes[['node_id', 'node_type', 'node_name', 'node_source']].copy()
|
252 |
+
display_data = display_data.rename(columns = {'node_id': 'ID', 'node_type': 'Type', 'node_name': 'Name', 'node_source': 'Database'})
|
253 |
+
st.dataframe(display_data, use_container_width = True, hide_index = True)
|
pages/predict.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from menu import menu_with_redirect
|
3 |
+
|
4 |
+
# Standard imports
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
# Path manipulation
|
12 |
+
from pathlib import Path
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
+
|
15 |
+
# Plotting
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
plt.rcParams['font.sans-serif'] = 'Arial'
|
18 |
+
|
19 |
+
# Custom and other imports
|
20 |
+
import project_config
|
21 |
+
from utils import capitalize_after_slash, load_kg
|
22 |
+
|
23 |
+
# Redirect to app.py if not logged in, otherwise show the navigation menu
|
24 |
+
menu_with_redirect()
|
25 |
+
|
26 |
+
# Header
|
27 |
+
st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=True)
|
28 |
+
|
29 |
+
# Main content
|
30 |
+
# st.markdown(f"Hello, {st.session_state.name}!")
|
31 |
+
|
32 |
+
st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'])} Search", divider = "blue")
|
33 |
+
|
34 |
+
# Print current query
|
35 |
+
st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
|
36 |
+
|
37 |
+
@st.cache_data(show_spinner = 'Downloading AI model...')
|
38 |
+
def get_embeddings():
|
39 |
+
|
40 |
+
# # Get checkpoint name
|
41 |
+
# best_ckpt = "2024_05_15_13_05_33_epoch=2-step=40383"
|
42 |
+
|
43 |
+
# # Get paths to embeddings, relation weights, and edge types
|
44 |
+
# # with st.spinner('Downloading AI model...'):
|
45 |
+
# embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
46 |
+
# filename=(best_ckpt + "-thresh=4000_embeddings.pt"),
|
47 |
+
# token=st.secrets["HF_TOKEN"])
|
48 |
+
# relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
49 |
+
# filename=(best_ckpt + "_relation_weights.pt"),
|
50 |
+
# token=st.secrets["HF_TOKEN"])
|
51 |
+
# edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
52 |
+
# filename=(best_ckpt + "_edge_types.pt"),
|
53 |
+
# token=st.secrets["HF_TOKEN"])
|
54 |
+
|
55 |
+
# Get split name
|
56 |
+
split = st.session_state.split
|
57 |
+
avail_models = st.session_state.avail_models
|
58 |
+
|
59 |
+
# Get model name from avail_models
|
60 |
+
embed_name = avail_models[avail_models['test'] == 'all']['file'].values[0]
|
61 |
+
relation_weights_name = embed_name.replace('_embeddings.pt', '_relation_weights.pt')
|
62 |
+
edge_types_name = embed_name.replace('_embeddings.pt', '_edge_types.pt')
|
63 |
+
|
64 |
+
# Convert to paths
|
65 |
+
embed_path = project_config.MODEL_DIR / 'embeddings' / embed_name
|
66 |
+
relation_weights_path = project_config.MODEL_DIR / 'embeddings' / relation_weights_name
|
67 |
+
edge_types_path = project_config.MODEL_DIR / 'embeddings' / edge_types_name
|
68 |
+
|
69 |
+
return embed_path, relation_weights_path, edge_types_path
|
70 |
+
|
71 |
+
@st.cache_data(show_spinner = 'Loading AI model...')
|
72 |
+
def load_embeddings(embed_path, relation_weights_path, edge_types_path):
|
73 |
+
|
74 |
+
# Load embeddings, relation weights, and edge types
|
75 |
+
# with st.spinner('Loading AI model...'):
|
76 |
+
embeddings = torch.load(embed_path)
|
77 |
+
relation_weights = torch.load(relation_weights_path)
|
78 |
+
edge_types = torch.load(edge_types_path)
|
79 |
+
|
80 |
+
return embeddings, relation_weights, edge_types
|
81 |
+
|
82 |
+
# Load knowledge graph and embeddings
|
83 |
+
kg_nodes = load_kg()
|
84 |
+
embed_path, relation_weights_path, edge_types_path = get_embeddings()
|
85 |
+
embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path)
|
86 |
+
|
87 |
+
# # Print source node type
|
88 |
+
# st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
|
89 |
+
|
90 |
+
# # Print source node
|
91 |
+
# st.write(f"Source Node: {st.session_state.query['source_node']}")
|
92 |
+
|
93 |
+
# # Print relation
|
94 |
+
# st.write(f"Edge Type: {st.session_state.query['relation']}")
|
95 |
+
|
96 |
+
# # Print target node type
|
97 |
+
# st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
|
98 |
+
|
99 |
+
# Compute predictions
|
100 |
+
with st.spinner('Computing predictions...'):
|
101 |
+
|
102 |
+
source_node_type = st.session_state.query['source_node_type']
|
103 |
+
source_node = st.session_state.query['source_node']
|
104 |
+
relation = st.session_state.query['relation']
|
105 |
+
target_node_type = st.session_state.query['target_node_type']
|
106 |
+
|
107 |
+
# Get source node index
|
108 |
+
src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
|
109 |
+
|
110 |
+
# Get relation index
|
111 |
+
edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
|
112 |
+
|
113 |
+
# Get target nodes indices
|
114 |
+
target_nodes = kg_nodes[kg_nodes.node_type == target_node_type].copy()
|
115 |
+
dst_indices = target_nodes.node_index.values
|
116 |
+
src_indices = np.repeat(src_index, len(dst_indices))
|
117 |
+
|
118 |
+
# Retrieve cached embeddings and apply activation function
|
119 |
+
src_embeddings = embeddings[src_indices]
|
120 |
+
dst_embeddings = embeddings[dst_indices]
|
121 |
+
src_embeddings = F.leaky_relu(src_embeddings)
|
122 |
+
dst_embeddings = F.leaky_relu(dst_embeddings)
|
123 |
+
|
124 |
+
# Get relation weights
|
125 |
+
rel_weights = relation_weights[edge_type_index]
|
126 |
+
|
127 |
+
# Compute weighted dot product
|
128 |
+
scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
|
129 |
+
scores = torch.sigmoid(scores)
|
130 |
+
|
131 |
+
# Add scores to dataframe
|
132 |
+
target_nodes['score'] = scores.detach().numpy()
|
133 |
+
target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
|
134 |
+
target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
|
135 |
+
|
136 |
+
# Rename columns
|
137 |
+
display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
|
138 |
+
display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'score': 'Score', 'node_source': 'Database'})
|
139 |
+
|
140 |
+
# Define dictionary mapping node types to database URLs
|
141 |
+
map_dbs = {
|
142 |
+
'gene/protein': lambda x: f"https://ncbi.nlm.nih.gov/gene/?term={x}",
|
143 |
+
'drug': lambda x: f"https://go.drugbank.com/drugs/{x}",
|
144 |
+
'effect/phenotype': lambda x: f"https://hpo.jax.org/app/browse/term/HP:{x.zfill(7)}", # pad with 0s to 7 digits
|
145 |
+
'disease': lambda x: x, # MONDO
|
146 |
+
# pad with 0s to 7 digits
|
147 |
+
'biological_process': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
|
148 |
+
'molecular_function': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
|
149 |
+
'cellular_component': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
|
150 |
+
'exposure': lambda x: f"https://ctdbase.org/detail.go?type=chem&acc={x}",
|
151 |
+
'pathway': lambda x: f"https://reactome.org/content/detail/{x}",
|
152 |
+
'anatomy': lambda x: x,
|
153 |
+
}
|
154 |
+
|
155 |
+
# Get name of database
|
156 |
+
display_database = display_data['Database'].values[0]
|
157 |
+
|
158 |
+
# Add URLs to database column
|
159 |
+
display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
|
160 |
+
|
161 |
+
# Check if validation data exists
|
162 |
+
if 'validation' in st.session_state:
|
163 |
+
|
164 |
+
# Checkbox to allow reverse edges
|
165 |
+
show_val = st.checkbox("Show ground truth validation?", value = False)
|
166 |
+
|
167 |
+
if show_val:
|
168 |
+
|
169 |
+
# Get validation data
|
170 |
+
val_results = st.session_state.validation.copy()
|
171 |
+
|
172 |
+
# Merge with predictions
|
173 |
+
val_display_data = pd.merge(display_data, val_results, left_on = 'ID', right_on = 'y_id', how='left')
|
174 |
+
val_display_data = val_display_data.fillna(0).drop(columns='y_id')
|
175 |
+
|
176 |
+
# Get new columns
|
177 |
+
val_relations = val_display_data.columns.difference(display_data.columns).tolist()
|
178 |
+
|
179 |
+
# Replace 0 with blank and 1 with check emoji in new columns
|
180 |
+
for col in val_relations:
|
181 |
+
val_display_data[col] = val_display_data[col].replace({0: '', 1: '✅'})
|
182 |
+
|
183 |
+
# Define a function to apply styles
|
184 |
+
def style_val(val):
|
185 |
+
if val == '✅':
|
186 |
+
return 'background-color: #C2EABD;' # text-align: center;
|
187 |
+
return 'background-color: #F5F5F5;' # text-align: center;
|
188 |
+
|
189 |
+
else:
|
190 |
+
show_val = False
|
191 |
+
|
192 |
+
|
193 |
+
# NODE SEARCH
|
194 |
+
|
195 |
+
# Use multiselect to search for specific nodes
|
196 |
+
selected_nodes = st.multiselect(f"Search for specific {target_node_type.replace('_', ' ')} nodes to determine their ranking.",
|
197 |
+
display_data.Name, placeholder = "Type to search...")
|
198 |
+
|
199 |
+
# Filter nodes
|
200 |
+
if len(selected_nodes) > 0:
|
201 |
+
|
202 |
+
if show_val:
|
203 |
+
# selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)]
|
204 |
+
selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)].copy()
|
205 |
+
selected_display_data = selected_display_data.reset_index(drop=True).style.map(style_val, subset=val_relations)
|
206 |
+
else:
|
207 |
+
selected_display_data = display_data[display_data.Name.isin(selected_nodes)].copy()
|
208 |
+
selected_display_data = selected_display_data.reset_index(drop=True)
|
209 |
+
|
210 |
+
st.markdown(f"Out of {target_nodes.shape[0]} {target_node_type} nodes, the selected nodes rank as follows:")
|
211 |
+
selected_display_data_with_rank = selected_display_data.copy()
|
212 |
+
selected_display_data_with_rank['Rank'] = selected_display_data_with_rank['Rank'].apply(lambda x: f"{x} (top {(100*x/target_nodes.shape[0]):.2f}% of predictions)")
|
213 |
+
|
214 |
+
# Show filtered nodes
|
215 |
+
if target_node_type not in ['disease', 'anatomy']:
|
216 |
+
st.dataframe(selected_display_data_with_rank, use_container_width = True, hide_index = True,
|
217 |
+
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
218 |
+
help = "Click to visit external database.",
|
219 |
+
display_text = display_database)})
|
220 |
+
else:
|
221 |
+
st.dataframe(selected_display_data_with_rank, use_container_width = True)
|
222 |
+
|
223 |
+
# Show plot
|
224 |
+
st.markdown(f"In the plot below, the dashed lines represent the rank of the selected {target_node_type} nodes across all predictions for {source_node}.")
|
225 |
+
|
226 |
+
# Checkbox to show text labels
|
227 |
+
show_labels = st.checkbox("Show Text Labels?", value = False)
|
228 |
+
|
229 |
+
# Plot rank vs. score using matplotlib
|
230 |
+
fig, ax = plt.subplots(figsize = (10, 6))
|
231 |
+
ax.plot(display_data['Rank'], display_data['Score'], color = 'black', linewidth = 1.5, zorder = 2)
|
232 |
+
ax.set_xlabel('Rank', fontsize = 12)
|
233 |
+
ax.set_ylabel('Score', fontsize = 12)
|
234 |
+
ax.set_xlim(1, display_data['Rank'].max())
|
235 |
+
|
236 |
+
# Get color palette
|
237 |
+
# palette = plt.cm.get_cmap('tab10', len(selected_display_data))
|
238 |
+
palette = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"]
|
239 |
+
|
240 |
+
# Add vertical line for selected nodes
|
241 |
+
for i, node in selected_display_data.iterrows():
|
242 |
+
ax.scatter(node['Rank'], node['Score'], color = palette[i], zorder=3)
|
243 |
+
ax.axvline(node['Rank'], color = palette[i], linestyle = '--', linewidth = 1.5, label = node['Name'], zorder=3)
|
244 |
+
if show_labels:
|
245 |
+
ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize = 10, color = palette[i], zorder=3)
|
246 |
+
|
247 |
+
# Add legend
|
248 |
+
ax.legend(loc = 'upper right', fontsize = 10)
|
249 |
+
ax.grid(alpha = 0.2, zorder=0)
|
250 |
+
|
251 |
+
st.pyplot(fig)
|
252 |
+
|
253 |
+
|
254 |
+
# FULL RESULTS
|
255 |
+
|
256 |
+
# Show top ranked nodes
|
257 |
+
st.subheader("Model Predictions", divider = "blue")
|
258 |
+
top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
|
259 |
+
|
260 |
+
# Show full results
|
261 |
+
# full_results = val_display_data.iloc[:top_k] if show_val else display_data.iloc[:top_k]
|
262 |
+
full_results = val_display_data.iloc[:top_k].style.map(style_val, subset=val_relations) if show_val else display_data.iloc[:top_k]
|
263 |
+
|
264 |
+
if target_node_type not in ['disease', 'anatomy']:
|
265 |
+
st.dataframe(full_results, use_container_width = True, hide_index = True,
|
266 |
+
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
267 |
+
help = "Click to visit external database.",
|
268 |
+
display_text = display_database)})
|
269 |
+
else:
|
270 |
+
st.dataframe(full_results, use_container_width = True, hide_index = True,)
|
271 |
+
|
272 |
+
# Save to session state
|
273 |
+
st.session_state.predictions = display_data
|
274 |
+
st.session_state.display_database = display_database
|
pages/split.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from menu import menu_with_redirect
|
3 |
+
|
4 |
+
# Standard imports
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
# Path manipulation
|
9 |
+
import os
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
# Plotting
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
plt.rcParams['font.sans-serif'] = 'Arial'
|
15 |
+
|
16 |
+
# Custom and other imports
|
17 |
+
import project_config
|
18 |
+
|
19 |
+
# Redirect to app.py if not logged in, otherwise show the navigation menu
|
20 |
+
menu_with_redirect()
|
21 |
+
|
22 |
+
# Back button with emoji
|
23 |
+
if st.button("◀️ Back"):
|
24 |
+
st.switch_page("pages/input.py")
|
25 |
+
|
26 |
+
# Get metadata from session state
|
27 |
+
split = st.session_state.split
|
28 |
+
splits_df = st.session_state.splits_df
|
29 |
+
|
30 |
+
with st.spinner('Loading disease splits...'):
|
31 |
+
|
32 |
+
# Read disease splits
|
33 |
+
disease_split_nodes = pd.read_csv(project_config.DATA_DIR / 'disease_splits' / 'disease_splits.csv', dtype = {'disease_split_index': str})
|
34 |
+
|
35 |
+
# If split is all
|
36 |
+
if split == 'all':
|
37 |
+
disease_split_nodes = disease_split_nodes[['node_index', 'node_name', 'embedding_score', 'levenshtein_score', 'neighborhood_score', 'method', 'disease_split']]
|
38 |
+
disease_split_nodes = disease_split_nodes.rename(columns = {'node_index': 'Node ID', 'node_name': 'Disease', 'embedding_score': 'Embedding Score', 'levenshtein_score': 'Levenshtein Score', 'neighborhood_score': 'Neighborhood Score', 'method': 'Method', 'disease_split': 'Disease Split'})
|
39 |
+
|
40 |
+
else:
|
41 |
+
disease_split_nodes = disease_split_nodes[disease_split_nodes['disease_split_index'] == split]
|
42 |
+
disease_split_nodes = disease_split_nodes[['node_index', 'node_name', 'embedding_score', 'levenshtein_score', 'neighborhood_score', 'method']]
|
43 |
+
disease_split_nodes = disease_split_nodes.rename(columns = {'node_index': 'Node ID', 'node_name': 'Disease', 'embedding_score': 'Embedding Score', 'levenshtein_score': 'Levenshtein Score', 'neighborhood_score': 'Neighborhood Score', 'method': 'Method'})
|
44 |
+
|
45 |
+
# Read disease split edges
|
46 |
+
disease_split_edges = pd.read_csv(project_config.DATA_DIR / 'disease_splits' / 'split_edges' / f'{split}.csv')
|
47 |
+
|
48 |
+
# Subset and rename columns
|
49 |
+
disease_split_edges = disease_split_edges[['relation', 'x_index', 'x_type', 'x_name', 'y_index', 'y_type', 'y_name']]
|
50 |
+
disease_split_edges['relation'] = disease_split_edges['relation'].str.replace('_', ' ').str.title()
|
51 |
+
disease_split_edges = disease_split_edges.rename(columns = {'relation': 'Relation', 'x_index': 'Source ID', 'x_type': 'Source Type', 'x_name': 'Source Name', 'y_index': 'Target ID', 'y_type': 'Target Type', 'y_name': 'Target Name'})
|
52 |
+
|
53 |
+
st.subheader("Nodes in Disease Split", divider = "blue")
|
54 |
+
st.markdown(f"**Disease Split:** {splits_df[splits_df['node_index'] == split]['node_name'].values[0]}")
|
55 |
+
st.markdown(f"**Number of Nodes:** {disease_split_nodes.shape[0]}")
|
56 |
+
|
57 |
+
# Show as dataframe
|
58 |
+
st.dataframe(disease_split_nodes, use_container_width = True, hide_index = True)
|
59 |
+
|
60 |
+
st.markdown("Below, we show the number of nodes by method of inclusion in the disease split.")
|
61 |
+
|
62 |
+
# Plotting the bar plot
|
63 |
+
method_counts = disease_split_nodes['Method'].value_counts().reset_index()
|
64 |
+
method_counts.columns = ['Method', 'Count']
|
65 |
+
method_counts['Method Length'] = method_counts['Method'].apply(len)
|
66 |
+
method_counts = method_counts.sort_values('Method Length')
|
67 |
+
|
68 |
+
# Plotting the bar plot
|
69 |
+
plt.figure(figsize=(10, 6))
|
70 |
+
bars = plt.bar(method_counts['Method'], method_counts['Count'], color='#B8D4F7', edgecolor='black')
|
71 |
+
plt.xlabel('Method', fontsize=16, fontweight='bold')
|
72 |
+
plt.ylabel('Count', fontsize=16, fontweight='bold')
|
73 |
+
plt.xticks(rotation=45, ha='right', fontsize=12)
|
74 |
+
plt.tight_layout()
|
75 |
+
|
76 |
+
# Adding labels on top of each bar
|
77 |
+
for bar in bars:
|
78 |
+
yval = bar.get_height()
|
79 |
+
plt.text(bar.get_x() + bar.get_width()/2.0, yval, int(yval), va='bottom', fontsize=12)
|
80 |
+
plt.ylim(0, max(method_counts['Count'])*1.1)
|
81 |
+
|
82 |
+
# Show plot
|
83 |
+
st.pyplot(plt)
|
84 |
+
|
85 |
+
|
86 |
+
st.subheader("Edges in Disease Split", divider = "green")
|
87 |
+
st.markdown(f"**Number of Edges:** {disease_split_edges.shape[0]}")
|
88 |
+
|
89 |
+
# Show as dataframe
|
90 |
+
st.dataframe(disease_split_edges, use_container_width = True, hide_index = True)
|
91 |
+
|
92 |
+
if disease_split_edges.shape[0] > 0:
|
93 |
+
|
94 |
+
# Make bar plot of number of edges by relation
|
95 |
+
st.markdown("Below, we show the number of edges by relation in the disease split.")
|
96 |
+
relation_counts = disease_split_edges['Relation'].value_counts().reset_index()
|
97 |
+
relation_counts.columns = ['Relation', 'Count']
|
98 |
+
|
99 |
+
# Plotting the bar plot
|
100 |
+
plt.figure(figsize=(10, 6))
|
101 |
+
bars = plt.bar(relation_counts['Relation'], relation_counts['Count'], color='#C1ECC5', edgecolor='black')
|
102 |
+
plt.xlabel('Relation', fontsize=16, fontweight='bold')
|
103 |
+
plt.ylabel('Count', fontsize=16, fontweight='bold')
|
104 |
+
plt.xticks(rotation=45, ha='right', fontsize=12)
|
105 |
+
plt.tight_layout()
|
106 |
+
|
107 |
+
# Adding labels on top of each bar
|
108 |
+
for bar in bars:
|
109 |
+
yval = bar.get_height()
|
110 |
+
plt.text(bar.get_x() + bar.get_width()/2.0, yval, int(yval), va='bottom', fontsize=12)
|
111 |
+
plt.ylim(0, max(relation_counts['Count'])*1.1)
|
112 |
+
|
113 |
+
# Show plot
|
114 |
+
st.pyplot(plt)
|
pages/validate.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from menu import menu_with_redirect
|
3 |
+
|
4 |
+
# Standard imports
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
# Path manipulation
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
# Plotting
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
plt.rcParams['font.sans-serif'] = 'Arial'
|
14 |
+
import matplotlib.colors as mcolors
|
15 |
+
|
16 |
+
# Custom and other imports
|
17 |
+
import project_config
|
18 |
+
from utils import load_kg, load_kg_edges
|
19 |
+
|
20 |
+
# Redirect to app.py if not logged in, otherwise show the navigation menu
|
21 |
+
menu_with_redirect()
|
22 |
+
|
23 |
+
# Header
|
24 |
+
st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True)
|
25 |
+
|
26 |
+
# Main content
|
27 |
+
# st.markdown(f"Hello, {st.session_state.name}!")
|
28 |
+
|
29 |
+
st.subheader("Validate Predictions", divider = "green")
|
30 |
+
|
31 |
+
# Print current query
|
32 |
+
st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
|
33 |
+
|
34 |
+
# Coming soon
|
35 |
+
# st.write("Coming soon...")
|
36 |
+
|
37 |
+
source_node_type = st.session_state.query['source_node_type']
|
38 |
+
source_node = st.session_state.query['source_node']
|
39 |
+
relation = st.session_state.query['relation']
|
40 |
+
target_node_type = st.session_state.query['target_node_type']
|
41 |
+
predictions = st.session_state.predictions
|
42 |
+
|
43 |
+
kg_nodes = load_kg()
|
44 |
+
kg_edges = load_kg_edges()
|
45 |
+
|
46 |
+
# Convert tuple to hex
|
47 |
+
def rgba_to_hex(rgba):
|
48 |
+
return mcolors.to_hex(rgba[:3])
|
49 |
+
|
50 |
+
with st.spinner('Searching known relationships...'):
|
51 |
+
|
52 |
+
# Subset existing edges
|
53 |
+
edge_subset = kg_edges[(kg_edges.x_type == source_node_type) & (kg_edges.x_name == source_node)]
|
54 |
+
edge_subset = edge_subset[edge_subset.y_type == target_node_type]
|
55 |
+
|
56 |
+
# Merge edge subset with predictions
|
57 |
+
edges_in_kg = pd.merge(predictions, edge_subset[['relation', 'y_id']], left_on = 'ID', right_on = 'y_id', how = 'right')
|
58 |
+
edges_in_kg = edges_in_kg.sort_values(by = 'Score', ascending = False)
|
59 |
+
edges_in_kg = edges_in_kg.drop(columns = 'y_id')
|
60 |
+
|
61 |
+
# Rename relation to ground-truth
|
62 |
+
edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']]
|
63 |
+
edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'})
|
64 |
+
|
65 |
+
# If there exist edges in KG
|
66 |
+
if len(edges_in_kg) > 0:
|
67 |
+
|
68 |
+
with st.spinner('Saving validation results...'):
|
69 |
+
|
70 |
+
# Cast long to wide
|
71 |
+
val_results = edge_subset[['relation', 'y_id']].pivot_table(index='y_id', columns='relation', aggfunc='size', fill_value=0)
|
72 |
+
val_results = (val_results > 0).astype(int).reset_index()
|
73 |
+
val_results.columns = [val_results.columns[0]] + [x.replace('_', ' ').title() for x in val_results.columns[1:]]
|
74 |
+
|
75 |
+
# Save validation results to session state
|
76 |
+
st.session_state.validation = val_results
|
77 |
+
|
78 |
+
with st.spinner('Plotting known relationships...'):
|
79 |
+
|
80 |
+
# Define a color map for different relations
|
81 |
+
color_map = plt.get_cmap('tab10')
|
82 |
+
|
83 |
+
# Group by relation and create separate plots
|
84 |
+
relations = edges_in_kg['Known Relation'].unique()
|
85 |
+
for idx, relation in enumerate(relations):
|
86 |
+
|
87 |
+
relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation]
|
88 |
+
|
89 |
+
# Get a color from the color map
|
90 |
+
color = color_map(idx % color_map.N)
|
91 |
+
|
92 |
+
fig, ax = plt.subplots(figsize=(10, 3))
|
93 |
+
ax.plot(predictions['Rank'], predictions['Score'])
|
94 |
+
ax.set_xlabel('Rank', fontsize=12)
|
95 |
+
ax.set_ylabel('Score', fontsize=12)
|
96 |
+
ax.set_xlim(1, predictions['Rank'].max())
|
97 |
+
|
98 |
+
for i, node in relation_data.iterrows():
|
99 |
+
ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name'])
|
100 |
+
# ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color)
|
101 |
+
|
102 |
+
# ax.set_title(f'{relation.replace("_", "-")}')
|
103 |
+
# ax.legend()
|
104 |
+
color_hex = rgba_to_hex(color)
|
105 |
+
|
106 |
+
# Write header in color of relation
|
107 |
+
st.markdown(f"<h3 style='color:{color_hex}'>{relation.replace('_', ' ').title()}</h2>", unsafe_allow_html=True)
|
108 |
+
|
109 |
+
# Show plot
|
110 |
+
st.pyplot(fig)
|
111 |
+
|
112 |
+
# Drop known relation column
|
113 |
+
relation_data = relation_data.drop(columns = 'Known Relation')
|
114 |
+
if target_node_type not in ['disease', 'anatomy']:
|
115 |
+
st.dataframe(relation_data, use_container_width=True,
|
116 |
+
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
117 |
+
help = "Click to visit external database.",
|
118 |
+
display_text = st.session_state.display_database)})
|
119 |
+
else:
|
120 |
+
st.dataframe(relation_data, use_container_width=True)
|
121 |
+
|
122 |
+
else:
|
123 |
+
|
124 |
+
st.error('No ground truth relationships found for the given query in the knowledge graph.', icon="✖️")
|
project_config.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
PROJECT CONFIGURATION FILE
|
3 |
+
This file contains the configuration variables for the project. The variables are used
|
4 |
+
in the other scripts to define the paths to the data and results directories. The variables
|
5 |
+
are also used to set the random seed for reproducibility.
|
6 |
+
'''
|
7 |
+
|
8 |
+
# Import libraries
|
9 |
+
from pathlib import Path
|
10 |
+
import socket
|
11 |
+
import getpass
|
12 |
+
|
13 |
+
def check_internet_connection():
|
14 |
+
try:
|
15 |
+
# Connect to one of the DNS servers
|
16 |
+
socket.create_connection(("8.8.8.8", 53), timeout=5)
|
17 |
+
return True
|
18 |
+
except OSError:
|
19 |
+
return False
|
20 |
+
|
21 |
+
def check_local_machine():
|
22 |
+
hostname = socket.gethostname()
|
23 |
+
username = getpass.getuser()
|
24 |
+
|
25 |
+
return hostname, username
|
26 |
+
|
27 |
+
# Define global variable indicating whether on VDI or not
|
28 |
+
VDI = not check_internet_connection()
|
29 |
+
print(f"VDI: {VDI}")
|
30 |
+
|
31 |
+
# Define global variable to check if running locally
|
32 |
+
hostname, username = check_local_machine()
|
33 |
+
LOCAL = True if username == 'an583' else False
|
34 |
+
|
35 |
+
# Define HF repo variable
|
36 |
+
HF_REPO = 'ayushnoori/clinical-drug-repurposing'
|
37 |
+
|
38 |
+
# Define project configuration variables
|
39 |
+
PROJECT_DIR = Path(__file__).resolve().parent
|
40 |
+
DATA_DIR = PROJECT_DIR / 'data'
|
41 |
+
AUTH_DIR = PROJECT_DIR / 'auth'
|
42 |
+
MODEL_DIR = PROJECT_DIR / 'models'
|
43 |
+
MEDIA_DIR = PROJECT_DIR / 'media'
|
44 |
+
SEED = 42
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
pandas
|
3 |
+
scikit-learn
|
4 |
+
matplotlib
|
5 |
+
seaborn
|
6 |
+
pathlib
|
7 |
+
torch
|
8 |
+
altair<5
|
9 |
+
gspread
|
10 |
+
oauth2client
|
11 |
+
huggingface_hub
|
sync_data.sh
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Set permissions with chmod +x sync_data.sh
|
4 |
+
# Run with ./sync_data.sh
|
5 |
+
|
6 |
+
# Ask the user for the environment
|
7 |
+
echo "Would you like to sync results from O2 or Kempner?"
|
8 |
+
echo "1) O2"
|
9 |
+
echo "2) Kempner"
|
10 |
+
read -p "Enter your choice (1 or 2): " env_choice
|
11 |
+
|
12 |
+
# Ask the user which results folder to sync
|
13 |
+
echo "Which results folder do you want to sync?"
|
14 |
+
echo "1) data: disease splits"
|
15 |
+
echo "2) disease split models: checkpoints"
|
16 |
+
echo "3) disease split models: embeddings"
|
17 |
+
read -p "Enter your choice (1, 2, or 3): " folder_choice
|
18 |
+
|
19 |
+
# Map user input to folder names
|
20 |
+
case $folder_choice in
|
21 |
+
1) SRC_FOLDER="Data/DrugKG/2_harmonize_KG/disease_splits";;
|
22 |
+
2) SRC_FOLDER="Results/GALAXY/disease_splits/checkpoints";;
|
23 |
+
3) SRC_FOLDER="Results/GALAXY/disease_splits/embeddings";;
|
24 |
+
*) echo "Invalid folder choice. Please enter 1, 2, or 3."; exit 1;;
|
25 |
+
esac
|
26 |
+
|
27 |
+
case $env_choice in
|
28 |
+
1) SRC_DIR="[email protected]:/n/data1/hms/dbmi/zitnik/lab/users/an252/NeuroKG/neuroKG/$SRC_FOLDER";;
|
29 |
+
2) SRC_DIR="[email protected]:/n/holylabs/LABS/mzitnik_lab/Users/anoori/neuroKG/$SRC_FOLDER";;
|
30 |
+
*) echo "Invalid source server choice. Please enter 1 or 2."; exit 1;;
|
31 |
+
esac
|
32 |
+
|
33 |
+
# Map user input to destination folder names
|
34 |
+
case $folder_choice in
|
35 |
+
1) DST_DIR="data/disease_splits";;
|
36 |
+
2) DST_DIR="models/checkpoints";; # Don't need checkpoints for this application
|
37 |
+
3) DST_DIR="models/embeddings";;
|
38 |
+
*) echo "Invalid folder choice. Please enter 1, 2, or 3."; exit 1;;
|
39 |
+
esac
|
40 |
+
|
41 |
+
# Sync source and destination folders with specific file types for checkpoints or embeddings
|
42 |
+
# Note, local files not present in the source will be deleted
|
43 |
+
if [[ $folder_choice -eq 2 || $folder_choice -eq 3 ]]; then
|
44 |
+
echo "Syncing only .ckpt or .pt files from $SRC_DIR to $DST_DIR..."
|
45 |
+
rsync -avz -e ssh --include="*.ckpt" --include="*.pt" --exclude="*" --delete $SRC_DIR/ $DST_DIR
|
46 |
+
else
|
47 |
+
echo "Syncing $SRC_DIR to $DST_DIR..."
|
48 |
+
rsync -avz -e ssh --delete $SRC_DIR/ $DST_DIR
|
49 |
+
fi
|
50 |
+
|
51 |
+
echo "Synchronization complete."
|
utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import project_config
|
4 |
+
import base64
|
5 |
+
|
6 |
+
|
7 |
+
@st.cache_data(show_spinner = 'Loading knowledge graph nodes...')
|
8 |
+
def load_kg():
|
9 |
+
# with st.spinner('Loading knowledge graph...'):
|
10 |
+
kg_nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
|
11 |
+
return kg_nodes
|
12 |
+
|
13 |
+
|
14 |
+
@st.cache_data(show_spinner = 'Loading knowledge graph edges...')
|
15 |
+
def load_kg_edges():
|
16 |
+
# with st.spinner('Loading knowledge graph...'):
|
17 |
+
kg_edges = pd.read_csv(project_config.DATA_DIR / 'kg_edges.csv', dtype = {'edge_index': int, 'x_index': int, 'y_index': int}, low_memory = False)
|
18 |
+
return kg_edges
|
19 |
+
|
20 |
+
|
21 |
+
def capitalize_after_slash(s):
|
22 |
+
# Split the string by slashes first
|
23 |
+
parts = s.split('/')
|
24 |
+
# Capitalize each part separately
|
25 |
+
capitalized_parts = [part.title() for part in parts]
|
26 |
+
# Rejoin the parts with slashes
|
27 |
+
capitalized_string = '/'.join(capitalized_parts).replace('_', ' ')
|
28 |
+
return capitalized_string
|