Commit
·
48e3a32
1
Parent(s):
d4ca2d2
Update input
Browse files- pages/input.py +26 -35
- pages/predict.py +16 -0
- pages/split.py +2 -2
- pages/validate.py +8 -3
- project_config.py +1 -0
pages/input.py
CHANGED
@@ -31,42 +31,32 @@ with st.spinner('Loading disease splits...'):
|
|
31 |
# Load from Kempner using sync_data.sh
|
32 |
disease_splits = pd.read_csv(project_config.DATA_DIR / 'disease_splits' / 'disease_splits.csv',
|
33 |
dtype = {'node_index': str, 'disease_split_index': str})
|
34 |
-
|
|
|
|
|
|
|
35 |
else:
|
36 |
|
37 |
# Read disease splits from HF
|
38 |
disease_splits = hf_hub_download(repo_id=project_config.HF_REPO,
|
39 |
-
filename='
|
40 |
token=st.secrets["HF_TOKEN"], repo_type="dataset")
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
#
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
# Add
|
56 |
-
splits_df['
|
57 |
-
splits_df = pd.concat([pd.DataFrame([['all', 'all diseases', None, disease_splits.shape[0]]], columns=splits_df.columns), splits_df], ignore_index=True)
|
58 |
-
|
59 |
-
# For each disease split, count number of edges (number of rows in CSV file in disease_splits directory)
|
60 |
-
# Do not read file in
|
61 |
-
edge_counts = []
|
62 |
-
for index, row in splits_df.iterrows():
|
63 |
-
# Count lines
|
64 |
-
file_name = project_config.DATA_DIR / 'disease_splits' / 'split_edges' / f'{row["node_index"]}.csv'
|
65 |
-
edge_count = int(subprocess.check_output(['wc', '-l', file_name]).split()[0]) - 1
|
66 |
-
edge_counts.append(edge_count)
|
67 |
-
|
68 |
-
# Add edge counts to splits_df
|
69 |
-
splits_df['edge_count'] = edge_counts
|
70 |
|
71 |
# Get list of available modles
|
72 |
model_files = os.listdir(project_config.MODEL_DIR / 'embeddings')
|
@@ -92,7 +82,6 @@ with st.spinner('Loading disease splits...'):
|
|
92 |
# Get available models, only keep latest version per split
|
93 |
avail_models = pd.DataFrame([get_model_metadata(f) for f in model_files])
|
94 |
avail_models = avail_models.sort_values('date', ascending=False).drop_duplicates('test').reset_index(drop=True)
|
95 |
-
# avail_models.loc[avail_models['test'] == 'all', 'test'] = 'all diseases'
|
96 |
|
97 |
# Add column to indicate if model is available
|
98 |
splits_df['available'] = splits_df['node_index'].isin(avail_models['test'])
|
@@ -104,9 +93,12 @@ with st.spinner('Loading disease splits...'):
|
|
104 |
####################################################################################################
|
105 |
|
106 |
# Select disease split from splits with available models
|
|
|
|
|
|
|
107 |
# Make dictionary with node_index: node_name, where name is value shown but index is used for query
|
108 |
-
# split_options =
|
109 |
-
split_options =
|
110 |
split_options = split_options.set_index('node_index')['node_name'].to_dict()
|
111 |
|
112 |
# Check if split is in session state
|
@@ -119,7 +111,6 @@ with st.spinner('Loading disease splits...'):
|
|
119 |
index = split_index)
|
120 |
|
121 |
# Show all splits dataframe
|
122 |
-
splits_display = splits_df[['node_index', 'node_name', 'node_count', 'edge_count', 'available']].copy()
|
123 |
splits_display = splits_display.rename(columns = {'node_index': 'Split ID', 'node_name': 'Disease', 'node_count': 'Node Count', 'edge_count': 'Edge Count', 'available': 'Model Available'})
|
124 |
st.dataframe(splits_display, use_container_width = True, hide_index = True)
|
125 |
|
|
|
31 |
# Load from Kempner using sync_data.sh
|
32 |
disease_splits = pd.read_csv(project_config.DATA_DIR / 'disease_splits' / 'disease_splits.csv',
|
33 |
dtype = {'node_index': str, 'disease_split_index': str})
|
34 |
+
|
35 |
+
splits_df = pd.read_csv(project_config.DATA_DIR / 'disease_splits' / 'disease_splits_summary.csv',
|
36 |
+
dtype = {'node_index': str, 'disease_split_index': str})
|
37 |
+
|
38 |
else:
|
39 |
|
40 |
# Read disease splits from HF
|
41 |
disease_splits = hf_hub_download(repo_id=project_config.HF_REPO,
|
42 |
+
filename='data/disease_splits/disease_splits.csv',
|
43 |
token=st.secrets["HF_TOKEN"], repo_type="dataset")
|
44 |
+
|
45 |
+
disease_splits = hf_hub_download(repo_id=project_config.HF_REPO,
|
46 |
+
filename='data/disease_splits/disease_splits_summary.csv',
|
47 |
+
token=st.secrets["HF_TOKEN"], repo_type="dataset")
|
48 |
+
|
49 |
+
# # For each disease split, count number of edges (number of rows in CSV file in disease_splits directory)
|
50 |
+
# # Do not read file in
|
51 |
+
# edge_counts = []
|
52 |
+
# for index, row in splits_df.iterrows():
|
53 |
+
# # Count lines
|
54 |
+
# file_name = project_config.DATA_DIR / 'disease_splits' / 'split_edges' / f'{row["node_index"]}.csv'
|
55 |
+
# edge_count = int(subprocess.check_output(['wc', '-l', file_name]).split()[0]) - 1
|
56 |
+
# edge_counts.append(edge_count)
|
57 |
+
|
58 |
+
# # Add edge counts to splits_df
|
59 |
+
# splits_df['edge_count'] = edge_counts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
# Get list of available modles
|
62 |
model_files = os.listdir(project_config.MODEL_DIR / 'embeddings')
|
|
|
82 |
# Get available models, only keep latest version per split
|
83 |
avail_models = pd.DataFrame([get_model_metadata(f) for f in model_files])
|
84 |
avail_models = avail_models.sort_values('date', ascending=False).drop_duplicates('test').reset_index(drop=True)
|
|
|
85 |
|
86 |
# Add column to indicate if model is available
|
87 |
splits_df['available'] = splits_df['node_index'].isin(avail_models['test'])
|
|
|
93 |
####################################################################################################
|
94 |
|
95 |
# Select disease split from splits with available models
|
96 |
+
splits_display = splits_df[['node_index', 'node_name', 'node_count', 'edge_count', 'available']].copy()
|
97 |
+
splits_display['node_name'] = splits_display['node_name'].str.replace(' \\(disease\\)', '', regex=True)
|
98 |
+
|
99 |
# Make dictionary with node_index: node_name, where name is value shown but index is used for query
|
100 |
+
# split_options = splits_display[splits_display['available']].copy()
|
101 |
+
split_options = splits_display.copy()
|
102 |
split_options = split_options.set_index('node_index')['node_name'].to_dict()
|
103 |
|
104 |
# Check if split is in session state
|
|
|
111 |
index = split_index)
|
112 |
|
113 |
# Show all splits dataframe
|
|
|
114 |
splits_display = splits_display.rename(columns = {'node_index': 'Split ID', 'node_name': 'Disease', 'node_count': 'Node Count', 'edge_count': 'Edge Count', 'available': 'Model Available'})
|
115 |
st.dataframe(splits_display, use_container_width = True, hide_index = True)
|
116 |
|
pages/predict.py
CHANGED
@@ -34,6 +34,13 @@ st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'
|
|
34 |
# Print current query
|
35 |
st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
@st.cache_data(show_spinner = 'Downloading AI model...')
|
38 |
def get_embeddings():
|
39 |
|
@@ -272,3 +279,12 @@ with st.spinner('Computing predictions...'):
|
|
272 |
# Save to session state
|
273 |
st.session_state.predictions = display_data
|
274 |
st.session_state.display_database = display_database
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
# Print current query
|
35 |
st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
|
36 |
|
37 |
+
# Print split
|
38 |
+
split = st.session_state.split
|
39 |
+
splits_df = st.session_state.splits_df
|
40 |
+
num_nodes = splits_df[splits_df['node_index'] == split]['node_count'].values[0]
|
41 |
+
num_edges = splits_df[splits_df['node_index'] == split]['edge_count'].values[0]
|
42 |
+
st.markdown(f"**Disease Split:** {st.session_state.split} ({num_nodes} nodes, {num_edges} edges)")
|
43 |
+
|
44 |
@st.cache_data(show_spinner = 'Downloading AI model...')
|
45 |
def get_embeddings():
|
46 |
|
|
|
279 |
# Save to session state
|
280 |
st.session_state.predictions = display_data
|
281 |
st.session_state.display_database = display_database
|
282 |
+
|
283 |
+
# If validation not in session state
|
284 |
+
if 'validation' not in st.session_state:
|
285 |
+
|
286 |
+
col1, col2, col3 = st.columns(3)
|
287 |
+
|
288 |
+
with col2:
|
289 |
+
if st.button("Validate Predictions"):
|
290 |
+
st.switch_page("pages/validate.py")
|
pages/split.py
CHANGED
@@ -76,7 +76,7 @@ plt.tight_layout()
|
|
76 |
# Adding labels on top of each bar
|
77 |
for bar in bars:
|
78 |
yval = bar.get_height()
|
79 |
-
plt.text(bar.get_x() + bar.get_width()/2.0, yval, int(yval), va='bottom', fontsize=12)
|
80 |
plt.ylim(0, max(method_counts['Count'])*1.1)
|
81 |
|
82 |
# Show plot
|
@@ -107,7 +107,7 @@ if disease_split_edges.shape[0] > 0:
|
|
107 |
# Adding labels on top of each bar
|
108 |
for bar in bars:
|
109 |
yval = bar.get_height()
|
110 |
-
plt.text(bar.get_x() + bar.get_width()/2.0, yval, int(yval), va='bottom', fontsize=12)
|
111 |
plt.ylim(0, max(relation_counts['Count'])*1.1)
|
112 |
|
113 |
# Show plot
|
|
|
76 |
# Adding labels on top of each bar
|
77 |
for bar in bars:
|
78 |
yval = bar.get_height()
|
79 |
+
plt.text(bar.get_x() + bar.get_width()/2.0, yval, int(yval), va='bottom', ha='center', fontsize=12)
|
80 |
plt.ylim(0, max(method_counts['Count'])*1.1)
|
81 |
|
82 |
# Show plot
|
|
|
107 |
# Adding labels on top of each bar
|
108 |
for bar in bars:
|
109 |
yval = bar.get_height()
|
110 |
+
plt.text(bar.get_x() + bar.get_width()/2.0, yval, int(yval), va='bottom', ha='center', fontsize=12)
|
111 |
plt.ylim(0, max(relation_counts['Count'])*1.1)
|
112 |
|
113 |
# Show plot
|
pages/validate.py
CHANGED
@@ -31,9 +31,14 @@ st.subheader("Validate Predictions", divider = "green")
|
|
31 |
# Print current query
|
32 |
st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
|
33 |
|
34 |
-
#
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
37 |
source_node_type = st.session_state.query['source_node_type']
|
38 |
source_node = st.session_state.query['source_node']
|
39 |
relation = st.session_state.query['relation']
|
|
|
31 |
# Print current query
|
32 |
st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
|
33 |
|
34 |
+
# Print split
|
35 |
+
split = st.session_state.split
|
36 |
+
splits_df = st.session_state.splits_df
|
37 |
+
num_nodes = splits_df[splits_df['node_index'] == split]['node_count'].values[0]
|
38 |
+
num_edges = splits_df[splits_df['node_index'] == split]['edge_count'].values[0]
|
39 |
+
st.markdown(f"**Disease Split:** {st.session_state.split} ({num_nodes} nodes, {num_edges} edges)")
|
40 |
+
|
41 |
+
# Get query and predictions
|
42 |
source_node_type = st.session_state.query['source_node_type']
|
43 |
source_node = st.session_state.query['source_node']
|
44 |
relation = st.session_state.query['relation']
|
project_config.py
CHANGED
@@ -31,6 +31,7 @@ print(f"VDI: {VDI}")
|
|
31 |
# Define global variable to check if running locally
|
32 |
hostname, username = check_local_machine()
|
33 |
LOCAL = True if username == 'an583' else False
|
|
|
34 |
|
35 |
# Define HF repo variable
|
36 |
HF_REPO = 'ayushnoori/clinical-drug-repurposing'
|
|
|
31 |
# Define global variable to check if running locally
|
32 |
hostname, username = check_local_machine()
|
33 |
LOCAL = True if username == 'an583' else False
|
34 |
+
print(f"LOCAL: {LOCAL}")
|
35 |
|
36 |
# Define HF repo variable
|
37 |
HF_REPO = 'ayushnoori/clinical-drug-repurposing'
|