ayushnoori commited on
Commit
b7df334
·
1 Parent(s): 11b8c2d

Significant update with multi-relation comparison across app

Browse files
Files changed (4) hide show
  1. pages/input.py +14 -10
  2. pages/predict.py +112 -89
  3. pages/validate.py +148 -122
  4. utils.py +31 -1
pages/input.py CHANGED
@@ -173,21 +173,21 @@ if "query" not in st.session_state:
173
  source_node_type_index = 0
174
  source_node_index = 0
175
  target_node_type_index = 0
176
- relation_index = 0
177
  filter_diseases_value = False
178
 
179
  if st.session_state.team == "Clalit":
180
  source_node_type_index = 2
181
  source_node_index = 0
182
  target_node_type_index = 3
183
- relation_index = 2
184
  filter_diseases_value = True
185
 
186
  else:
187
  source_node_type_index = st.session_state.query_options['source_node_type'].index(st.session_state.query['source_node_type'])
188
  source_node_index = st.session_state.query_options['source_node'].index(st.session_state.query['source_node'])
189
  target_node_type_index = st.session_state.query_options['target_node_type'].index(st.session_state.query['target_node_type'])
190
- relation_index = st.session_state.query_options['relation'].index(st.session_state.query['relation'])
191
  filter_diseases_value = st.session_state.query_options['filter_diseases']
192
 
193
  # Define error catching function
@@ -237,11 +237,11 @@ target_node_type = st.selectbox("Target Node Type", target_node_type_options,
237
  format_func = lambda x: x.replace("_", " "),
238
  index = catch_index_error(target_node_type_index, target_node_type_options))
239
 
240
- # Select relation
241
- relation_options = edge_types[(edge_types.x_type == source_node_type) & (edge_types.y_type == target_node_type)].relation.unique()
242
- relation = st.selectbox("Edge Type", relation_options,
243
- format_func = lambda x: x.replace("_", "-"),
244
- index = catch_index_error(relation_index, relation_options))
245
 
246
  # Button to submit query
247
  if st.button("Submit Query"):
@@ -259,7 +259,7 @@ if st.button("Submit Query"):
259
  "source_node_type": source_node_type,
260
  "source_node": source_node,
261
  "target_node_type": target_node_type,
262
- "relation": relation
263
  }
264
 
265
  # Save query options to session state
@@ -267,7 +267,7 @@ if st.button("Submit Query"):
267
  "source_node_type": list(source_node_type_options),
268
  "source_node": list(source_node_options),
269
  "target_node_type": list(target_node_type_options),
270
- "relation": list(relation_options),
271
  "filter_diseases": filter_diseases
272
  }
273
 
@@ -275,6 +275,10 @@ if st.button("Submit Query"):
275
  if "validation" in st.session_state:
276
  del st.session_state.validation
277
 
 
 
 
 
278
  # # Write query to console
279
  # st.write("Current Query:")
280
  # st.write(st.session_state.query)
 
173
  source_node_type_index = 0
174
  source_node_index = 0
175
  target_node_type_index = 0
176
+ # relation_index = 0
177
  filter_diseases_value = False
178
 
179
  if st.session_state.team == "Clalit":
180
  source_node_type_index = 2
181
  source_node_index = 0
182
  target_node_type_index = 3
183
+ # relation_index = 2
184
  filter_diseases_value = True
185
 
186
  else:
187
  source_node_type_index = st.session_state.query_options['source_node_type'].index(st.session_state.query['source_node_type'])
188
  source_node_index = st.session_state.query_options['source_node'].index(st.session_state.query['source_node'])
189
  target_node_type_index = st.session_state.query_options['target_node_type'].index(st.session_state.query['target_node_type'])
190
+ # relation_index = st.session_state.query_options['relation'].index(st.session_state.query['relation'])
191
  filter_diseases_value = st.session_state.query_options['filter_diseases']
192
 
193
  # Define error catching function
 
237
  format_func = lambda x: x.replace("_", " "),
238
  index = catch_index_error(target_node_type_index, target_node_type_options))
239
 
240
+ # # Select relation
241
+ # relation_options = edge_types[(edge_types.x_type == source_node_type) & (edge_types.y_type == target_node_type)].relation.unique()
242
+ # relation = st.selectbox("Edge Type", relation_options,
243
+ # format_func = lambda x: x.replace("_", "-"),
244
+ # index = catch_index_error(relation_index, relation_options))
245
 
246
  # Button to submit query
247
  if st.button("Submit Query"):
 
259
  "source_node_type": source_node_type,
260
  "source_node": source_node,
261
  "target_node_type": target_node_type,
262
+ # "relation": relation
263
  }
264
 
265
  # Save query options to session state
 
267
  "source_node_type": list(source_node_type_options),
268
  "source_node": list(source_node_options),
269
  "target_node_type": list(target_node_type_options),
270
+ # "relation": list(relation_options),
271
  "filter_diseases": filter_diseases
272
  }
273
 
 
275
  if "validation" in st.session_state:
276
  del st.session_state.validation
277
 
278
+ # Delete selected nodes from session state
279
+ if "selected_nodes" in st.session_state:
280
+ del st.session_state.selected_nodes
281
+
282
  # # Write query to console
283
  # st.write("Current Query:")
284
  # st.write(st.session_state.query)
pages/predict.py CHANGED
@@ -18,7 +18,7 @@ plt.rcParams['font.sans-serif'] = 'Arial'
18
 
19
  # Custom and other imports
20
  import project_config
21
- from utils import capitalize_after_slash, load_kg
22
 
23
  # Redirect to app.py if not logged in, otherwise show the navigation menu
24
  menu_with_redirect()
@@ -29,10 +29,9 @@ st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=
29
  # Main content
30
  # st.markdown(f"Hello, {st.session_state.name}!")
31
 
32
- st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'])} Search", divider = "blue")
33
-
34
  # Print current query
