Erva Ulusoy commited on
Commit
14c3500
·
1 Parent(s): 85b27f1

include second degree edges (major update)

Browse files
Files changed (2) hide show
  1. ProtHGT_app.py +75 -41
  2. visualize_kg.py +94 -4
ProtHGT_app.py CHANGED
@@ -70,7 +70,7 @@ with st.expander("🚀 Upcoming Features"):
70
 
71
  - **Real-time data retrieval for new proteins**: Currently, ProtHGT can only generate predictions for proteins that already exist in our knowledge graph. We are developing a new feature that will allow users to **predict functions for entirely new proteins starting from their sequences**. This will work by **retrieving relevant relationship data in real time from external source databases** (e.g., UniProt, STRING, and other biological repositories). The system will dynamically construct a knowledge graph for the query protein, incorporating its interactions, domains, pathways, and other biological associations before running function prediction. This approach will enable ProtHGT to analyze newly discovered or less-studied proteins even if they are not pre-annotated in our dataset.
72
  - **Expanded embedding options**: Currently, this application represents proteins using **TAPE embeddings**, which serve as the initial numerical representations of protein sequences before being processed in the heterogeneous graph model. We are working on integrating **ProtT5** and **ESM-2** as alternative initial embeddings, allowing users to choose different sequence representations that may enhance performance for specific tasks. A detailed comparison of how these embeddings influence function prediction accuracy will be included in our upcoming publication.
73
- - **Knowledge graph visualization for interpretability**: To improve model explainability, we are developing an interactive **knowledge graph visualization** feature. This will allow users to explore the biological relationships that contributed to ProtHGTs predictions for a given protein. Users will be able to inspect **protein interactions, GO annotations, domains, pathways, and other key connections** in a structured graphical format, making it easier to interpret and validate predictions.
74
 
75
  Stay tuned for updates and future publications!
