ayushnoori commited on
Commit
d4ca2d2
·
1 Parent(s): dc3f347

Initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ media/pfp/*.png filter=lfs diff=lfs merge=lfs -text
37
+ data/*.csv filter=lfs diff=lfs merge=lfs -text
38
+ *.csv filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore Mac temporary files
2
+ *.DS_Store
3
+ .DS_Store
4
+
5
+ # Ignore python cache files
6
+ __pycache__/
7
+
8
+ # Ignore code
9
+ code/*
10
+
11
+ # Ignore model files
12
+ data/*.pt
13
+ data/disease_splits/*
14
+ models/embeddings/*
15
+ models/checkpoints/*
16
+
17
+ # Ignore secrets
18
+ .streamlit/secrets.toml
19
+
20
+ # Ignore user DB
21
+ auth/*
.streamlit/config.toml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [client]
2
+ showSidebarNavigation = false
3
+
4
+ [theme]
5
+ base="light"
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: Clinical Drug Repurposing
3
- emoji: 🐢
4
- colorFrom: gray
5
- colorTo: green
6
  sdk: streamlit
7
- sdk_version: 1.36.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
  title: Clinical Drug Repurposing
3
+ emoji: ⚕️
4
+ colorFrom: red
5
+ colorTo: purple
6
  sdk: streamlit
7
+ sdk_version: 1.34.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ # User authentication
4
+ import gspread
5
+ from oauth2client.service_account import ServiceAccountCredentials
6
+ import hmac
7
+
8
+ # Standard imports
9
+ import pandas as pd
10
+
11
+ # Custom and other imports
12
+ import project_config
13
+ from menu import menu
14
+
15
+ # Initialize st.session_state.role to None
16
+ if "role" not in st.session_state:
17
+ st.session_state.role = None
18
+
19
+
20
+ # From https://stackoverflow.com/questions/55961295/serviceaccountcredentials-from-json-keyfile-name-equivalent-for-remote-json
21
+ # See also https://www.slingacademy.com/article/pandas-how-to-read-and-update-google-sheet-files/
22
+ # See also https://docs.streamlit.io/develop/tutorials/databases/private-gsheet
23
+ # Note that the secrets cannot be passed in a group in HuggingFace Spaces,
24
+ # which is required for the native Streamlit implementation
25
+ def create_keyfile_dict():
26
+ variables_keys = {
27
+ # "spreadsheet": st.secrets['spreadsheet'], # spreadsheet
28
+ "type": st.secrets['type'], # type
29
+ "project_id": st.secrets['project_id'], # project_id
30
+ "private_key_id": st.secrets['private_key_id'], # private_key_id
31
+ # Have to replace \n with new lines (^l in Word) by hand
32
+ "private_key": st.secrets['private_key'], # private_key
33
+ "client_email": st.secrets['client_email'], # client_email
34
+ "client_id": st.secrets['client_id'], # client_id
35
+ "auth_uri": st.secrets['auth_uri'], # auth_uri
36
+ "token_uri": st.secrets['token_uri'], # token_uri
37
+ "auth_provider_x509_cert_url": st.secrets['auth_provider_x509_cert_url'], # auth_provider_x509_cert_url
38
+ "client_x509_cert_url": st.secrets['client_x509_cert_url'], # client_x509_cert_url
39
+ "universe_domain": st.secrets['universe_domain'] # universe_domain
40
+ }
41
+ return variables_keys
42
+
43
+
44
+ def check_password():
45
+ """Returns `True` if the user had a correct password."""
46
+
47
+ def login_form():
48
+ """Form with widgets to collect user information"""
49
+ # Header
50
+ col1, col2, col3 = st.columns(3)
51
+ with col2:
52
+ st.image(str(project_config.MEDIA_DIR / 'gravity_logo.svg'), width=300)
53
+
54
+ with st.form("Credentials"):
55
+ st.text_input("Username", key="username")
56
+ st.text_input("Password", type="password", key="password")
57
+ st.form_submit_button("Log In", on_click=password_entered)
58
+
59
+ def password_entered():
60
+ """Checks whether a password entered by the user is correct."""
61
+
62
+ if project_config.VDI or project_config.LOCAL:
63
+
64
+ # Read the user database
65
+ user_db = pd.read_csv(project_config.AUTH_DIR / "crd_user_db.csv")
66
+
67
+ else:
68
+
69
+ # Define the scope
70
+ scope = [
71
+ 'https://spreadsheets.google.com/feeds',
72
+ 'https://www.googleapis.com/auth/drive'
73
+ ]
74
+
75
+ # Add credentials to the account
76
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(create_keyfile_dict(), scope)
77
+
78
+ # Authenticate and create the client
79
+ client = gspread.authorize(creds)
80
+
81
+ # Open the spreadsheet
82
+ sheet = client.open_by_url(st.secrets['spreadsheet']).worksheet("user_db")
83
+ data = sheet.get_all_records()
84
+ user_db = pd.DataFrame(data)
85
+
86
+ # Check if the username is in the database
87
+ if st.session_state["username"] in user_db.username.values:
88
+
89
+ st.session_state["username_correct"] = True
90
+
91
+ # Check if the password is correct
92
+ if hmac.compare_digest(
93
+ st.session_state["password"],
94
+ user_db.loc[user_db.username == st.session_state["username"], "password"].values[0],
95
+ ):
96
+
97
+ st.session_state["password_correct"] = True
98
+
99
+ # Check if the username is an admin
100
+ if st.session_state["username"] in user_db[user_db.role == "admin"].username.values:
101
+ st.session_state["role"] = "admin"
102
+ else:
103
+ st.session_state["role"] = "user"
104
+
105
+ # Retrieve and store user name and team
106
+ st.session_state["name"] = user_db.loc[user_db.username == st.session_state["username"], "name"].values[0]
107
+ st.session_state["team"] = user_db.loc[user_db.username == st.session_state["username"], "team"].values[0]
108
+ st.session_state["profile_pic"] = st.session_state["username"]
109
+
110
+ # Don't store the password
111
+ del st.session_state["password"]
112
+
113
+ else:
114
+ st.session_state["password_correct"] = False
115
+
116
+ else:
117
+ st.session_state["username_correct"] = False
118
+ st.session_state["password_correct"] = False
119
+
120
+ # Return True if the username + password is validated
121
+ if st.session_state.get("password_correct", False):
122
+ return True
123
+
124
+ # Show inputs for username + password
125
+ login_form()
126
+ if "password_correct" in st.session_state:
127
+
128
+ if not st.session_state["username_correct"]:
129
+ st.error("User not found.")
130
+ elif not st.session_state["password_correct"]:
131
+ st.error("The password you entered is incorrect.")
132
+ else:
133
+ st.error("An unexpected error occurred.")
134
+
135
+ return False
136
+
137
+ menu() # Render the dynamic menu!
138
+
139
+ if not check_password():
140
+ st.stop()
141
+
142
+ st.switch_page("pages/about.py")
data/kg_edge_types.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ee79b2f5021304a4dd82581568e8a8c940f94b29cd1206f7730bdff6b82cab4
3
+ size 5288
data/kg_edges.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab7d0e23c56381abf8e214cc5d4fae4e6a8b98957c8f2e5272b4f800953b1461
3
+ size 2765378133
data/kg_node_types.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1a0afff52deec5f48689a22a479d14cd49333759e054624366687ec4ef306c8
3
+ size 192
data/kg_nodes.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a21c42a1ee345195038d438854ee5d4befa7b1984e5efa4865ed74825a75b6d9
3
+ size 8529743
media/about_header.svg ADDED
media/explore_header.svg ADDED
media/gravity_logo.png ADDED
media/gravity_logo.svg ADDED
media/input_header.svg ADDED
media/pfp/anoori.png ADDED

Git LFS Details

  • SHA256: 56f2cd51f6496ff1e43f0ce3fb63145a442772b16e3d456bba06cf86d78671cf
  • Pointer size: 132 Bytes
  • Size of remote file: 1.53 MB
media/pfp/gravity.png ADDED

Git LFS Details

  • SHA256: 348a8c9cabd92f92e0e088f24c7ddb10120911c3c492f8e91baf02a870d464fd
  • Pointer size: 131 Bytes
  • Size of remote file: 114 kB
media/pfp/mzitnik.png ADDED

Git LFS Details

  • SHA256: b514858118909ce8004028a1f87f3e7a259415d730d34371830e884a1343da2f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.2 MB
media/pfp/ndagan.png ADDED

Git LFS Details

  • SHA256: d7d169b3a4cceca7bcb829ae02ea5ce912c94b5b66006343f1d5bfdc7296ce79
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB
media/pfp/rbalicer.png ADDED

Git LFS Details

  • SHA256: c1b68144f018798483fb7485d9fa46350e6274bc9ece3d5548ec33226ca2690d
  • Pointer size: 131 Bytes
  • Size of remote file: 580 kB
media/predict_header.svg ADDED
media/validate_header.svg ADDED
menu.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://docs.streamlit.io/develop/tutorials/multipage/st.page_link-nav
2
+ import streamlit as st
3
+ import os
4
+ import project_config
5
+
6
+ def authenticated_menu():
7
+
8
+ # Insert profile picture
9
+ pfp_path = str(project_config.MEDIA_DIR / 'pfp' / f"{st.session_state.profile_pic}.png")
10
+ if not os.path.exists(pfp_path):
11
+ pfp_path = str(project_config.MEDIA_DIR / 'pfp' / "gravity.png")
12
+ st.sidebar.image(pfp_path, use_column_width=True)
13
+ st.sidebar.markdown("---")
14
+
15
+ # Show a navigation menu for authenticated users
16
+ # st.sidebar.page_link("app.py", label="Switch Accounts", icon="🔒")
17
+ st.sidebar.page_link("pages/about.py", label="About", icon="📖")
18
+ st.sidebar.page_link("pages/input.py", label="Input", icon="💡")
19
+ st.sidebar.page_link("pages/predict.py", label="Predict", icon="🔍",
20
+ disabled=("query" not in st.session_state))
21
+ st.sidebar.page_link("pages/validate.py", label="Validate", icon="✅",
22
+ disabled=("query" not in st.session_state))
23
+ # st.sidebar.page_link("pages/explore.py", label="Explore", icon="🔍")
24
+ if st.session_state.role in ["admin"]:
25
+ st.sidebar.page_link("pages/admin.py", label="Manage Users", icon="🔧")
26
+
27
+ # Show the logout button
28
+ st.sidebar.markdown("---")
29
+ st.sidebar.button("Log Out", on_click=lambda: st.session_state.clear())
30
+
31
+
32
+ def unauthenticated_menu():
33
+
34
+ # Show a navigation menu for unauthenticated users
35
+ st.sidebar.page_link("app.py", label="Log In", icon="🔒")
36
+
37
+
38
+ def menu():
39
+ # Determine if a user is logged in or not, then show the correct navigation menu
40
+ if "role" not in st.session_state or st.session_state.role is None:
41
+ unauthenticated_menu()
42
+ return
43
+ authenticated_menu()
44
+
45
+
46
+ def menu_with_redirect():
47
+ # Redirect users to the main page if not logged in, otherwise continue to
48
+ # render the navigation menu
49
+ if "role" not in st.session_state or st.session_state.role is None:
50
+ st.switch_page("app.py")
51
+ menu()
pages/about.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from menu import menu_with_redirect
3
+
4
+ # Path manipulation
5
+ from pathlib import Path
6
+
7
+ # Custom and other imports
8
+ import project_config
9
+
10
+ # Redirect to app.py if not logged in, otherwise show the navigation menu
11
+ menu_with_redirect()
12
+
13
+ # Header
14
+ st.image(str(project_config.MEDIA_DIR / 'about_header.svg'), use_column_width=True)
15
+
16
+ # Main content
17
+ st.markdown(f"Hello, {st.session_state.name}! Welcome to GRAVITY, a **GR**aph **A**I **VI**sualization **T**ool to query and visualize knowledge graph-grounded biomedical AI models.")
18
+
19
+ # Subheader
20
+ st.subheader("Clinical Drug Repurposing", divider = "grey")
21
+
22
+ st.markdown("""
23
+ Here, we use GRAVITY to visualize the outputs of our clinical drug repurposing algorithm. The algorithm predicts the probability of a drug treating a disease based on the drug-disease relationship in the knowledge graph.
24
+ """)
25
+
26
+ col1, col2, col3 = st.columns(3)
27
+
28
+ with col2:
29
+ if st.button("Make Predictions"):
30
+ st.switch_page("pages/input.py")
pages/admin.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from menu import menu_with_redirect
3
+
4
+ # Redirect to app.py if not logged in, otherwise show the navigation menu
5
+ menu_with_redirect()
6
+
7
+ # Verify the user's role
8
+ if st.session_state.role not in ["admin"]:
9
+ st.warning("You do not have permission to view this page.")
10
+ st.stop()
11
+
12
+ st.title("User Management")
13
+ st.markdown(f"You are currently logged with the role of {st.session_state.role}.")
pages/input.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from menu import menu_with_redirect
3
+
4
+ # Standard imports
5
+ import numpy as np
6
+ import pandas as pd
7
+ import subprocess
8
+
9
+ # Path manipulation
10
+ import os
11
+ from pathlib import Path
12
+ from huggingface_hub import hf_hub_download
13
+
14
+ # Custom and other imports
15
+ import project_config
16
+ from utils import load_kg
17
+
18
+ # Redirect to app.py if not logged in, otherwise show the navigation menu
19
+ menu_with_redirect()
20
+
21
+ # Header
22
+ st.image(str(project_config.MEDIA_DIR / 'input_header.svg'), use_column_width=True)
23
+
24
+ st.subheader("Choose Disease Split", divider = "red")
25
+
26
+ with st.spinner('Loading disease splits...'):
27
+
28
+ if project_config.VDI or project_config.LOCAL:
29
+
30
+ # Read disease splits
31
+ # Load from Kempner using sync_data.sh
32
+ disease_splits = pd.read_csv(project_config.DATA_DIR / 'disease_splits' / 'disease_splits.csv',
33
+ dtype = {'node_index': str, 'disease_split_index': str})
34
+
35
+ else:
36
+
37
+ # Read disease splits from HF
38
+ disease_splits = hf_hub_download(repo_id=project_config.HF_REPO,
39
+ filename='disease_split/disease_splits.csv',
40
+ token=st.secrets["HF_TOKEN"], repo_type="dataset")
41
+
42
+ # Group disease splits by disease_split_index column
43
+ disease_splits_grouped = disease_splits.groupby('disease_split_index').size().reset_index(name='node_count')
44
+
45
+ # Subset to unique disease splits
46
+ splits_df =disease_splits[disease_splits['node_index'] == disease_splits['disease_split_index']]
47
+ splits_df = splits_df.drop_duplicates(subset='disease_split_index').reset_index(drop=True)
48
+ splits_df = splits_df[['node_index', 'node_name', 'node_id']]
49
+
50
+ # Merge with counts
51
+ splits_df = splits_df.merge(disease_splits_grouped, left_on='node_index', right_on='disease_split_index', how='left')
52
+ splits_df = splits_df.drop(columns='disease_split_index')
53
+ splits_df['node_name'] = splits_df['node_name'].str.replace(' \\(disease\\)', '', regex=True)
54
+
55
+ # Add row for all to beginning
56
+ splits_df['node_index'] = splits_df['node_index'].astype(str)
57
+ splits_df = pd.concat([pd.DataFrame([['all', 'all diseases', None, disease_splits.shape[0]]], columns=splits_df.columns), splits_df], ignore_index=True)
58
+
59
+ # For each disease split, count number of edges (number of rows in CSV file in disease_splits directory)
60
+ # Do not read file in
61
+ edge_counts = []
62
+ for index, row in splits_df.iterrows():
63
+ # Count lines
64
+ file_name = project_config.DATA_DIR / 'disease_splits' / 'split_edges' / f'{row["node_index"]}.csv'
65
+ edge_count = int(subprocess.check_output(['wc', '-l', file_name]).split()[0]) - 1
66
+ edge_counts.append(edge_count)
67
+
68
+ # Add edge counts to splits_df
69
+ splits_df['edge_count'] = edge_counts
70
+
71
+ # Get list of available modles
72
+ model_files = os.listdir(project_config.MODEL_DIR / 'embeddings')
73
+ model_files = [f for f in model_files if f.endswith('_embeddings.pt')]
74
+
75
+ # Get model metadata
76
+ def get_model_metadata(f):
77
+
78
+ # Get metadata
79
+ metadata = f.split('_')
80
+ date = '_'.join(metadata[:6])
81
+ date = pd.to_datetime(date, format='%Y_%m_%d_%H_%M_%S')
82
+
83
+ # Parameters
84
+ params = metadata[6].split('-')
85
+ params = {p.split('=')[0]: p.split('=')[1] for p in params}
86
+
87
+ # Add date to params
88
+ params['date'] = date
89
+ params['file'] = f
90
+ return params
91
+
92
+ # Get available models, only keep latest version per split
93
+ avail_models = pd.DataFrame([get_model_metadata(f) for f in model_files])
94
+ avail_models = avail_models.sort_values('date', ascending=False).drop_duplicates('test').reset_index(drop=True)
95
+ # avail_models.loc[avail_models['test'] == 'all', 'test'] = 'all diseases'
96
+
97
+ # Add column to indicate if model is available
98
+ splits_df['available'] = splits_df['node_index'].isin(avail_models['test'])
99
+
100
+ # If all diseases model is available, set all diseases to available
101
+ if avail_models['test'].str.contains('all').any():
102
+ splits_df.loc[splits_df['node_name'] == 'all diseases', 'available'] = True
103
+
104
+ ####################################################################################################
105
+
106
+ # Select disease split from splits with available models
107
+ # Make dictionary with node_index: node_name, where name is value shown but index is used for query
108
+ # split_options = splits_df[splits_df['available']].copy()
109
+ split_options = splits_df.copy()
110
+ split_options = split_options.set_index('node_index')['node_name'].to_dict()
111
+
112
+ # Check if split is in session state
113
+ if "split" not in st.session_state:
114
+ split_index = 0
115
+ else:
116
+ split_index = list(split_options.keys()).index(st.session_state.split)
117
+
118
+ split = st.selectbox("Disease Split", list(split_options.keys()), format_func = lambda x: split_options[x],
119
+ index = split_index)
120
+
121
+ # Show all splits dataframe
122
+ splits_display = splits_df[['node_index', 'node_name', 'node_count', 'edge_count', 'available']].copy()
123
+ splits_display = splits_display.rename(columns = {'node_index': 'Split ID', 'node_name': 'Disease', 'node_count': 'Node Count', 'edge_count': 'Edge Count', 'available': 'Model Available'})
124
+ st.dataframe(splits_display, use_container_width = True, hide_index = True)
125
+
126
+ # Save split and available models to session state
127
+ st.session_state.split = split
128
+ st.session_state.splits_df = splits_df
129
+ st.session_state.avail_models = avail_models
130
+
131
+ if st.button("Explore Split"):
132
+ st.switch_page("pages/split.py")
133
+
134
+ ####################################################################################################
135
+
136
+ st.subheader("Construct Query", divider = "red")
137
+
138
+ # # Checkbox to allow reverse edges
139
+ # allow_reverse_edges = st.checkbox("Allow reverse edges?", value = False)
140
+ allow_reverse_edges = False
141
+
142
+ # Load knowledge graph
143
+ kg_nodes = load_kg()
144
+
145
+ with st.spinner('Loading knowledge graph...'):
146
+ # kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
147
+ node_types = pd.read_csv(project_config.DATA_DIR / 'kg_node_types.csv')
148
+ edge_types = pd.read_csv(project_config.DATA_DIR / 'kg_edge_types.csv')
149
+
150
+ if not allow_reverse_edges:
151
+ edge_types = edge_types[edge_types.direction == 'forward']
152
+
153
+ # If query is not in session state, initialize it
154
+ if "query" not in st.session_state:
155
+ source_node_type_index = 0
156
+ source_node_index = 0
157
+ target_node_type_index = 0
158
+ relation_index = 0
159
+
160
+ if st.session_state.team == "Clalit":
161
+ source_node_type_index = 2
162
+ source_node_index = 0
163
+ target_node_type_index = 3
164
+ relation_index = 2
165
+
166
+ else:
167
+ source_node_type_index = st.session_state.query_options['source_node_type'].index(st.session_state.query['source_node_type'])
168
+ source_node_index = st.session_state.query_options['source_node'].index(st.session_state.query['source_node'])
169
+ target_node_type_index = st.session_state.query_options['target_node_type'].index(st.session_state.query['target_node_type'])
170
+ relation_index = st.session_state.query_options['relation'].index(st.session_state.query['relation'])
171
+
172
+ # Select source node type
173
+ source_node_type_options = node_types['node_type']
174
+ source_node_type = st.selectbox("Source Node Type", source_node_type_options,
175
+ format_func = lambda x: x.replace("_", " "), index = source_node_type_index)
176
+
177
+ # If source node type is disease, add option to select only diseases in current split
178
+ if source_node_type == 'disease':
179
+
180
+ # Get diseases in current split
181
+ if split == 'all':
182
+ split_diseases = disease_splits.drop_duplicates(subset='node_name')['node_name']
183
+ else:
184
+ split_diseases = disease_splits[disease_splits['disease_split_index'] == split]
185
+ split_diseases = split_diseases.drop_duplicates(subset='node_name')['node_name']
186
+
187
+ # Add checkbox to filter diseases
188
+ filter_diseases = st.checkbox("Filter diseases to current split?", value = False)
189
+
190
+ # Select source node
191
+
192
+ if source_node_type == 'disease' and filter_diseases:
193
+ # source_node_options = source_node_options[source_node_options.isin(split_diseases)]
194
+ source_node_options = split_diseases
195
+ else:
196
+ source_node_options = kg_nodes[kg_nodes['node_type'] == source_node_type]['node_name']
197
+ source_node = st.selectbox("Source Node", source_node_options,
198
+ index = source_node_index)
199
+
200
+ # Select target node type
201
+ target_node_type_options = edge_types[edge_types.x_type == source_node_type].y_type.unique()
202
+ target_node_type = st.selectbox("Target Node Type", target_node_type_options,
203
+ format_func = lambda x: x.replace("_", " "), index = target_node_type_index)
204
+
205
+ # Select relation
206
+ relation_options = edge_types[(edge_types.x_type == source_node_type) & (edge_types.y_type == target_node_type)].relation.unique()
207
+ relation = st.selectbox("Edge Type", relation_options,
208
+ format_func = lambda x: x.replace("_", "-"), index = relation_index)
209
+
210
+ # Button to submit query
211
+ if st.button("Submit Query"):
212
+
213
+ # Check if model is available for split
214
+ model_avail = splits_df.loc[splits_df['node_index'] == st.session_state.split, 'available'].values[0]
215
+ if not model_avail:
216
+
217
+ st.error("A trained model is not yet available for this disease split. Please select another disease split for which a trained model is available.", icon="🚨")
218
+
219
+ else:
220
+
221
+ # Save query to session state
222
+ st.session_state.query = {
223
+ "source_node_type": source_node_type,
224
+ "source_node": source_node,
225
+ "target_node_type": target_node_type,
226
+ "relation": relation
227
+ }
228
+
229
+ # Save query options to session state
230
+ st.session_state.query_options = {
231
+ "source_node_type": list(source_node_type_options),
232
+ "source_node": list(source_node_options),
233
+ "target_node_type": list(target_node_type_options),
234
+ "relation": list(relation_options)
235
+ }
236
+
237
+ # Delete validation from session state
238
+ if "validation" in st.session_state:
239
+ del st.session_state.validation
240
+
241
+ # # Write query to console
242
+ # st.write("Current Query:")
243
+ # st.write(st.session_state.query)
244
+ st.write("Query submitted.")
245
+
246
+ # Switch to the Predict page
247
+ st.switch_page("pages/predict.py")
248
+
249
+
250
+ st.subheader("Knowledge Graph", divider = "red")
251
+ display_data = kg_nodes[['node_id', 'node_type', 'node_name', 'node_source']].copy()
252
+ display_data = display_data.rename(columns = {'node_id': 'ID', 'node_type': 'Type', 'node_name': 'Name', 'node_source': 'Database'})
253
+ st.dataframe(display_data, use_container_width = True, hide_index = True)
pages/predict.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from menu import menu_with_redirect
3
+
4
+ # Standard imports
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ # Path manipulation
12
+ from pathlib import Path
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ # Plotting
16
+ import matplotlib.pyplot as plt
17
+ plt.rcParams['font.sans-serif'] = 'Arial'
18
+
19
+ # Custom and other imports
20
+ import project_config
21
+ from utils import capitalize_after_slash, load_kg
22
+
23
+ # Redirect to app.py if not logged in, otherwise show the navigation menu
24
+ menu_with_redirect()
25
+
26
+ # Header
27
+ st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=True)
28
+
29
+ # Main content
30
+ # st.markdown(f"Hello, {st.session_state.name}!")
31
+
32
+ st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'])} Search", divider = "blue")
33
+
34
+ # Print current query
35
+ st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
36
+
37
+ @st.cache_data(show_spinner = 'Downloading AI model...')
38
+ def get_embeddings():
39
+
40
+ # # Get checkpoint name
41
+ # best_ckpt = "2024_05_15_13_05_33_epoch=2-step=40383"
42
+
43
+ # # Get paths to embeddings, relation weights, and edge types
44
+ # # with st.spinner('Downloading AI model...'):
45
+ # embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
46
+ # filename=(best_ckpt + "-thresh=4000_embeddings.pt"),
47
+ # token=st.secrets["HF_TOKEN"])
48
+ # relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
49
+ # filename=(best_ckpt + "_relation_weights.pt"),
50
+ # token=st.secrets["HF_TOKEN"])
51
+ # edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
52
+ # filename=(best_ckpt + "_edge_types.pt"),
53
+ # token=st.secrets["HF_TOKEN"])
54
+
55
+ # Get split name
56
+ split = st.session_state.split
57
+ avail_models = st.session_state.avail_models
58
+
59
+ # Get model name from avail_models
60
+ embed_name = avail_models[avail_models['test'] == 'all']['file'].values[0]
61
+ relation_weights_name = embed_name.replace('_embeddings.pt', '_relation_weights.pt')
62
+ edge_types_name = embed_name.replace('_embeddings.pt', '_edge_types.pt')
63
+
64
+ # Convert to paths
65
+ embed_path = project_config.MODEL_DIR / 'embeddings' / embed_name
66
+ relation_weights_path = project_config.MODEL_DIR / 'embeddings' / relation_weights_name
67
+ edge_types_path = project_config.MODEL_DIR / 'embeddings' / edge_types_name
68
+
69
+ return embed_path, relation_weights_path, edge_types_path
70
+
71
+ @st.cache_data(show_spinner = 'Loading AI model...')
72
+ def load_embeddings(embed_path, relation_weights_path, edge_types_path):
73
+
74
+ # Load embeddings, relation weights, and edge types
75
+ # with st.spinner('Loading AI model...'):
76
+ embeddings = torch.load(embed_path)
77
+ relation_weights = torch.load(relation_weights_path)
78
+ edge_types = torch.load(edge_types_path)
79
+
80
+ return embeddings, relation_weights, edge_types
81
+
82
+ # Load knowledge graph and embeddings
83
+ kg_nodes = load_kg()
84
+ embed_path, relation_weights_path, edge_types_path = get_embeddings()
85
+ embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path)
86
+
87
+ # # Print source node type
88
+ # st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
89
+
90
+ # # Print source node
91
+ # st.write(f"Source Node: {st.session_state.query['source_node']}")
92
+
93
+ # # Print relation
94
+ # st.write(f"Edge Type: {st.session_state.query['relation']}")
95
+
96
+ # # Print target node type
97
+ # st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
98
+
99
+ # Compute predictions
100
+ with st.spinner('Computing predictions...'):
101
+
102
+ source_node_type = st.session_state.query['source_node_type']
103
+ source_node = st.session_state.query['source_node']
104
+ relation = st.session_state.query['relation']
105
+ target_node_type = st.session_state.query['target_node_type']
106
+
107
+ # Get source node index
108
+ src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
109
+
110
+ # Get relation index
111
+ edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
112
+
113
+ # Get target nodes indices
114
+ target_nodes = kg_nodes[kg_nodes.node_type == target_node_type].copy()
115
+ dst_indices = target_nodes.node_index.values
116
+ src_indices = np.repeat(src_index, len(dst_indices))
117
+
118
+ # Retrieve cached embeddings and apply activation function
119
+ src_embeddings = embeddings[src_indices]
120
+ dst_embeddings = embeddings[dst_indices]
121
+ src_embeddings = F.leaky_relu(src_embeddings)
122
+ dst_embeddings = F.leaky_relu(dst_embeddings)
123
+
124
+ # Get relation weights
125
+ rel_weights = relation_weights[edge_type_index]
126
+
127
+ # Compute weighted dot product
128
+ scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
129
+ scores = torch.sigmoid(scores)
130
+
131
+ # Add scores to dataframe
132
+ target_nodes['score'] = scores.detach().numpy()
133
+ target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
134
+ target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
135
+
136
+ # Rename columns
137
+ display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
138
+ display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'score': 'Score', 'node_source': 'Database'})
139
+
140
+ # Define dictionary mapping node types to database URLs
141
+ map_dbs = {
142
+ 'gene/protein': lambda x: f"https://ncbi.nlm.nih.gov/gene/?term={x}",
143
+ 'drug': lambda x: f"https://go.drugbank.com/drugs/{x}",
144
+ 'effect/phenotype': lambda x: f"https://hpo.jax.org/app/browse/term/HP:{x.zfill(7)}", # pad with 0s to 7 digits
145
+ 'disease': lambda x: x, # MONDO
146
+ # pad with 0s to 7 digits
147
+ 'biological_process': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
148
+ 'molecular_function': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
149
+ 'cellular_component': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
150
+ 'exposure': lambda x: f"https://ctdbase.org/detail.go?type=chem&acc={x}",
151
+ 'pathway': lambda x: f"https://reactome.org/content/detail/{x}",
152
+ 'anatomy': lambda x: x,
153
+ }
154
+
155
+ # Get name of database
156
+ display_database = display_data['Database'].values[0]
157
+
158
+ # Add URLs to database column
159
+ display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
160
+
161
+ # Check if validation data exists
162
+ if 'validation' in st.session_state:
163
+
164
+ # Checkbox to allow reverse edges
165
+ show_val = st.checkbox("Show ground truth validation?", value = False)
166
+
167
+ if show_val:
168
+
169
+ # Get validation data
170
+ val_results = st.session_state.validation.copy()
171
+
172
+ # Merge with predictions
173
+ val_display_data = pd.merge(display_data, val_results, left_on = 'ID', right_on = 'y_id', how='left')
174
+ val_display_data = val_display_data.fillna(0).drop(columns='y_id')
175
+
176
+ # Get new columns
177
+ val_relations = val_display_data.columns.difference(display_data.columns).tolist()
178
+
179
+ # Replace 0 with blank and 1 with check emoji in new columns
180
+ for col in val_relations:
181
+ val_display_data[col] = val_display_data[col].replace({0: '', 1: '✅'})
182
+
183
+ # Define a function to apply styles
184
+ def style_val(val):
185
+ if val == '✅':
186
+ return 'background-color: #C2EABD;' # text-align: center;
187
+ return 'background-color: #F5F5F5;' # text-align: center;
188
+
189
+ else:
190
+ show_val = False
191
+
192
+
193
+ # NODE SEARCH
194
+
195
+ # Use multiselect to search for specific nodes
196
+ selected_nodes = st.multiselect(f"Search for specific {target_node_type.replace('_', ' ')} nodes to determine their ranking.",
197
+ display_data.Name, placeholder = "Type to search...")
198
+
199
+ # Filter nodes
200
+ if len(selected_nodes) > 0:
201
+
202
+ if show_val:
203
+ # selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)]
204
+ selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)].copy()
205
+ selected_display_data = selected_display_data.reset_index(drop=True).style.map(style_val, subset=val_relations)
206
+ else:
207
+ selected_display_data = display_data[display_data.Name.isin(selected_nodes)].copy()
208
+ selected_display_data = selected_display_data.reset_index(drop=True)
209
+
210
+ st.markdown(f"Out of {target_nodes.shape[0]} {target_node_type} nodes, the selected nodes rank as follows:")
211
+ selected_display_data_with_rank = selected_display_data.copy()
212
+ selected_display_data_with_rank['Rank'] = selected_display_data_with_rank['Rank'].apply(lambda x: f"{x} (top {(100*x/target_nodes.shape[0]):.2f}% of predictions)")
213
+
214
+ # Show filtered nodes
215
+ if target_node_type not in ['disease', 'anatomy']:
216
+ st.dataframe(selected_display_data_with_rank, use_container_width = True, hide_index = True,
217
+ column_config={"Database": st.column_config.LinkColumn(width = "small",
218
+ help = "Click to visit external database.",
219
+ display_text = display_database)})
220
+ else:
221
+ st.dataframe(selected_display_data_with_rank, use_container_width = True)
222
+
223
+ # Show plot
224
+ st.markdown(f"In the plot below, the dashed lines represent the rank of the selected {target_node_type} nodes across all predictions for {source_node}.")
225
+
226
+ # Checkbox to show text labels
227
+ show_labels = st.checkbox("Show Text Labels?", value = False)
228
+
229
+ # Plot rank vs. score using matplotlib
230
+ fig, ax = plt.subplots(figsize = (10, 6))
231
+ ax.plot(display_data['Rank'], display_data['Score'], color = 'black', linewidth = 1.5, zorder = 2)
232
+ ax.set_xlabel('Rank', fontsize = 12)
233
+ ax.set_ylabel('Score', fontsize = 12)
234
+ ax.set_xlim(1, display_data['Rank'].max())
235
+
236
+ # Get color palette
237
+ # palette = plt.cm.get_cmap('tab10', len(selected_display_data))
238
+ palette = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"]
239
+
240
+ # Add vertical line for selected nodes
241
+ for i, node in selected_display_data.iterrows():
242
+ ax.scatter(node['Rank'], node['Score'], color = palette[i], zorder=3)
243
+ ax.axvline(node['Rank'], color = palette[i], linestyle = '--', linewidth = 1.5, label = node['Name'], zorder=3)
244
+ if show_labels:
245
+ ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize = 10, color = palette[i], zorder=3)
246
+
247
+ # Add legend
248
+ ax.legend(loc = 'upper right', fontsize = 10)
249
+ ax.grid(alpha = 0.2, zorder=0)
250
+
251
+ st.pyplot(fig)
252
+
253
+
254
+ # FULL RESULTS
255
+
256
+ # Show top ranked nodes
257
+ st.subheader("Model Predictions", divider = "blue")
258
+ top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
259
+
260
+ # Show full results
261
+ # full_results = val_display_data.iloc[:top_k] if show_val else display_data.iloc[:top_k]
262
+ full_results = val_display_data.iloc[:top_k].style.map(style_val, subset=val_relations) if show_val else display_data.iloc[:top_k]
263
+
264
+ if target_node_type not in ['disease', 'anatomy']:
265
+ st.dataframe(full_results, use_container_width = True, hide_index = True,
266
+ column_config={"Database": st.column_config.LinkColumn(width = "small",
267
+ help = "Click to visit external database.",
268
+ display_text = display_database)})
269
+ else:
270
+ st.dataframe(full_results, use_container_width = True, hide_index = True,)
271
+
272
+ # Save to session state
273
+ st.session_state.predictions = display_data
274
+ st.session_state.display_database = display_database
pages/split.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from menu import menu_with_redirect
3
+
4
+ # Standard imports
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ # Path manipulation
9
+ import os
10
+ from pathlib import Path
11
+
12
+ # Plotting
13
+ import matplotlib.pyplot as plt
14
+ plt.rcParams['font.sans-serif'] = 'Arial'
15
+
16
+ # Custom and other imports
17
+ import project_config
18
+
19
+ # Redirect to app.py if not logged in, otherwise show the navigation menu
20
+ menu_with_redirect()
21
+
22
+ # Back button with emoji
23
+ if st.button("◀️ Back"):
24
+ st.switch_page("pages/input.py")
25
+
26
+ # Get metadata from session state
27
+ split = st.session_state.split
28
+ splits_df = st.session_state.splits_df
29
+
30
+ with st.spinner('Loading disease splits...'):
31
+
32
+ # Read disease splits
33
+ disease_split_nodes = pd.read_csv(project_config.DATA_DIR / 'disease_splits' / 'disease_splits.csv', dtype = {'disease_split_index': str})
34
+
35
+ # If split is all
36
+ if split == 'all':
37
+ disease_split_nodes = disease_split_nodes[['node_index', 'node_name', 'embedding_score', 'levenshtein_score', 'neighborhood_score', 'method', 'disease_split']]
38
+ disease_split_nodes = disease_split_nodes.rename(columns = {'node_index': 'Node ID', 'node_name': 'Disease', 'embedding_score': 'Embedding Score', 'levenshtein_score': 'Levenshtein Score', 'neighborhood_score': 'Neighborhood Score', 'method': 'Method', 'disease_split': 'Disease Split'})
39
+
40
+ else:
41
+ disease_split_nodes = disease_split_nodes[disease_split_nodes['disease_split_index'] == split]
42
+ disease_split_nodes = disease_split_nodes[['node_index', 'node_name', 'embedding_score', 'levenshtein_score', 'neighborhood_score', 'method']]
43
+ disease_split_nodes = disease_split_nodes.rename(columns = {'node_index': 'Node ID', 'node_name': 'Disease', 'embedding_score': 'Embedding Score', 'levenshtein_score': 'Levenshtein Score', 'neighborhood_score': 'Neighborhood Score', 'method': 'Method'})
44
+
45
+ # Read disease split edges
46
+ disease_split_edges = pd.read_csv(project_config.DATA_DIR / 'disease_splits' / 'split_edges' / f'{split}.csv')
47
+
48
+ # Subset and rename columns
49
+ disease_split_edges = disease_split_edges[['relation', 'x_index', 'x_type', 'x_name', 'y_index', 'y_type', 'y_name']]
50
+ disease_split_edges['relation'] = disease_split_edges['relation'].str.replace('_', ' ').str.title()
51
+ disease_split_edges = disease_split_edges.rename(columns = {'relation': 'Relation', 'x_index': 'Source ID', 'x_type': 'Source Type', 'x_name': 'Source Name', 'y_index': 'Target ID', 'y_type': 'Target Type', 'y_name': 'Target Name'})
52
+
53
+ st.subheader("Nodes in Disease Split", divider = "blue")
54
+ st.markdown(f"**Disease Split:** {splits_df[splits_df['node_index'] == split]['node_name'].values[0]}")
55
+ st.markdown(f"**Number of Nodes:** {disease_split_nodes.shape[0]}")
56
+
57
+ # Show as dataframe
58
+ st.dataframe(disease_split_nodes, use_container_width = True, hide_index = True)
59
+
60
+ st.markdown("Below, we show the number of nodes by method of inclusion in the disease split.")
61
+
62
+ # Plotting the bar plot
63
+ method_counts = disease_split_nodes['Method'].value_counts().reset_index()
64
+ method_counts.columns = ['Method', 'Count']
65
+ method_counts['Method Length'] = method_counts['Method'].apply(len)
66
+ method_counts = method_counts.sort_values('Method Length')
67
+
68
+ # Plotting the bar plot
69
+ plt.figure(figsize=(10, 6))
70
+ bars = plt.bar(method_counts['Method'], method_counts['Count'], color='#B8D4F7', edgecolor='black')
71
+ plt.xlabel('Method', fontsize=16, fontweight='bold')
72
+ plt.ylabel('Count', fontsize=16, fontweight='bold')
73
+ plt.xticks(rotation=45, ha='right', fontsize=12)
74
+ plt.tight_layout()
75
+
76
+ # Adding labels on top of each bar
77
+ for bar in bars:
78
+ yval = bar.get_height()
79
+ plt.text(bar.get_x() + bar.get_width()/2.0, yval, int(yval), va='bottom', fontsize=12)
80
+ plt.ylim(0, max(method_counts['Count'])*1.1)
81
+
82
+ # Show plot
83
+ st.pyplot(plt)
84
+
85
+
86
+ st.subheader("Edges in Disease Split", divider = "green")
87
+ st.markdown(f"**Number of Edges:** {disease_split_edges.shape[0]}")
88
+
89
+ # Show as dataframe
90
+ st.dataframe(disease_split_edges, use_container_width = True, hide_index = True)
91
+
92
+ if disease_split_edges.shape[0] > 0:
93
+
94
+ # Make bar plot of number of edges by relation
95
+ st.markdown("Below, we show the number of edges by relation in the disease split.")
96
+ relation_counts = disease_split_edges['Relation'].value_counts().reset_index()
97
+ relation_counts.columns = ['Relation', 'Count']
98
+
99
+ # Plotting the bar plot
100
+ plt.figure(figsize=(10, 6))
101
+ bars = plt.bar(relation_counts['Relation'], relation_counts['Count'], color='#C1ECC5', edgecolor='black')
102
+ plt.xlabel('Relation', fontsize=16, fontweight='bold')
103
+ plt.ylabel('Count', fontsize=16, fontweight='bold')
104
+ plt.xticks(rotation=45, ha='right', fontsize=12)
105
+ plt.tight_layout()
106
+
107
+ # Adding labels on top of each bar
108
+ for bar in bars:
109
+ yval = bar.get_height()
110
+ plt.text(bar.get_x() + bar.get_width()/2.0, yval, int(yval), va='bottom', fontsize=12)
111
+ plt.ylim(0, max(relation_counts['Count'])*1.1)
112
+
113
+ # Show plot
114
+ st.pyplot(plt)
pages/validate.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from menu import menu_with_redirect
3
+
4
+ # Standard imports
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ # Path manipulation
9
+ from pathlib import Path
10
+
11
+ # Plotting
12
+ import matplotlib.pyplot as plt
13
+ plt.rcParams['font.sans-serif'] = 'Arial'
14
+ import matplotlib.colors as mcolors
15
+
16
+ # Custom and other imports
17
+ import project_config
18
+ from utils import load_kg, load_kg_edges
19
+
20
+ # Redirect to app.py if not logged in, otherwise show the navigation menu
21
+ menu_with_redirect()
22
+
23
+ # Header
24
+ st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True)
25
+
26
+ # Main content
27
+ # st.markdown(f"Hello, {st.session_state.name}!")
28
+
29
+ st.subheader("Validate Predictions", divider = "green")
30
+
31
+ # Print current query
32
+ st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
33
+
34
+ # Coming soon
35
+ # st.write("Coming soon...")
36
+
37
+ source_node_type = st.session_state.query['source_node_type']
38
+ source_node = st.session_state.query['source_node']
39
+ relation = st.session_state.query['relation']
40
+ target_node_type = st.session_state.query['target_node_type']
41
+ predictions = st.session_state.predictions
42
+
43
+ kg_nodes = load_kg()
44
+ kg_edges = load_kg_edges()
45
+
46
+ # Convert tuple to hex
47
+ def rgba_to_hex(rgba):
48
+ return mcolors.to_hex(rgba[:3])
49
+
50
+ with st.spinner('Searching known relationships...'):
51
+
52
+ # Subset existing edges
53
+ edge_subset = kg_edges[(kg_edges.x_type == source_node_type) & (kg_edges.x_name == source_node)]
54
+ edge_subset = edge_subset[edge_subset.y_type == target_node_type]
55
+
56
+ # Merge edge subset with predictions
57
+ edges_in_kg = pd.merge(predictions, edge_subset[['relation', 'y_id']], left_on = 'ID', right_on = 'y_id', how = 'right')
58
+ edges_in_kg = edges_in_kg.sort_values(by = 'Score', ascending = False)
59
+ edges_in_kg = edges_in_kg.drop(columns = 'y_id')
60
+
61
+ # Rename relation to ground-truth
62
+ edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']]
63
+ edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'})
64
+
65
+ # If there exist edges in KG
66
+ if len(edges_in_kg) > 0:
67
+
68
+ with st.spinner('Saving validation results...'):
69
+
70
+ # Cast long to wide
71
+ val_results = edge_subset[['relation', 'y_id']].pivot_table(index='y_id', columns='relation', aggfunc='size', fill_value=0)
72
+ val_results = (val_results > 0).astype(int).reset_index()
73
+ val_results.columns = [val_results.columns[0]] + [x.replace('_', ' ').title() for x in val_results.columns[1:]]
74
+
75
+ # Save validation results to session state
76
+ st.session_state.validation = val_results
77
+
78
+ with st.spinner('Plotting known relationships...'):
79
+
80
+ # Define a color map for different relations
81
+ color_map = plt.get_cmap('tab10')
82
+
83
+ # Group by relation and create separate plots
84
+ relations = edges_in_kg['Known Relation'].unique()
85
+ for idx, relation in enumerate(relations):
86
+
87
+ relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation]
88
+
89
+ # Get a color from the color map
90
+ color = color_map(idx % color_map.N)
91
+
92
+ fig, ax = plt.subplots(figsize=(10, 3))
93
+ ax.plot(predictions['Rank'], predictions['Score'])
94
+ ax.set_xlabel('Rank', fontsize=12)
95
+ ax.set_ylabel('Score', fontsize=12)
96
+ ax.set_xlim(1, predictions['Rank'].max())
97
+
98
+ for i, node in relation_data.iterrows():
99
+ ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name'])
100
+ # ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color)
101
+
102
+ # ax.set_title(f'{relation.replace("_", "-")}')
103
+ # ax.legend()
104
+ color_hex = rgba_to_hex(color)
105
+
106
+ # Write header in color of relation
107
+ st.markdown(f"<h3 style='color:{color_hex}'>{relation.replace('_', ' ').title()}</h2>", unsafe_allow_html=True)
108
+
109
+ # Show plot
110
+ st.pyplot(fig)
111
+
112
+ # Drop known relation column
113
+ relation_data = relation_data.drop(columns = 'Known Relation')
114
+ if target_node_type not in ['disease', 'anatomy']:
115
+ st.dataframe(relation_data, use_container_width=True,
116
+ column_config={"Database": st.column_config.LinkColumn(width = "small",
117
+ help = "Click to visit external database.",
118
+ display_text = st.session_state.display_database)})
119
+ else:
120
+ st.dataframe(relation_data, use_container_width=True)
121
+
122
+ else:
123
+
124
+ st.error('No ground truth relationships found for the given query in the knowledge graph.', icon="✖️")
project_config.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ PROJECT CONFIGURATION FILE
3
+ This file contains the configuration variables for the project. The variables are used
4
+ in the other scripts to define the paths to the data and results directories. The variables
5
+ are also used to set the random seed for reproducibility.
6
+ '''
7
+
8
+ # Import libraries
9
+ from pathlib import Path
10
+ import socket
11
+ import getpass
12
+
13
+ def check_internet_connection():
14
+ try:
15
+ # Connect to one of the DNS servers
16
+ socket.create_connection(("8.8.8.8", 53), timeout=5)
17
+ return True
18
+ except OSError:
19
+ return False
20
+
21
+ def check_local_machine():
22
+ hostname = socket.gethostname()
23
+ username = getpass.getuser()
24
+
25
+ return hostname, username
26
+
27
+ # Define global variable indicating whether on VDI or not
28
+ VDI = not check_internet_connection()
29
+ print(f"VDI: {VDI}")
30
+
31
+ # Define global variable to check if running locally
32
+ hostname, username = check_local_machine()
33
+ LOCAL = True if username == 'an583' else False
34
+
35
+ # Define HF repo variable
36
+ HF_REPO = 'ayushnoori/clinical-drug-repurposing'
37
+
38
+ # Define project configuration variables
39
+ PROJECT_DIR = Path(__file__).resolve().parent
40
+ DATA_DIR = PROJECT_DIR / 'data'
41
+ AUTH_DIR = PROJECT_DIR / 'auth'
42
+ MODEL_DIR = PROJECT_DIR / 'models'
43
+ MEDIA_DIR = PROJECT_DIR / 'media'
44
+ SEED = 42
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ pandas
3
+ scikit-learn
4
+ matplotlib
5
+ seaborn
6
+ pathlib
7
+ torch
8
+ altair<5
9
+ gspread
10
+ oauth2client
11
+ huggingface_hub
sync_data.sh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Set permissions with chmod +x sync_data.sh
4
+ # Run with ./sync_data.sh
5
+
6
+ # Ask the user for the environment
7
+ echo "Would you like to sync results from O2 or Kempner?"
8
+ echo "1) O2"
9
+ echo "2) Kempner"
10
+ read -p "Enter your choice (1 or 2): " env_choice
11
+
12
+ # Ask the user which results folder to sync
13
+ echo "Which results folder do you want to sync?"
14
+ echo "1) data: disease splits"
15
+ echo "2) disease split models: checkpoints"
16
+ echo "3) disease split models: embeddings"
17
+ read -p "Enter your choice (1, 2, or 3): " folder_choice
18
+
19
+ # Map user input to folder names
20
+ case $folder_choice in
21
+ 1) SRC_FOLDER="Data/DrugKG/2_harmonize_KG/disease_splits";;
22
+ 2) SRC_FOLDER="Results/GALAXY/disease_splits/checkpoints";;
23
+ 3) SRC_FOLDER="Results/GALAXY/disease_splits/embeddings";;
24
+ *) echo "Invalid folder choice. Please enter 1, 2, or 3."; exit 1;;
25
+ esac
26
+
27
+ case $env_choice in
28
+ 1) SRC_DIR="[email protected]:/n/data1/hms/dbmi/zitnik/lab/users/an252/NeuroKG/neuroKG/$SRC_FOLDER";;
29
+ 2) SRC_DIR="[email protected]:/n/holylabs/LABS/mzitnik_lab/Users/anoori/neuroKG/$SRC_FOLDER";;
30
+ *) echo "Invalid source server choice. Please enter 1 or 2."; exit 1;;
31
+ esac
32
+
33
+ # Map user input to destination folder names
34
+ case $folder_choice in
35
+ 1) DST_DIR="data/disease_splits";;
36
+ 2) DST_DIR="models/checkpoints";; # Don't need checkpoints for this application
37
+ 3) DST_DIR="models/embeddings";;
38
+ *) echo "Invalid folder choice. Please enter 1, 2, or 3."; exit 1;;
39
+ esac
40
+
41
+ # Sync source and destination folders with specific file types for checkpoints or embeddings
42
+ # Note, local files not present in the source will be deleted
43
+ if [[ $folder_choice -eq 2 || $folder_choice -eq 3 ]]; then
44
+ echo "Syncing only .ckpt or .pt files from $SRC_DIR to $DST_DIR..."
45
+ rsync -avz -e ssh --include="*.ckpt" --include="*.pt" --exclude="*" --delete $SRC_DIR/ $DST_DIR
46
+ else
47
+ echo "Syncing $SRC_DIR to $DST_DIR..."
48
+ rsync -avz -e ssh --delete $SRC_DIR/ $DST_DIR
49
+ fi
50
+
51
+ echo "Synchronization complete."
utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import project_config
4
+ import base64
5
+
6
+
7
+ @st.cache_data(show_spinner = 'Loading knowledge graph nodes...')
8
+ def load_kg():
9
+ # with st.spinner('Loading knowledge graph...'):
10
+ kg_nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
11
+ return kg_nodes
12
+
13
+
14
+ @st.cache_data(show_spinner = 'Loading knowledge graph edges...')
15
+ def load_kg_edges():
16
+ # with st.spinner('Loading knowledge graph...'):
17
+ kg_edges = pd.read_csv(project_config.DATA_DIR / 'kg_edges.csv', dtype = {'edge_index': int, 'x_index': int, 'y_index': int}, low_memory = False)
18
+ return kg_edges
19
+
20
+
21
+ def capitalize_after_slash(s):
22
+ # Split the string by slashes first
23
+ parts = s.split('/')
24
+ # Capitalize each part separately
25
+ capitalized_parts = [part.title() for part in parts]
26
+ # Rejoin the parts with slashes
27
+ capitalized_string = '/'.join(capitalized_parts).replace('_', ' ')
28
+ return capitalized_string