35
- st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
 
36
 
37
  # Print split
38
  split = st.session_state.split
@@ -48,7 +47,7 @@ def get_embeddings():
48
  # best_ckpt = "2024_05_15_13_05_33_epoch=2-step=40383"
49
 
50
  # Get split name
51
- split = st.session_state.split
52
  avail_models = st.session_state.avail_models
53
 
54
  # Get model name from available models
@@ -79,6 +78,7 @@ def get_embeddings():
79
 
80
  return embed_path, relation_weights_path, edge_types_path
81
 
 
82
  @st.cache_data(show_spinner = 'Loading AI model...')
83
  def load_embeddings(embed_path, relation_weights_path, edge_types_path):
84
 
@@ -94,6 +94,7 @@ def load_embeddings(embed_path, relation_weights_path, edge_types_path):
94
  kg_nodes = load_kg()
95
  embed_path, relation_weights_path, edge_types_path = get_embeddings()
96
  embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path)
 
97
 
98
  # # Print source node type
99
  # st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
@@ -107,67 +108,79 @@ embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_
107
  # # Print target node type
108
  # st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
109
 
110
- # Compute predictions
111
- with st.spinner('Computing predictions...'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- source_node_type = st.session_state.query['source_node_type']
114
- source_node = st.session_state.query['source_node']
115
- relation = st.session_state.query['relation']
116
- target_node_type = st.session_state.query['target_node_type']
117
 
118
- # Get source node index
119
- src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
120
 
121
- # Get relation index
122
- edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
123
 
124
- # Get target nodes indices
 
 
 
 
 
 
125
  target_nodes = kg_nodes[kg_nodes.node_type == target_node_type].copy()
126
- dst_indices = target_nodes.node_index.values
127
- src_indices = np.repeat(src_index, len(dst_indices))
128
-
129
- # Retrieve cached embeddings and apply activation function
130
- src_embeddings = embeddings[src_indices]
131
- dst_embeddings = embeddings[dst_indices]
132
- src_embeddings = F.leaky_relu(src_embeddings)
133
- dst_embeddings = F.leaky_relu(dst_embeddings)
134
-
135
- # Get relation weights
136
- rel_weights = relation_weights[edge_type_index]
137
-
138
- # Compute weighted dot product
139
- scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
140
- scores = torch.sigmoid(scores)
141
-
142
- # Add scores to dataframe
143
- target_nodes['score'] = scores.detach().numpy()
144
- target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
145
- target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
146
-
147
- # Rename columns
148
- display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
149
- display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'score': 'Score', 'node_source': 'Database'})
150
-
151
- # Define dictionary mapping node types to database URLs
152
- map_dbs = {
153
- 'gene/protein': lambda x: f"https://ncbi.nlm.nih.gov/gene/?term={x}",
154
- 'drug': lambda x: f"https://go.drugbank.com/drugs/{x}",
155
- 'effect/phenotype': lambda x: f"https://hpo.jax.org/app/browse/term/HP:{x.zfill(7)}", # pad with 0s to 7 digits
156
- 'disease': lambda x: x, # MONDO
157
- # pad with 0s to 7 digits
158
- 'biological_process': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
159
- 'molecular_function': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
160
- 'cellular_component': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
161
- 'exposure': lambda x: f"https://ctdbase.org/detail.go?type=chem&acc={x}",
162
- 'pathway': lambda x: f"https://reactome.org/content/detail/{x}",
163
- 'anatomy': lambda x: x,
164
- }
165
-
166
- # Get name of database
167
- display_database = display_data['Database'].values[0]
168
-
169
- # Add URLs to database column
170
- display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
171
 
172
  # Check if validation data exists
173
  if 'validation' in st.session_state:
@@ -203,9 +216,12 @@ with st.spinner('Computing predictions...'):
203
 
204
  # NODE SEARCH
205
 
 
 
206
  # Use multiselect to search for specific nodes
207
- selected_nodes = st.multiselect(f"Search for specific {target_node_type.replace('_', ' ')} nodes to determine their ranking.",
208
- display_data.Name, placeholder = "Type to search...")
 
209
 
210
  # Filter nodes
211
  if len(selected_nodes) > 0:
@@ -213,7 +229,7 @@ with st.spinner('Computing predictions...'):
213
  if show_val:
214
  # selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)]
215
  selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)].copy()
216
- selected_display_data = selected_display_data.reset_index(drop=True).style.map(style_val, subset=val_relations)
217
  else:
218
  selected_display_data = display_data[display_data.Name.isin(selected_nodes)].copy()
219
  selected_display_data = selected_display_data.reset_index(drop=True)
@@ -222,12 +238,15 @@ with st.spinner('Computing predictions...'):
222
  selected_display_data_with_rank = selected_display_data.copy()
223
  selected_display_data_with_rank['Rank'] = selected_display_data_with_rank['Rank'].apply(lambda x: f"{x} (top {(100*x/target_nodes.shape[0]):.2f}% of predictions)")
224
 
 
 
 
225
  # Show filtered nodes
226
  if target_node_type not in ['disease', 'anatomy']:
227
  st.dataframe(selected_display_data_with_rank, use_container_width = True, hide_index = True,
228
  column_config={"Database": st.column_config.LinkColumn(width = "small",
229
- help = "Click to visit external database.",
230
- display_text = display_database)})
231
  else:
232
  st.dataframe(selected_display_data_with_rank, use_container_width = True)
233
 
@@ -260,30 +279,26 @@ with st.spinner('Computing predictions...'):
260
  ax.grid(alpha = 0.2, zorder=0)
261
 
262
  st.pyplot(fig)
263
-
264
-
265
  # FULL RESULTS
266
 
267
  # Show top ranked nodes
268
- st.subheader("Model Predictions", divider = "blue")
269
  top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
270
 
271
  # Show full results
272
  # full_results = val_display_data.iloc[:top_k] if show_val else display_data.iloc[:top_k]