76
  """)
@@ -562,78 +562,112 @@ if st.session_state.submitted:
562
  # Create visualizations in each tab
563
  for idx, protein_id in enumerate(selected_proteins):
564
  with protein_tabs[idx]:
565
- max_node_count = st.slider(
566
- "Maximum neighbors per edge type",
567
- min_value=5,
568
- max_value=50,
569
- value=10,
570
- step=5,
571
- help="Control the maximum number of neighboring nodes shown for each relationship type",
572
- key=f"slider_{protein_id}"
573
- )
574
-
575
- # Check if visualization exists for this protein
 
 
576
  viz_exists = (protein_id in st.session_state.protein_visualizations and
577
- os.path.exists(st.session_state.protein_visualizations[protein_id]['path']))
 
578
 
579
  if not viz_exists:
580
  if st.button(f"Generate Visualization", key=f"viz_{protein_id}"):
581
- # Generate visualization with selected max_node_count
582
- html_path, visualized_edges = visualize_protein_subgraph(
 
 
 
 
 
583
  st.session_state.heterodata,
584
  protein_id,
585
  st.session_state.predictions_df,
586
- limit=max_node_count
 
587
  )
588
 
589
- # Store visualization info in session state
590
- st.session_state.protein_visualizations[protein_id] = {
591
- 'path': html_path,
592
- 'edges': visualized_edges
 
 
 
 
 
 
 
 
 
 
 
 
 
593
  }
594
  st.rerun()
595
 
596
- # If visualization exists, display it
597
  if viz_exists:
598
- viz_info = st.session_state.protein_visualizations[protein_id]
 
 
 
 
 
 
599
 
600
- # Add download button for edges
601
- formatted_edges = {}
602
- for edge_type, edges in viz_info['edges'].items():
603
- edge_type_str = f"{edge_type[0]}_{edge_type[1]}_{edge_type[2]}"
604
- formatted_edges[edge_type_str] = [
605
- {"source": edge[0][0], "target": edge[0][1], "probability": edge[1]}
606
- for edge in edges
607
- ]
608
 
609
  kg_viz_button_columns = st.columns([1, 1, 1])
610
 
611
  with kg_viz_button_columns[0]:
 
 
 
 
 
 
 
 
 
612
  st.download_button(
613
  label='Download Visualized Edges',
614
  data=json.dumps(formatted_edges, indent=2),
615
- file_name=f'{protein_id}_visualized_edges.json',
616
  mime='application/json'
617
  )
618
 
619
  with kg_viz_button_columns[1]:
620
  if st.button("Regenerate Visualization", key=f"regenerate_{protein_id}"):
621
- # Clean up old file
622
- try:
623
- os.unlink(viz_info['path'])
624
- except FileNotFoundError:
625
- pass
626
- # Remove from session state
627
- del st.session_state.protein_visualizations[protein_id]
 
 
 
 
628
  st.rerun()
629
 
 
630
  with open(viz_info['path'], 'r', encoding='utf-8') as f:
631
  html_content = f.read()
632
 
633
  st.components.v1.html(html_content, height=1200)
634
 
635
-
636
-
637
-
638
  else:
639
  st.warning("Knowledge graph visualization is only available when 10 or fewer proteins are selected.")
 
70
 
71
  - **Real-time data retrieval for new proteins**: Currently, ProtHGT can only generate predictions for proteins that already exist in our knowledge graph. We are developing a new feature that will allow users to **predict functions for entirely new proteins starting from their sequences**. This will work by **retrieving relevant relationship data in real time from external source databases** (e.g., UniProt, STRING, and other biological repositories). The system will dynamically construct a knowledge graph for the query protein, incorporating its interactions, domains, pathways, and other biological associations before running function prediction. This approach will enable ProtHGT to analyze newly discovered or less-studied proteins even if they are not pre-annotated in our dataset.
72
  - **Expanded embedding options**: Currently, this application represents proteins using **TAPE embeddings**, which serve as the initial numerical representations of protein sequences before being processed in the heterogeneous graph model. We are working on integrating **ProtT5** and **ESM-2** as alternative initial embeddings, allowing users to choose different sequence representations that may enhance performance for specific tasks. A detailed comparison of how these embeddings influence function prediction accuracy will be included in our upcoming publication.
73
+ - **Knowledge graph visualization for interpretability**: To improve model explainability, we are developing an interactive **knowledge graph visualization** feature. This will allow users to explore the biological relationships that contributed to ProtHGT's predictions for a given protein. Users will be able to inspect **protein interactions, GO annotations, domains, pathways, and other key connections** in a structured graphical format, making it easier to interpret and validate predictions.
74
 
75
  Stay tuned for updates and future publications!
76
  """)
 
562
  # Create visualizations in each tab
563
  for idx, protein_id in enumerate(selected_proteins):
564
  with protein_tabs[idx]:
565
+ col1, col2 = st.columns([3, 1])
566
+ with col1:
567
+ max_node_count = st.slider(
568
+ "Maximum neighbors per edge type",
569
+ min_value=5,
570
+ max_value=50,
571
+ value=10,
572
+ step=5,
573
+ help="Control the maximum number of neighboring nodes shown for each relationship type",
574
+ key=f"slider_{protein_id}"
575
+ )
576
+
577
+ # Check if both visualizations exist for this protein
578
  viz_exists = (protein_id in st.session_state.protein_visualizations and
579
+ 'first_degree' in st.session_state.protein_visualizations[protein_id] and
580
+ 'second_degree' in st.session_state.protein_visualizations[protein_id])
581
 
582
  if not viz_exists:
583
  if st.button(f"Generate Visualization", key=f"viz_{protein_id}"):
584
+ # Initialize the protein's visualizations if not exists
585
+ if protein_id not in st.session_state.protein_visualizations:
586
+ st.session_state.protein_visualizations[protein_id] = {}
587
+
588
+ # Generate both visualizations upfront
589
+ # First degree only
590
+ html_path_1st, edges_1st = visualize_protein_subgraph(
591
  st.session_state.heterodata,
592
  protein_id,
593
  st.session_state.predictions_df,
594
+ limit=max_node_count,
595
+ include_second_degree=False
596
  )
597
 
598
+ # With second degree
599
+ html_path_2nd, edges_2nd = visualize_protein_subgraph(
600
+ st.session_state.heterodata,
601
+ protein_id,
602
+ st.session_state.predictions_df,
603
+ limit=max_node_count,
604
+ include_second_degree=True
605
+ )
606
+
607
+ # Store both visualizations in session state
608
+ st.session_state.protein_visualizations[protein_id]['first_degree'] = {
609
+ 'path': html_path_1st,
610
+ 'edges': edges_1st
611
+ }
612
+ st.session_state.protein_visualizations[protein_id]['second_degree'] = {
613
+ 'path': html_path_2nd,
614
+ 'edges': edges_2nd
615
  }
616
  st.rerun()
617
 
618
+ # If visualization exists, show the toggle and display appropriate version
619
  if viz_exists:
620
+ with col2:
621
+ include_second_degree = st.checkbox(
622
+ "Include second-degree edges",
623
+ value=False,
624
+ key=f"second_degree_{protein_id}",
625
+ help="Show connections between neighbor nodes"
626
+ )
627
 
628
+ # Get the appropriate visualization based on checkbox
629
+ viz_type = 'second_degree' if include_second_degree else 'first_degree'
630
+ viz_info = st.session_state.protein_visualizations[protein_id][viz_type]
 
 
 
 
 
631
 
632
  kg_viz_button_columns = st.columns([1, 1, 1])
633
 
634
  with kg_viz_button_columns[0]:
635
+ # Format edges for download
636
+ formatted_edges = {}
637
+ for edge_type, edges in viz_info['edges'].items():
638
+ edge_type_str = f"{edge_type[0]}_{edge_type[1]}_{edge_type[2]}"
639
+ formatted_edges[edge_type_str] = [
640
+ {"source": edge[0][0], "target": edge[0][1], "probability": edge[1]}
641
+ for edge in edges
642
+ ]
643
+
644
  st.download_button(
645
  label='Download Visualized Edges',
646
  data=json.dumps(formatted_edges, indent=2),
647
+ file_name=f'{protein_id}_visualized_edges{"_with_2nd_degree" if include_second_degree else ""}.json',
648
  mime='application/json'
649
  )
650
 
651
  with kg_viz_button_columns[1]:
652
  if st.button("Regenerate Visualization", key=f"regenerate_{protein_id}"):
653
+ # Clean up old files
654
+ if protein_id in st.session_state.protein_visualizations:
655
+ for viz_type in ['first_degree', 'second_degree']:
656
+ if viz_type in st.session_state.protein_visualizations[protein_id]:
657
+ try:
658
+ old_path = st.session_state.protein_visualizations[protein_id][viz_type]['path']
659
+ os.unlink(old_path)
660
+ except:
661
+ pass
662
+ # Remove from session state
663
+ del st.session_state.protein_visualizations[protein_id]
664
  st.rerun()
665
 
666
+ # Display the appropriate visualization
667
  with open(viz_info['path'], 'r', encoding='utf-8') as f:
668
  html_content = f.read()
669
 
670
  st.components.v1.html(html_content, height=1200)
671
 
 
 
 
672
  else:
673
  st.warning("Knowledge graph visualization is only available when 10 or fewer proteins are selected.")