273
  full_results = val_display_data.iloc[:top_k].style.map(style_val, subset=val_relations) if show_val else display_data.iloc[:top_k]
274
-
275
  if target_node_type not in ['disease', 'anatomy']:
276
  st.dataframe(full_results, use_container_width = True, hide_index = True,
277
  column_config={"Database": st.column_config.LinkColumn(width = "small",
278
- help = "Click to visit external database.",
279
- display_text = display_database)})
280
  else:
281
  st.dataframe(full_results, use_container_width = True, hide_index = True,)
282
 
283
- # Save to session state
284
- st.session_state.predictions = display_data
285
- st.session_state.display_database = display_database
286
-
287
  # If validation not in session state
288
  if 'validation' not in st.session_state:
289
 
@@ -293,10 +308,15 @@ with st.spinner('Computing predictions...'):
293
  if st.button("Validate Predictions"):
294
  st.switch_page("pages/validate.py")
295
 
 
 
 
 
296
 
297
- ####################################################################################################
298
 
299
- relation_options = st.session_state.query_options['relation']
 
300
 
301
  if len(relation_options) > 1:
302
 
@@ -316,11 +336,12 @@ with st.spinner('Computing predictions...'):
316
 
317
  with relation_1_col:
318
  relation_1 = st.selectbox("Select first relation:", relation_options,
319
- format_func = lambda x: x.replace("_", "-"), index = relation_1_index)
320
 
321
  with relation_2_col:
322
- relation_2 = st.selectbox("Select second relation:", relation_options,
323
- format_func = lambda x: x.replace("_", "-"), index = relation_2_index)
 
324
 
325
  # Get relation index
326
  rel_1_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation_1, target_node_type)][0]
@@ -355,18 +376,18 @@ with st.spinner('Computing predictions...'):
355
  target_nodes = target_nodes.sort_values(by = 'rel_1_score', ascending = False)
356
  target_nodes['rel_1_rank'] = np.arange(1, target_nodes.shape[0] + 1)
357
 
358
- # Rename relations
359
- relation_1 = relation_1.replace("_", " ").title()
360
- relation_2 = relation_2.replace("_", " ").title()
361
-
362
  # Compute correlation coefficient of scores
363
  corr = target_nodes['rel_1_score'].corr(target_nodes['rel_2_score'])
364
  spearman_corr = target_nodes['rel_1_score'].corr(target_nodes['rel_2_score'], method = 'spearman')
365
 
366
- st.markdown(f"The correlation coefficient between {relation_1} and {relation_2} scores is:")
367
  st.markdown(f"**Pearson's $r$:** {corr:.2f} (Score)")
368
  st.markdown(f"**Spearman's $\\rho$:** {spearman_corr:.2f} (Rank)")
369
 
 
 
 
 
370
  # Rename columns
371
  display_comp = target_nodes[['node_id', 'node_name', 'rel_1_rank', 'rel_2_rank', 'rel_1_score', 'rel_2_score', 'node_source']].copy()
372
  display_comp = display_comp.rename(columns = {
@@ -398,7 +419,7 @@ with st.spinner('Computing predictions...'):
398
  rel_2_min = target_nodes[rel_2_column].min()
399
  rel_1_max = target_nodes[rel_1_column].max()
400
  rel_2_max = target_nodes[rel_2_column].max()
401
- ax.plot([0, rel_1_max], [0, rel_2_max], color = 'red',
402
  linestyle = '--', zorder = 3) # label = 'Equal Rank',
403
  ax.set_xlim(rel_1_min, rel_1_max)
404
  ax.set_ylim(rel_2_min, rel_2_max)
@@ -448,7 +469,7 @@ with st.spinner('Computing predictions...'):
448
  st.dataframe(display_comp_styled, use_container_width = True, hide_index = True,
449
  column_config={"Database": st.column_config.LinkColumn(width = "small",
450
  help = "Click to visit external database.",
451
- display_text = display_database)})
452
 
453
  else:
454
 
@@ -456,4 +477,6 @@ with st.spinner('Computing predictions...'):
456
  st.dataframe(display_comp, use_container_width = True, hide_index = True,
457
  column_config={"Database": st.column_config.LinkColumn(width = "small",
458
  help = "Click to visit external database.",
459
- display_text = display_database)})
 
 
 
18
 
19
  # Custom and other imports
20
  import project_config
21
+ from utils import capitalize_after_slash, load_kg, map_dbs, map_db_names
22
 
23
  # Redirect to app.py if not logged in, otherwise show the navigation menu
24
  menu_with_redirect()
 
29
  # Main content
30
  # st.markdown(f"Hello, {st.session_state.name}!")
31
 
 
 
32
  # Print current query
33
+ # st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
34
+ st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
35
 
36
  # Print split
37
  split = st.session_state.split
 
47
  # best_ckpt = "2024_05_15_13_05_33_epoch=2-step=40383"
48
 
49
  # Get split name
50
+ # split = st.session_state.split
51
  avail_models = st.session_state.avail_models
52
 
53
  # Get model name from available models
 
78
 
79
  return embed_path, relation_weights_path, edge_types_path
80
 
81
+
82
  @st.cache_data(show_spinner = 'Loading AI model...')
83
  def load_embeddings(embed_path, relation_weights_path, edge_types_path):
84
 
 
94
  kg_nodes = load_kg()
95
  embed_path, relation_weights_path, edge_types_path = get_embeddings()
96
  embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path)
97
+ edge_types_df = pd.read_csv(project_config.DATA_DIR / 'kg_edge_types.csv')
98
 
99
  # # Print source node type
100
  # st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
 
108
  # # Print target node type
109
  # st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
110
 
111
+ source_node_type = st.session_state.query['source_node_type']
112
+ source_node = st.session_state.query['source_node']
113
+ # relation = st.session_state.query['relation']
114
+ target_node_type = st.session_state.query['target_node_type']
115
+
116
+ # Get relation options
117
+ relation_options = edge_types_df[(edge_types_df.x_type == source_node_type) & (edge_types_df.y_type == target_node_type)].relation.unique()
118
+
119
+ # Add relation selector
120
+ relation = st.selectbox("Relation Type", relation_options, format_func = lambda x: x.replace("_", "-"))
121
+ display_dbs = {}
122
+
123
+ # Get source node index
124
+ src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
125
+
126
+
127
+ @st.experimental_fragment()
128
+ def compute_scores():
129
+
130
+ # Compute predictions
131
+ with st.spinner('Computing predictions...'):
132
+
133
+ # Get target nodes indices
134
+ target_nodes = kg_nodes[kg_nodes.node_type == target_node_type].copy()
135
+ dst_indices = target_nodes.node_index.values
136
+ src_indices = np.repeat(src_index, len(dst_indices))
137
+
138
+ # Retrieve cached embeddings and apply activation function
139
+ src_embeddings = embeddings[src_indices]
140
+ dst_embeddings = embeddings[dst_indices]
141
+ src_embeddings = F.leaky_relu(src_embeddings)
142
+ dst_embeddings = F.leaky_relu(dst_embeddings)
143
+
144
+ for relation_i in relation_options:
145
+
146
+ # Get relation index
147
+ edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation_i, target_node_type)][0]
148
+
149
+ # Get relation weights
150
+ rel_weights = relation_weights[edge_type_index]
151
+
152
+ # Compute weighted dot product
153
+ scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
154
+ scores = torch.sigmoid(scores).detach().numpy()
155
+
156
+ # Add scores to dataframe
157
+ target_nodes = kg_nodes[kg_nodes.node_type == target_node_type].copy()
158
+ target_nodes['score'] = scores
159
+ target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
160
+ target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
161
+
162
+ # Rename columns
163
+ display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
164
+ display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'score': 'Score', 'node_source': 'Database'})
165
 