visualize_kg.py CHANGED
@@ -22,13 +22,20 @@ EDGE_LABEL_TRANSLATION = {
22
  'Orthology': 'is ortholog to',
23
  'Pathway': 'takes part in',
24
  'kegg_path_prot': 'takes part in',
 
 
 
 
25
  'protein_domain': 'has',
26
  'PPI': 'interacts with',
27
  'HPO': 'is associated with',
28
  'kegg_dis_prot': 'is related to',
29
  'Disease': 'is related to',
30
  'Drug': 'targets',
 
31
  'protein_ec': 'catalyzes',
 
 
32
  'Chembl': 'targets',
33
  ('protein_function', 'GO_term_F'): 'enables',
34
  ('protein_function', 'GO_term_P'): 'is involved in',
@@ -168,14 +175,96 @@ def _filter_edges(protein_id, protein_edges, prediction_df, limit=10):
168
 
169
  return filtered_edges
170
 
 
 
 
 
 
 
 
 
 
171
 
172
- def visualize_protein_subgraph(data, protein_id, prediction_df, limit=10):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
 
174
  with gzip.open('data/name_info.json.gz', 'rt', encoding='utf-8') as file:
175
  name_info = json.load(file)
176
 
 
177
  protein_edges = _gather_protein_edges(data, protein_id)
178
- visualized_edges = _filter_edges(protein_id, protein_edges, prediction_df, limit)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  print(f'Edges to be visualized: {visualized_edges}')
180
 
181
  net = Network(height="600px", width="100%", directed=True, notebook=False)
@@ -259,7 +348,7 @@ def visualize_protein_subgraph(data, protein_id, prediction_df, limit=10):
259
  for edge_type, edges in visualized_edges.items():
260
  source_type, relation_type, target_type = edge_type
261
 
262
- if relation_type == 'protein_function':
263
  relation_type = EDGE_LABEL_TRANSLATION[(relation_type, target_type)]
264
  else:
265
  relation_type = EDGE_LABEL_TRANSLATION[relation_type]
@@ -449,7 +538,8 @@ def visualize_protein_subgraph(data, protein_id, prediction_df, limit=10):
449
 
450
  # Save graph to a protein-specific file in a temporary directory
451
  os.makedirs('temp_viz', exist_ok=True)
452
- file_path = os.path.join('temp_viz', f'{protein_id}_graph.html')
 
453
 
454
  net.save_graph(file_path)
455
 
 
22
  'Orthology': 'is ortholog to',
23
  'Pathway': 'takes part in',
24
  'kegg_path_prot': 'takes part in',
25
+ ('domain_function', 'GO_term_F'): 'enables',
26
+ ('domain_function', 'GO_term_P'): 'is involved in',
27
+ ('domain_function', 'GO_term_C'): 'localizes to',
28
+ 'function_function': 'ontological relationship',
29
  'protein_domain': 'has',
30
  'PPI': 'interacts with',
31
  'HPO': 'is associated with',
32
  'kegg_dis_prot': 'is related to',
33
  'Disease': 'is related to',
34
  'Drug': 'targets',
35
+ 'kegg_dis_path': 'modulates',
36
  'protein_ec': 'catalyzes',
37
+ 'hpodis': 'is associated with',
38
+ 'kegg_dis_drug': 'treats',
39
  'Chembl': 'targets',
40
  ('protein_function', 'GO_term_F'): 'enables',
41
  ('protein_function', 'GO_term_P'): 'is involved in',
 
175
 
176
  return filtered_edges
177
 
178
+ def _gather_neighbor_edges(data, node_id, node_type, exclude_node_id):
179
+ """Gather edges for a neighbor node, excluding edges back to the original query protein"""
180
+
181
+ node_idx = data[node_type]['id_mapping'][node_id]
182
+ reverse_id_mapping = {}
183
+ for ntype in data.node_types:
184
+ reverse_id_mapping[ntype] = {v:k for k, v in data[ntype]['id_mapping'].items()}
185
+
186
+ neighbor_edges = {}
187
 
188
+ for edge_type in data.edge_types:
189
+ if 'rev' not in edge_type[1]:
190
+ if edge_type not in neighbor_edges:
191
+ neighbor_edges[edge_type] = []
192
+
193
+ if edge_type[0] == node_type:
194
+ # Get edges where neighbor is source
195
+ edges = data[edge_type].edge_index[:, data[edge_type].edge_index[0] == node_idx]
196
+ edges = edges.T.tolist()
197
+ # Filter out edges going back to the query protein
198
+ edges = [edge for edge in edges if reverse_id_mapping[edge_type[2]][edge[1]] != exclude_node_id]
199
+ neighbor_edges[edge_type].extend(edges)
200
+
201
+ elif edge_type[2] == node_type:
202
+ # Get edges where neighbor is target
203
+ edges = data[edge_type].edge_index[:, data[edge_type].edge_index[1] == node_idx]
204
+ edges = edges.T.tolist()
205
+ # Filter out edges coming from the query protein
206
+ edges = [edge for edge in edges if reverse_id_mapping[edge_type[0]][edge[0]] != exclude_node_id]
207
+ neighbor_edges[edge_type].extend(edges)
208
+
209
+ # Map indices back to IDs
210
+ for edge_type in neighbor_edges.keys():
211
+ if neighbor_edges[edge_type]:
212
+ mapped_edges = set()
213
+ for edge in neighbor_edges[edge_type]:
214
+ source_type, _, target_type = edge_type
215
+ source_id = reverse_id_mapping[source_type][edge[0]]
216
+ target_id = reverse_id_mapping[target_type][edge[1]]
217
+ mapped_edges.add((source_id, target_id))
218
+ neighbor_edges[edge_type] = mapped_edges
219
+
220
+ return neighbor_edges
221
 
222
+ def visualize_protein_subgraph(data, protein_id, prediction_df, limit=10, second_degree_limit=3, include_second_degree=False):
223
  with gzip.open('data/name_info.json.gz', 'rt', encoding='utf-8') as file:
224
  name_info = json.load(file)
225
 
226
+ # Get the first-degree edges and filter them
227
  protein_edges = _gather_protein_edges(data, protein_id)
228
+ first_degree_edges = _filter_edges(protein_id, protein_edges, prediction_df, limit)
229
+
230
+ # Initialize all_edges with first degree edges
231
+ all_edges = first_degree_edges.copy()
232
+
233
+ if include_second_degree:
234
+ # Collect neighbor nodes from first-degree edges
235
+ neighbor_nodes = set()
236
+ for edge_type, edges in first_degree_edges.items():
237
+ source_type, _, target_type = edge_type
238
+ for edge_info in edges:
239
+ edge = edge_info[0]
240
+ source, target = edge
241
+ if source != protein_id:
242
+ neighbor_nodes.add((source, source_type))
243
+ if target != protein_id:
244
+ neighbor_nodes.add((target, target_type))
245
+
246
+ # Gather and filter second-degree edges
247
+ second_degree_edges = {}
248
+ for neighbor_id, neighbor_type in neighbor_nodes:
249
+ neighbor_edges = _gather_neighbor_edges(data, neighbor_id, neighbor_type, protein_id)
250
+ filtered_neighbor_edges = _filter_edges(neighbor_id, neighbor_edges, prediction_df, second_degree_limit)
251
+
252
+ # Merge filtered neighbor edges into second_degree_edges
253
+ for edge_type, edges in filtered_neighbor_edges.items():
254
+ if edge_type not in second_degree_edges:
255
+ second_degree_edges[edge_type] = []
256
+ second_degree_edges[edge_type].extend(edges)
257
+
258
+ # Merge first and second degree edges
259
+ for edge_type, edges in second_degree_edges.items():
260
+ if edge_type in all_edges:
261
+ all_edges[edge_type].extend(edges)
262
+ else:
263
+ all_edges[edge_type] = edges
264
+
265
+ # Update visualized_edges with all edges
266
+ visualized_edges = all_edges
267
+
268
  print(f'Edges to be visualized: {visualized_edges}')
269
 
270
  net = Network(height="600px", width="100%", directed=True, notebook=False)
 
348
  for edge_type, edges in visualized_edges.items():
349
  source_type, relation_type, target_type = edge_type
350
 
351
+ if relation_type in ['protein_function', 'domain_function']:
352
  relation_type = EDGE_LABEL_TRANSLATION[(relation_type, target_type)]
353
  else:
354
  relation_type = EDGE_LABEL_TRANSLATION[relation_type]
 
538
 
539
  # Save graph to a protein-specific file in a temporary directory
540
  os.makedirs('temp_viz', exist_ok=True)
541
+ suffix = "_with_2nd_degree" if include_second_degree else "_1st_degree"
542
+ file_path = os.path.join('temp_viz', f'{protein_id}_graph{suffix}.html')
543
 
544
  net.save_graph(file_path)
545