166
+ # Add URLs to database column
167
+ display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
 
 
168
 
169
+ # Save to display databases
170
+ display_dbs[relation_i] = display_data
171
 
172
+ # Compute scores
173
+ compute_scores()
174
 
175
+ # Save to session state
176
+ st.session_state.predictions_rel = display_dbs
177
+
178
+ @st.experimental_fragment()
179
+ def visualize_scores():
180
+
181
+ # Get values
182
  target_nodes = kg_nodes[kg_nodes.node_type == target_node_type].copy()
183
+ display_data = display_dbs[relation]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  # Check if validation data exists
186
  if 'validation' in st.session_state:
 
216
 
217
  # NODE SEARCH
218
 
219
+ st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'])} Search", divider = "blue")
220
+
221
  # Use multiselect to search for specific nodes
222
+ selected_nodes = st.multiselect(f"Search for specific {target_node_type.replace('_', ' ')} nodes to determine their rankings.",
223
+ display_data.Name, placeholder = "Type to search...", key = 'selected_nodes',
224
+ default = st.session_state.selected_nodes if 'selected_nodes' in st.session_state else None)
225
 
226
  # Filter nodes
227
  if len(selected_nodes) > 0:
 
229
  if show_val:
230
  # selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)]
231
  selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)].copy()
232
+ selected_display_data = selected_display_data.reset_index(drop=True)
233
  else:
234
  selected_display_data = display_data[display_data.Name.isin(selected_nodes)].copy()
235
  selected_display_data = selected_display_data.reset_index(drop=True)
 
238
  selected_display_data_with_rank = selected_display_data.copy()
239
  selected_display_data_with_rank['Rank'] = selected_display_data_with_rank['Rank'].apply(lambda x: f"{x} (top {(100*x/target_nodes.shape[0]):.2f}% of predictions)")
240
 
241
+ if show_val:
242
+ selected_display_data_with_rank = selected_display_data_with_rank.style.map(style_val, subset=val_relations)
243
+
244
  # Show filtered nodes
245
  if target_node_type not in ['disease', 'anatomy']:
246
  st.dataframe(selected_display_data_with_rank, use_container_width = True, hide_index = True,
247
  column_config={"Database": st.column_config.LinkColumn(width = "small",
248
+ help = "Click to visit external database.",
249
+ display_text = map_db_names[target_node_type])})
250
  else:
251
  st.dataframe(selected_display_data_with_rank, use_container_width = True)
252
 
 
279
  ax.grid(alpha = 0.2, zorder=0)
280
 
281
  st.pyplot(fig)
282
+
283
+
284
  # FULL RESULTS
285
 
286
  # Show top ranked nodes
287
+ st.subheader(f"{relation.replace('_', ' ').title()} Predictions", divider = "blue")
288
  top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
289
 
290
  # Show full results
291
  # full_results = val_display_data.iloc[:top_k] if show_val else display_data.iloc[:top_k]
292
  full_results = val_display_data.iloc[:top_k].style.map(style_val, subset=val_relations) if show_val else display_data.iloc[:top_k]
293
+
294
  if target_node_type not in ['disease', 'anatomy']:
295
  st.dataframe(full_results, use_container_width = True, hide_index = True,
296
  column_config={"Database": st.column_config.LinkColumn(width = "small",
297
+ help = "Click to visit external database.",
298
+ display_text = map_db_names[target_node_type])})
299
  else:
300
  st.dataframe(full_results, use_container_width = True, hide_index = True,)
301
 
 
 
 
 
302
  # If validation not in session state
303
  if 'validation' not in st.session_state:
304
 
 
308
  if st.button("Validate Predictions"):
309
  st.switch_page("pages/validate.py")
310
 
311
+ visualize_scores()
312
+
313
+
314
+ ####################################################################################################
315
 
316
+ # relation_options = st.session_state.query_options['relation']
317
 
318
+ @st.experimental_fragment()
319
+ def compare_scores():
320
 
321
  if len(relation_options) > 1:
322
 
 
336
 
337
  with relation_1_col:
338
  relation_1 = st.selectbox("Select first relation:", relation_options,
339
+ format_func = lambda x: x.replace("_", "-"), index = relation_1_index)
340
 
341
  with relation_2_col:
342
+ relation_2_options = [rel for rel in relation_options if rel != relation_1]
343
+ relation_2 = st.selectbox("Select second relation:", relation_2_options,
344
+ format_func = lambda x: x.replace("_", "-"), index = relation_2_index)
345
 
346
  # Get relation index
347
  rel_1_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation_1, target_node_type)][0]
 
376
  target_nodes = target_nodes.sort_values(by = 'rel_1_score', ascending = False)
377
  target_nodes['rel_1_rank'] = np.arange(1, target_nodes.shape[0] + 1)
378
 
 
 
 
 
379
  # Compute correlation coefficient of scores
380
  corr = target_nodes['rel_1_score'].corr(target_nodes['rel_2_score'])
381
  spearman_corr = target_nodes['rel_1_score'].corr(target_nodes['rel_2_score'], method = 'spearman')
382
 
383
+ st.markdown(f"The correlation coefficient between {relation_1.replace('_', ' ')} and {relation_2.replace('_', ' ')} scores is:")
384
  st.markdown(f"**Pearson's $r$:** {corr:.2f} (Score)")
385
  st.markdown(f"**Spearman's $\\rho$:** {spearman_corr:.2f} (Rank)")
386
 
387
+ # Rename relations
388
+ relation_1 = relation_1.replace("_", " ").title()
389
+ relation_2 = relation_2.replace("_", " ").title()
390
+
391
  # Rename columns
392
  display_comp = target_nodes[['node_id', 'node_name', 'rel_1_rank', 'rel_2_rank', 'rel_1_score', 'rel_2_score', 'node_source']].copy()
393
  display_comp = display_comp.rename(columns = {
 
419
  rel_2_min = target_nodes[rel_2_column].min()
420
  rel_1_max = target_nodes[rel_1_column].max()
421
  rel_2_max = target_nodes[rel_2_column].max()
422
+ ax.plot([0, rel_1_max], [0, rel_2_max], color = 'red', linewidth = 1.5,
423
  linestyle = '--', zorder = 3) # label = 'Equal Rank',
424
  ax.set_xlim(rel_1_min, rel_1_max)
425
  ax.set_ylim(rel_2_min, rel_2_max)
 
469
  st.dataframe(display_comp_styled, use_container_width = True, hide_index = True,
470
  column_config={"Database": st.column_config.LinkColumn(width = "small",
471
  help = "Click to visit external database.",
472
+ display_text = map_db_names[target_node_type])})
473
 
474
  else:
475
 
 
477
  st.dataframe(display_comp, use_container_width = True, hide_index = True,
478
  column_config={"Database": st.column_config.LinkColumn(width = "small",
479
  help = "Click to visit external database.",
480
+ display_text = map_db_names[target_node_type])})
481
+
482
+ compare_scores()
pages/validate.py CHANGED
@@ -14,11 +14,11 @@ plt.rcParams['font.sans-serif'] = 'Arial'
14
  import matplotlib.colors as mcolors
15
 
16
  # Import metrics
17
- from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, f1_score
18
 
19
  # Custom and other imports
20
  import project_config
21
- from utils import load_kg, load_kg_edges
22
 
23
  # Redirect to app.py if not logged in, otherwise show the navigation menu
24
  menu_with_redirect()
@@ -32,7 +32,8 @@ st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width
32
  st.subheader("Validate Predictions", divider = "green")
33
 
34
  # Print current query
35
- st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
 
36
 
37
  # Print split
38
  split = st.session_state.split
@@ -44,9 +45,14 @@ st.markdown(f"**Disease Split:** {st.session_state.split} ({num_nodes} nodes, {n
44
  # Get query and predictions
45
  source_node_type = st.session_state.query['source_node_type']
46
  source_node = st.session_state.query['source_node']
47
- relation = st.session_state.query['relation']
48
  target_node_type = st.session_state.query['target_node_type']
49
- predictions = st.session_state.predictions
 
 
 
 
 
50
 
51
  @st.experimental_fragment()
52
  def plot_options():
@@ -58,8 +64,7 @@ def plot_options():
58
 
59
  # Slider for x-axis limit
60
  axis_limits = st.slider('Define the range of ranks to visualize.',
61
- min_value=0, max_value=predictions['Rank'].max(),
62
- value=(0, predictions['Rank'].max()), step=1000)
63
 
64
  # Update session state
65
  st.session_state.show_lines = show_lines
@@ -72,12 +77,12 @@ plot_options()
72
  if 'show_lines' not in st.session_state:
73
  st.session_state.show_lines = False
74
  if 'axis_limits' not in st.session_state:
75
- st.session_state.axis_limits = (0, predictions['Rank'].max())
76
 
77
  # Button to update plot
78
- col1, col2, col3 = st.columns([4, 2, 4])
79
  with col2:
80
- update_button = st.button('Generate Plot')
81
 
82
  # Horizontal line
83
  st.markdown('---')
@@ -90,24 +95,14 @@ if update_button:
90
  # Convert tuple to hex
91
  def rgba_to_hex(rgba):
92
  return mcolors.to_hex(rgba[:3])
93
-
 
94
  with st.spinner('Searching known relationships...'):
95
-
96
- # Subset existing edges
97
  edge_subset = kg_edges[(kg_edges.x_type == source_node_type) & (kg_edges.x_name == source_node)]
98
  edge_subset = edge_subset[edge_subset.y_type == target_node_type]
99
 
100
- # Merge edge subset with predictions
101
- edges_in_kg = pd.merge(predictions, edge_subset[['relation', 'y_id']], left_on = 'ID', right_on = 'y_id', how = 'right')
102
- edges_in_kg = edges_in_kg.sort_values(by = 'Score', ascending = False)
103
- edges_in_kg = edges_in_kg.drop(columns = 'y_id')
104
-
105
- # Rename relation to ground-truth
106
- edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']]
107
- edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'})
108
-
109
  # If there exist edges in KG
110
- if len(edges_in_kg) > 0:
111
 
112
  with st.spinner('Saving validation results...'):
113
 
@@ -119,118 +114,149 @@ if update_button:
119
  # Save validation results to session state
120
  st.session_state.validation = val_results
121
 
122
- with st.spinner('Plotting known relationships...'):
 
123
 
124
- # Define a color map for different relations
125
- color_map = plt.get_cmap('tab10')
126
 
127
- # Group by relation and create separate plots
128
- relations = edges_in_kg['Known Relation'].unique()
129
- for idx, relation in enumerate(relations):
130
 
131
- relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation]
 
 
 
 
 
 
 
 
 
 
132
 
133
  # Get a color from the color map
134
  color = color_map(idx % color_map.N)
135
 
136
- fig, ax = plt.subplots(figsize=(10, 5))
137
- ax.plot(predictions['Rank'], predictions['Score'], color = 'black', linewidth = 1.5, zorder = 2)
138
- ax.set_xlabel('Rank', fontsize=12)
139
- ax.set_ylabel('Score', fontsize=12)
140
- # ax.set_xlim(1, predictions['Rank'].max())
141
- # ax.set_xlim(axis_limits)
142
- ax.set_xlim(st.session_state.axis_limits)
143
-
144
- for i, node in relation_data.iterrows():
145
- if st.session_state.show_lines:
146
- ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name'], zorder = 3)
147
- ax.scatter(node['Rank'], node['Score'], color=color, zorder=3) # s=15
148
- # ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color)
149
-
150
- # Also calculate and plot recall at K
151
- ax2 = ax.twinx()
152
-
153
- # Calculate recall at K for all Rank
154
- recall_at_k = []
155
- for k in range(1, predictions['Rank'].max() + 1):
156
- recall = 100*len(relation_data[relation_data['Rank'] <= k]) / len(relation_data)
157
- recall_at_k.append(recall)
158
-
159
- ax2.plot(range(1, predictions['Rank'].max() + 1), recall_at_k,
160
- color = 'red', linestyle = '--', label = 'Recall at K', zorder = 4, linewidth = 2)
161
-
162
- # Set labels
163
- ax2.set_ylabel('Recall at K (%)', fontsize=12, color='red')
164
-
165
- # Add grid
166
- ax.grid(True, linestyle=':', alpha=0.5, zorder=0)
167
-
168
- # ax.set_title(f'{relation.replace("_", "-")}')
169
- # ax.legend()
170
  color_hex = rgba_to_hex(color)
171
 
172
  # Write header in color of relation
173
  st.markdown(f"<h3 style='color:{color_hex}'>{relation.replace('_', ' ').title()}</h3>", unsafe_allow_html=True)
174
 
175
- # Show plot
176
- st.pyplot(fig)
177
-
178
- # Create recall at K table
179
- k_vals = [10, 50, 100, 500, 1000, 5000, 10000]
180
- recall_at_k = []
181
- for k in k_vals:
182
- recall = 100*len(relation_data[relation_data['Rank'] <= k]) / len(relation_data)
183
- recall = f"{recall:.2f}%"
184
- recall_at_k.append(recall)
185
- recall_df = pd.DataFrame({'K': k_vals, 'Recall': recall_at_k})
186
-
187
- # Transpose and display recall at K
188
- recall_df = recall_df.T
189
- recall_df.columns = [f"k = {k:.0f}" for k in recall_df.iloc[0]]
190
- recall_df = recall_df.drop('K')
191
- st.markdown('**Recall at $k$:**')
192
- st.dataframe(recall_df, use_container_width=True)
193
-
194
- # Compute other statistics
195
- st.markdown('**Statistics:**')
196
-
197
- # Binarize score
198
- pred_threshold = 0.5
199
- raw_score = predictions['Score']
200
- binary_score = (raw_score > pred_threshold).astype(int)
201
- true_label = np.zeros(len(predictions))
202
-
203
- # Set true label to 1 for known relations
204
-
205
- # Reset index
206
- predictions_idx = predictions.copy().reset_index(drop = True)
207
- true_label[predictions_idx[predictions_idx['ID'].isin(relation_data['ID'])].index] = 1
208
-
209
- # Compute scores
210
- accuracy = accuracy_score(true_label, binary_score)
211
- ap = average_precision_score(true_label, raw_score)
212
- f1 = f1_score(true_label, binary_score, average = 'micro')
213
- try:
214
- auc = roc_auc_score(true_label, raw_score)
215
- except ValueError:
216
- auc = 0.5
217
-
218
- # Create dataframe
219
- stats_df = pd.DataFrame({'Accuracy': [accuracy], 'AUC': [auc], 'AP': [ap], 'F1': [f1]})
220
- stats_df.index = ["Value"]
221
- st.dataframe(stats_df, use_container_width=True)
222
-
223
- # Drop known relation column
224
- st.markdown('**Known Relationships:**')
225
- relation_data = relation_data.drop(columns = 'Known Relation')
226
- if target_node_type not in ['disease', 'anatomy']:
227
- st.dataframe(relation_data, use_container_width=True,
228
- column_config={"Database": st.column_config.LinkColumn(width = "small",
229
- help = "Click to visit external database.",
230
- display_text = st.session_state.display_database)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  else:
232
- st.dataframe(relation_data, use_container_width=True)
 
233
 
234
  else:
235
 
236
- st.error('No ground truth relationships found for the given query in the knowledge graph.', icon="✖️")
 
14
  import matplotlib.colors as mcolors
15
 
16
  # Import metrics
17
+ from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, f1_score, balanced_accuracy_score
18
 
19
  # Custom and other imports
20
  import project_config
21
+ from utils import load_kg, load_kg_edges, map_db_names
22
 
23
  # Redirect to app.py if not logged in, otherwise show the navigation menu
24
  menu_with_redirect()
 
32
  st.subheader("Validate Predictions", divider = "green")
33
 
34
  # Print current query
35
+ # st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
36
+ st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
37
 
38
  # Print split
39
  split = st.session_state.split
 
45
  # Get query and predictions
46
  source_node_type = st.session_state.query['source_node_type']
47
  source_node = st.session_state.query['source_node']
48
+ # relation = st.session_state.query['relation']
49
  target_node_type = st.session_state.query['target_node_type']
50
+ predictions_rel = st.session_state.predictions_rel
51
+
52
+ # Get relation options
53
+ edge_types_df = pd.read_csv(project_config.DATA_DIR / 'kg_edge_types.csv')
54
+ relation_options = edge_types_df[(edge_types_df.x_type == source_node_type) & (edge_types_df.y_type == target_node_type)].relation.unique()
55
+ max_rank = predictions_rel[relation_options[0]]['Rank'].max()
56
 
57
  @st.experimental_fragment()
58
  def plot_options():
 
64
 
65
  # Slider for x-axis limit
66
  axis_limits = st.slider('Define the range of ranks to visualize.',
67
+ min_value=0, max_value=max_rank, value=(0, max_rank), step=1000)
 
68
 
69
  # Update session state
70
  st.session_state.show_lines = show_lines
 
77
  if 'show_lines' not in st.session_state:
78
  st.session_state.show_lines = False
79
  if 'axis_limits' not in st.session_state:
80
+ st.session_state.axis_limits = (0, max_rank)
81
 
82
  # Button to update plot
83
+ col1, col2, col3 = st.columns([2, 2, 2])
84
  with col2:
85
+ update_button = st.button('Generate Plot and Metrics')
86
 
87
  # Horizontal line
88
  st.markdown('---')
 
95
  # Convert tuple to hex
96
  def rgba_to_hex(rgba):
97
  return mcolors.to_hex(rgba[:3])
98
+
99
+ # Subset existing edges
100
  with st.spinner('Searching known relationships...'):
 
 
101
  edge_subset = kg_edges[(kg_edges.x_type == source_node_type) & (kg_edges.x_name == source_node)]
102
  edge_subset = edge_subset[edge_subset.y_type == target_node_type]
103
 
 
 
 
 
 
 
 
 
 
104
  # If there exist edges in KG
105
+ if len(edge_subset) > 0:
106
 
107
  with st.spinner('Saving validation results...'):
108
 
 
114
  # Save validation results to session state
115
  st.session_state.validation = val_results
116
 
117
+ # Define a color map for different relations
118
+ color_map = plt.get_cmap('tab10')
119
 
120
+ for idx, relation in enumerate(relation_options):
 
121
 
122
+ # Get predictions for specific relation
123
+ predictions = predictions_rel[relation]
 
124
 
125
+ # Merge edge subset with predictions
126
+ edge_subset_rel = edge_subset[['relation', 'y_id']].copy()
127
+ edges_in_kg = pd.merge(predictions, edge_subset_rel, left_on = 'ID', right_on = 'y_id', how = 'right')
128
+ edges_in_kg = edges_in_kg.sort_values(by = 'Score', ascending = False)
129
+ edges_in_kg = edges_in_kg.drop(columns = 'y_id')
130
+
131
+ # Rename relation to ground-truth
132
+ edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']]
133
+ edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'})
134
+
135
+ with st.spinner('Plotting known relationships...'):
136
 
137
  # Get a color from the color map
138
  color = color_map(idx % color_map.N)
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  color_hex = rgba_to_hex(color)
141
 
142
  # Write header in color of relation
143
  st.markdown(f"<h3 style='color:{color_hex}'>{relation.replace('_', ' ').title()}</h3>", unsafe_allow_html=True)
144
 
145
+
146
+ # Group by relation and create separate plots
147
+ # relations = edges_in_kg['Known Relation'].unique()
148
+ # for idx, relation in enumerate(relations):
149
+
150
+ relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation]
151
+
152
+ if len(relation_data) > 0:
153
+
154
+ fig, ax = plt.subplots(figsize=(10, 5))
155
+ ax.plot(predictions['Rank'], predictions['Score'], color = 'black', linewidth = 1.5, zorder = 2)
156
+ ax.set_xlabel('Rank', fontsize=12)
157
+ ax.set_ylabel('Score', fontsize=12)
158
+ # ax.set_xlim(1, predictions['Rank'].max())
159
+ # ax.set_xlim(axis_limits)
160
+ ax.set_xlim(st.session_state.axis_limits)
161
+
162
+ for i, node in relation_data.iterrows():
163
+ if st.session_state.show_lines:
164
+ ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name'], zorder = 3)
165
+ ax.scatter(node['Rank'], node['Score'], color=color, zorder=3) # s=15
166
+ # ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color)
167
+
168
+ # Also calculate and plot recall at K
169
+ ax2 = ax.twinx()
170
+
171
+ # Calculate recall at K for all Rank
172
+ recall_at_k = []
173
+ for k in range(1, predictions['Rank'].max() + 1):
174
+ recall = 100*len(relation_data[relation_data['Rank'] <= k]) / len(relation_data)
175
+ recall_at_k.append(recall)
176
+
177
+ ax2.plot(range(1, predictions['Rank'].max() + 1), recall_at_k,
178
+ color = 'red', linestyle = '--', label = 'Recall at K', zorder = 4, linewidth = 2)
179
+
180
+ # Set labels
181
+ ax2.set_ylabel('Recall at K (%)', fontsize=12, color='red')
182
+
183
+ # Add grid
184
+ ax.grid(True, linestyle=':', alpha=0.5, zorder=0)
185
+
186
+ # ax.set_title(f'{relation.replace("_", "-")}')
187
+ # ax.legend()
188
+
189
+ # Show plot
190
+ st.pyplot(fig)
191
+
192
+ # Create recall at K table
193
+ k_vals = [10, 50, 100, 500, 1000, 5000, 10000]
194
+ recall_at_k = []
195
+ for k in k_vals:
196
+ recall = 100*len(relation_data[relation_data['Rank'] <= k]) / len(relation_data)
197
+ recall = f"{recall:.2f}%"
198
+ recall_at_k.append(recall)
199
+ recall_df = pd.DataFrame({'K': k_vals, 'Recall': recall_at_k})
200
+
201
+ # Transpose and display recall at K
202
+ recall_df = recall_df.T
203
+ recall_df.columns = [f"k = {k:.0f}" for k in recall_df.iloc[0]]
204
+ recall_df = recall_df.drop('K')
205
+ st.markdown('**Recall at $k$:**')
206
+ st.dataframe(recall_df, use_container_width=True)
207
+
208
+ # Compute other statistics
209
+ st.markdown('**Statistics:**')
210
+
211
+ # Binarize score
212
+ pred_threshold = 0.5
213
+ raw_score = predictions['Score']
214
+ binary_score = (raw_score > pred_threshold).astype(int)
215
+ true_label = np.zeros(len(predictions))
216
+
217
+ # Set true label to 1 for known relations
218
+
219
+ # Reset index
220
+ predictions_idx = predictions.copy().reset_index(drop = True)
221
+ true_label[predictions_idx[predictions_idx['ID'].isin(relation_data['ID'])].index] = 1
222
+
223
+ # Compute scores
224
+ accuracy = accuracy_score(true_label, binary_score)
225
+ balanced_accuracy = balanced_accuracy_score(true_label, binary_score)
226
+ accuracy = f"{100*accuracy:.2f}%"
227
+ balanced_accuracy = f"{100*balanced_accuracy:.2f}%"
228
+ ap = average_precision_score(true_label, raw_score)
229
+ f1 = f1_score(true_label, binary_score, average = 'micro')
230
+ try:
231
+ auc = roc_auc_score(true_label, raw_score)
232
+ except ValueError:
233
+ auc = 0.5
234
+
235
+ # Create dataframe
236
+ stats_df = pd.DataFrame({
237
+ 'Acc.': [accuracy], 'Balanced Acc.': [balanced_accuracy],
238
+ 'AUC': [auc], 'AP': [ap], 'F1': [f1]
239
+ })
240
+ stats_df.index = ["Value"]
241
+ st.dataframe(stats_df, use_container_width=True)
242
+
243
+ # Drop known relation column
244
+ st.markdown('**Known Relationships:**')
245
+ relation_data = relation_data.drop(columns = 'Known Relation')
246
+ relation_data['Rank'] = relation_data['Rank'].apply(lambda x: f"{x} (top {(100*x/predictions.shape[0]):.2f}%)")
247
+
248
+ if target_node_type not in ['disease', 'anatomy']:
249
+ st.dataframe(relation_data, use_container_width=True, hide_index = True,
250
+ column_config={"Database": st.column_config.LinkColumn(width = "small",
251
+ help = "Click to visit external database.",
252
+ display_text = map_db_names[target_node_type],)})
253
+ else:
254
+ st.dataframe(relation_data, use_container_width=True, hide_index = True)
255
+
256
  else:
257
+
258
+ st.error(f"No ground truth {relation.replace('_', ' ')} edges found for {source_node} in the knowledge graph.", icon="✖️")
259
 
260
  else:
261
 
262
+ st.error(f"No ground truth {target_node_type} relationships found for {source_node} in the knowledge graph.", icon="✖️")
utils.py CHANGED
@@ -25,4 +25,34 @@ def capitalize_after_slash(s):
25
  capitalized_parts = [part.title() for part in parts]
26
  # Rejoin the parts with slashes
27
  capitalized_string = '/'.join(capitalized_parts).replace('_', ' ')
28
- return capitalized_string
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  capitalized_parts = [part.title() for part in parts]
26
  # Rejoin the parts with slashes
27
  capitalized_string = '/'.join(capitalized_parts).replace('_', ' ')
28
+ return capitalized_string
29
+
30
+
31
+ # Define dictionary mapping node types to database URLs
32
+ map_dbs = {
33
+ 'gene/protein': lambda x: f"https://ncbi.nlm.nih.gov/gene/?term={x}",
34
+ 'drug': lambda x: f"https://go.drugbank.com/drugs/{x}",
35
+ 'effect/phenotype': lambda x: f"https://hpo.jax.org/app/browse/term/HP:{x.zfill(7)}", # pad with 0s to 7 digits
36
+ 'disease': lambda x: x, # MONDO
37
+ # pad with 0s to 7 digits
38
+ 'biological_process': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
39
+ 'molecular_function': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
40
+ 'cellular_component': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
41
+ 'exposure': lambda x: f"https://ctdbase.org/detail.go?type=chem&acc={x}",
42
+ 'pathway': lambda x: f"https://reactome.org/content/detail/{x}",
43
+ 'anatomy': lambda x: x,
44
+ }
45
+
46
+ # Define dictionary mapping node types to database names
47
+ map_db_names = {
48
+ 'gene/protein': 'NCBI',
49
+ 'drug': 'DrugBank',
50
+ 'effect/phenotype': 'HPO',
51
+ 'disease': 'MONDO',
52
+ 'biological_process': 'GO: BP',
53
+ 'molecular_function': 'GO: MF',
54
+ 'cellular_component': 'GO: CC',
55
+ 'exposure': 'CTD',
56
+ 'pathway': 'Reactome',
57
+ 'anatomy': 'UBERON',
58
+ }