Spaces:
Sleeping
Sleeping
Erva Ulusoy
commited on
Commit
·
14c3500
1
Parent(s):
85b27f1
include second degree edges (major update)
Browse files- ProtHGT_app.py +75 -41
- visualize_kg.py +94 -4
ProtHGT_app.py
CHANGED
@@ -70,7 +70,7 @@ with st.expander("🚀 Upcoming Features"):
|
|
70 |
|
71 |
- **Real-time data retrieval for new proteins**: Currently, ProtHGT can only generate predictions for proteins that already exist in our knowledge graph. We are developing a new feature that will allow users to **predict functions for entirely new proteins starting from their sequences**. This will work by **retrieving relevant relationship data in real time from external source databases** (e.g., UniProt, STRING, and other biological repositories). The system will dynamically construct a knowledge graph for the query protein, incorporating its interactions, domains, pathways, and other biological associations before running function prediction. This approach will enable ProtHGT to analyze newly discovered or less-studied proteins even if they are not pre-annotated in our dataset.
|
72 |
- **Expanded embedding options**: Currently, this application represents proteins using **TAPE embeddings**, which serve as the initial numerical representations of protein sequences before being processed in the heterogeneous graph model. We are working on integrating **ProtT5** and **ESM-2** as alternative initial embeddings, allowing users to choose different sequence representations that may enhance performance for specific tasks. A detailed comparison of how these embeddings influence function prediction accuracy will be included in our upcoming publication.
|
73 |
-
- **Knowledge graph visualization for interpretability**: To improve model explainability, we are developing an interactive **knowledge graph visualization** feature. This will allow users to explore the biological relationships that contributed to ProtHGT
|
74 |
|
75 |
Stay tuned for updates and future publications!
|
76 |
""")
|
@@ -562,78 +562,112 @@ if st.session_state.submitted:
|
|
562 |
# Create visualizations in each tab
|
563 |
for idx, protein_id in enumerate(selected_proteins):
|
564 |
with protein_tabs[idx]:
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
|
|
|
|
576 |
viz_exists = (protein_id in st.session_state.protein_visualizations and
|
577 |
-
|
|
|
578 |
|
579 |
if not viz_exists:
|
580 |
if st.button(f"Generate Visualization", key=f"viz_{protein_id}"):
|
581 |
-
#
|
582 |
-
|
|
|
|
|
|
|
|
|
|
|
583 |
st.session_state.heterodata,
|
584 |
protein_id,
|
585 |
st.session_state.predictions_df,
|
586 |
-
limit=max_node_count
|
|
|
587 |
)
|
588 |
|
589 |
-
#
|
590 |
-
|
591 |
-
|
592 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
593 |
}
|
594 |
st.rerun()
|
595 |
|
596 |
-
# If visualization exists, display
|
597 |
if viz_exists:
|
598 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
599 |
|
600 |
-
#
|
601 |
-
|
602 |
-
|
603 |
-
edge_type_str = f"{edge_type[0]}_{edge_type[1]}_{edge_type[2]}"
|
604 |
-
formatted_edges[edge_type_str] = [
|
605 |
-
{"source": edge[0][0], "target": edge[0][1], "probability": edge[1]}
|
606 |
-
for edge in edges
|
607 |
-
]
|
608 |
|
609 |
kg_viz_button_columns = st.columns([1, 1, 1])
|
610 |
|
611 |
with kg_viz_button_columns[0]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
612 |
st.download_button(
|
613 |
label='Download Visualized Edges',
|
614 |
data=json.dumps(formatted_edges, indent=2),
|
615 |
-
file_name=f'{protein_id}_visualized_edges.json',
|
616 |
mime='application/json'
|
617 |
)
|
618 |
|
619 |
with kg_viz_button_columns[1]:
|
620 |
if st.button("Regenerate Visualization", key=f"regenerate_{protein_id}"):
|
621 |
-
# Clean up old
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
|
|
|
|
|
|
|
|
628 |
st.rerun()
|
629 |
|
|
|
630 |
with open(viz_info['path'], 'r', encoding='utf-8') as f:
|
631 |
html_content = f.read()
|
632 |
|
633 |
st.components.v1.html(html_content, height=1200)
|
634 |
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
else:
|
639 |
st.warning("Knowledge graph visualization is only available when 10 or fewer proteins are selected.")
|
|
|
70 |
|
71 |
- **Real-time data retrieval for new proteins**: Currently, ProtHGT can only generate predictions for proteins that already exist in our knowledge graph. We are developing a new feature that will allow users to **predict functions for entirely new proteins starting from their sequences**. This will work by **retrieving relevant relationship data in real time from external source databases** (e.g., UniProt, STRING, and other biological repositories). The system will dynamically construct a knowledge graph for the query protein, incorporating its interactions, domains, pathways, and other biological associations before running function prediction. This approach will enable ProtHGT to analyze newly discovered or less-studied proteins even if they are not pre-annotated in our dataset.
|
72 |
- **Expanded embedding options**: Currently, this application represents proteins using **TAPE embeddings**, which serve as the initial numerical representations of protein sequences before being processed in the heterogeneous graph model. We are working on integrating **ProtT5** and **ESM-2** as alternative initial embeddings, allowing users to choose different sequence representations that may enhance performance for specific tasks. A detailed comparison of how these embeddings influence function prediction accuracy will be included in our upcoming publication.
|
73 |
+
- **Knowledge graph visualization for interpretability**: To improve model explainability, we are developing an interactive **knowledge graph visualization** feature. This will allow users to explore the biological relationships that contributed to ProtHGT's predictions for a given protein. Users will be able to inspect **protein interactions, GO annotations, domains, pathways, and other key connections** in a structured graphical format, making it easier to interpret and validate predictions.
|
74 |
|
75 |
Stay tuned for updates and future publications!
|
76 |
""")
|
|
|
562 |
# Create visualizations in each tab
|
563 |
for idx, protein_id in enumerate(selected_proteins):
|
564 |
with protein_tabs[idx]:
|
565 |
+
col1, col2 = st.columns([3, 1])
|
566 |
+
with col1:
|
567 |
+
max_node_count = st.slider(
|
568 |
+
"Maximum neighbors per edge type",
|
569 |
+
min_value=5,
|
570 |
+
max_value=50,
|
571 |
+
value=10,
|
572 |
+
step=5,
|
573 |
+
help="Control the maximum number of neighboring nodes shown for each relationship type",
|
574 |
+
key=f"slider_{protein_id}"
|
575 |
+
)
|
576 |
+
|
577 |
+
# Check if both visualizations exist for this protein
|
578 |
viz_exists = (protein_id in st.session_state.protein_visualizations and
|
579 |
+
'first_degree' in st.session_state.protein_visualizations[protein_id] and
|
580 |
+
'second_degree' in st.session_state.protein_visualizations[protein_id])
|
581 |
|
582 |
if not viz_exists:
|
583 |
if st.button(f"Generate Visualization", key=f"viz_{protein_id}"):
|
584 |
+
# Initialize the protein's visualizations if not exists
|
585 |
+
if protein_id not in st.session_state.protein_visualizations:
|
586 |
+
st.session_state.protein_visualizations[protein_id] = {}
|
587 |
+
|
588 |
+
# Generate both visualizations upfront
|
589 |
+
# First degree only
|
590 |
+
html_path_1st, edges_1st = visualize_protein_subgraph(
|
591 |
st.session_state.heterodata,
|
592 |
protein_id,
|
593 |
st.session_state.predictions_df,
|
594 |
+
limit=max_node_count,
|
595 |
+
include_second_degree=False
|
596 |
)
|
597 |
|
598 |
+
# With second degree
|
599 |
+
html_path_2nd, edges_2nd = visualize_protein_subgraph(
|
600 |
+
st.session_state.heterodata,
|
601 |
+
protein_id,
|
602 |
+
st.session_state.predictions_df,
|
603 |
+
limit=max_node_count,
|
604 |
+
include_second_degree=True
|
605 |
+
)
|
606 |
+
|
607 |
+
# Store both visualizations in session state
|
608 |
+
st.session_state.protein_visualizations[protein_id]['first_degree'] = {
|
609 |
+
'path': html_path_1st,
|
610 |
+
'edges': edges_1st
|
611 |
+
}
|
612 |
+
st.session_state.protein_visualizations[protein_id]['second_degree'] = {
|
613 |
+
'path': html_path_2nd,
|
614 |
+
'edges': edges_2nd
|
615 |
}
|
616 |
st.rerun()
|
617 |
|
618 |
+
# If visualization exists, show the toggle and display appropriate version
|
619 |
if viz_exists:
|
620 |
+
with col2:
|
621 |
+
include_second_degree = st.checkbox(
|
622 |
+
"Include second-degree edges",
|
623 |
+
value=False,
|
624 |
+
key=f"second_degree_{protein_id}",
|
625 |
+
help="Show connections between neighbor nodes"
|
626 |
+
)
|
627 |
|
628 |
+
# Get the appropriate visualization based on checkbox
|
629 |
+
viz_type = 'second_degree' if include_second_degree else 'first_degree'
|
630 |
+
viz_info = st.session_state.protein_visualizations[protein_id][viz_type]
|
|
|
|
|
|
|
|
|
|
|
631 |
|
632 |
kg_viz_button_columns = st.columns([1, 1, 1])
|
633 |
|
634 |
with kg_viz_button_columns[0]:
|
635 |
+
# Format edges for download
|
636 |
+
formatted_edges = {}
|
637 |
+
for edge_type, edges in viz_info['edges'].items():
|
638 |
+
edge_type_str = f"{edge_type[0]}_{edge_type[1]}_{edge_type[2]}"
|
639 |
+
formatted_edges[edge_type_str] = [
|
640 |
+
{"source": edge[0][0], "target": edge[0][1], "probability": edge[1]}
|
641 |
+
for edge in edges
|
642 |
+
]
|
643 |
+
|
644 |
st.download_button(
|
645 |
label='Download Visualized Edges',
|
646 |
data=json.dumps(formatted_edges, indent=2),
|
647 |
+
file_name=f'{protein_id}_visualized_edges{"_with_2nd_degree" if include_second_degree else ""}.json',
|
648 |
mime='application/json'
|
649 |
)
|
650 |
|
651 |
with kg_viz_button_columns[1]:
|
652 |
if st.button("Regenerate Visualization", key=f"regenerate_{protein_id}"):
|
653 |
+
# Clean up old files
|
654 |
+
if protein_id in st.session_state.protein_visualizations:
|
655 |
+
for viz_type in ['first_degree', 'second_degree']:
|
656 |
+
if viz_type in st.session_state.protein_visualizations[protein_id]:
|
657 |
+
try:
|
658 |
+
old_path = st.session_state.protein_visualizations[protein_id][viz_type]['path']
|
659 |
+
os.unlink(old_path)
|
660 |
+
except:
|
661 |
+
pass
|
662 |
+
# Remove from session state
|
663 |
+
del st.session_state.protein_visualizations[protein_id]
|
664 |
st.rerun()
|
665 |
|
666 |
+
# Display the appropriate visualization
|
667 |
with open(viz_info['path'], 'r', encoding='utf-8') as f:
|
668 |
html_content = f.read()
|
669 |
|
670 |
st.components.v1.html(html_content, height=1200)
|
671 |
|
|
|
|
|
|
|
672 |
else:
|
673 |
st.warning("Knowledge graph visualization is only available when 10 or fewer proteins are selected.")
|
visualize_kg.py
CHANGED
@@ -22,13 +22,20 @@ EDGE_LABEL_TRANSLATION = {
|
|
22 |
'Orthology': 'is ortholog to',
|
23 |
'Pathway': 'takes part in',
|
24 |
'kegg_path_prot': 'takes part in',
|
|
|
|
|
|
|
|
|
25 |
'protein_domain': 'has',
|
26 |
'PPI': 'interacts with',
|
27 |
'HPO': 'is associated with',
|
28 |
'kegg_dis_prot': 'is related to',
|
29 |
'Disease': 'is related to',
|
30 |
'Drug': 'targets',
|
|
|
31 |
'protein_ec': 'catalyzes',
|
|
|
|
|
32 |
'Chembl': 'targets',
|
33 |
('protein_function', 'GO_term_F'): 'enables',
|
34 |
('protein_function', 'GO_term_P'): 'is involved in',
|
@@ -168,14 +175,96 @@ def _filter_edges(protein_id, protein_edges, prediction_df, limit=10):
|
|
168 |
|
169 |
return filtered_edges
|
170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
|
|
174 |
with gzip.open('data/name_info.json.gz', 'rt', encoding='utf-8') as file:
|
175 |
name_info = json.load(file)
|
176 |
|
|
|
177 |
protein_edges = _gather_protein_edges(data, protein_id)
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
print(f'Edges to be visualized: {visualized_edges}')
|
180 |
|
181 |
net = Network(height="600px", width="100%", directed=True, notebook=False)
|
@@ -259,7 +348,7 @@ def visualize_protein_subgraph(data, protein_id, prediction_df, limit=10):
|
|
259 |
for edge_type, edges in visualized_edges.items():
|
260 |
source_type, relation_type, target_type = edge_type
|
261 |
|
262 |
-
if relation_type
|
263 |
relation_type = EDGE_LABEL_TRANSLATION[(relation_type, target_type)]
|
264 |
else:
|
265 |
relation_type = EDGE_LABEL_TRANSLATION[relation_type]
|
@@ -449,7 +538,8 @@ def visualize_protein_subgraph(data, protein_id, prediction_df, limit=10):
|
|
449 |
|
450 |
# Save graph to a protein-specific file in a temporary directory
|
451 |
os.makedirs('temp_viz', exist_ok=True)
|
452 |
-
|
|
|
453 |
|
454 |
net.save_graph(file_path)
|
455 |
|
|
|
22 |
'Orthology': 'is ortholog to',
|
23 |
'Pathway': 'takes part in',
|
24 |
'kegg_path_prot': 'takes part in',
|
25 |
+
('domain_function', 'GO_term_F'): 'enables',
|
26 |
+
('domain_function', 'GO_term_P'): 'is involved in',
|
27 |
+
('domain_function', 'GO_term_C'): 'localizes to',
|
28 |
+
'function_function': 'ontological relationship',
|
29 |
'protein_domain': 'has',
|
30 |
'PPI': 'interacts with',
|
31 |
'HPO': 'is associated with',
|
32 |
'kegg_dis_prot': 'is related to',
|
33 |
'Disease': 'is related to',
|
34 |
'Drug': 'targets',
|
35 |
+
'kegg_dis_path': 'modulates',
|
36 |
'protein_ec': 'catalyzes',
|
37 |
+
'hpodis': 'is associated with',
|
38 |
+
'kegg_dis_drug': 'treats',
|
39 |
'Chembl': 'targets',
|
40 |
('protein_function', 'GO_term_F'): 'enables',
|
41 |
('protein_function', 'GO_term_P'): 'is involved in',
|
|
|
175 |
|
176 |
return filtered_edges
|
177 |
|
178 |
+
def _gather_neighbor_edges(data, node_id, node_type, exclude_node_id):
|
179 |
+
"""Gather edges for a neighbor node, excluding edges back to the original query protein"""
|
180 |
+
|
181 |
+
node_idx = data[node_type]['id_mapping'][node_id]
|
182 |
+
reverse_id_mapping = {}
|
183 |
+
for ntype in data.node_types:
|
184 |
+
reverse_id_mapping[ntype] = {v:k for k, v in data[ntype]['id_mapping'].items()}
|
185 |
+
|
186 |
+
neighbor_edges = {}
|
187 |
|
188 |
+
for edge_type in data.edge_types:
|
189 |
+
if 'rev' not in edge_type[1]:
|
190 |
+
if edge_type not in neighbor_edges:
|
191 |
+
neighbor_edges[edge_type] = []
|
192 |
+
|
193 |
+
if edge_type[0] == node_type:
|
194 |
+
# Get edges where neighbor is source
|
195 |
+
edges = data[edge_type].edge_index[:, data[edge_type].edge_index[0] == node_idx]
|
196 |
+
edges = edges.T.tolist()
|
197 |
+
# Filter out edges going back to the query protein
|
198 |
+
edges = [edge for edge in edges if reverse_id_mapping[edge_type[2]][edge[1]] != exclude_node_id]
|
199 |
+
neighbor_edges[edge_type].extend(edges)
|
200 |
+
|
201 |
+
elif edge_type[2] == node_type:
|
202 |
+
# Get edges where neighbor is target
|
203 |
+
edges = data[edge_type].edge_index[:, data[edge_type].edge_index[1] == node_idx]
|
204 |
+
edges = edges.T.tolist()
|
205 |
+
# Filter out edges coming from the query protein
|
206 |
+
edges = [edge for edge in edges if reverse_id_mapping[edge_type[0]][edge[0]] != exclude_node_id]
|
207 |
+
neighbor_edges[edge_type].extend(edges)
|
208 |
+
|
209 |
+
# Map indices back to IDs
|
210 |
+
for edge_type in neighbor_edges.keys():
|
211 |
+
if neighbor_edges[edge_type]:
|
212 |
+
mapped_edges = set()
|
213 |
+
for edge in neighbor_edges[edge_type]:
|
214 |
+
source_type, _, target_type = edge_type
|
215 |
+
source_id = reverse_id_mapping[source_type][edge[0]]
|
216 |
+
target_id = reverse_id_mapping[target_type][edge[1]]
|
217 |
+
mapped_edges.add((source_id, target_id))
|
218 |
+
neighbor_edges[edge_type] = mapped_edges
|
219 |
+
|
220 |
+
return neighbor_edges
|
221 |
|
222 |
+
def visualize_protein_subgraph(data, protein_id, prediction_df, limit=10, second_degree_limit=3, include_second_degree=False):
|
223 |
with gzip.open('data/name_info.json.gz', 'rt', encoding='utf-8') as file:
|
224 |
name_info = json.load(file)
|
225 |
|
226 |
+
# Get the first-degree edges and filter them
|
227 |
protein_edges = _gather_protein_edges(data, protein_id)
|
228 |
+
first_degree_edges = _filter_edges(protein_id, protein_edges, prediction_df, limit)
|
229 |
+
|
230 |
+
# Initialize all_edges with first degree edges
|
231 |
+
all_edges = first_degree_edges.copy()
|
232 |
+
|
233 |
+
if include_second_degree:
|
234 |
+
# Collect neighbor nodes from first-degree edges
|
235 |
+
neighbor_nodes = set()
|
236 |
+
for edge_type, edges in first_degree_edges.items():
|
237 |
+
source_type, _, target_type = edge_type
|
238 |
+
for edge_info in edges:
|
239 |
+
edge = edge_info[0]
|
240 |
+
source, target = edge
|
241 |
+
if source != protein_id:
|
242 |
+
neighbor_nodes.add((source, source_type))
|
243 |
+
if target != protein_id:
|
244 |
+
neighbor_nodes.add((target, target_type))
|
245 |
+
|
246 |
+
# Gather and filter second-degree edges
|
247 |
+
second_degree_edges = {}
|
248 |
+
for neighbor_id, neighbor_type in neighbor_nodes:
|
249 |
+
neighbor_edges = _gather_neighbor_edges(data, neighbor_id, neighbor_type, protein_id)
|
250 |
+
filtered_neighbor_edges = _filter_edges(neighbor_id, neighbor_edges, prediction_df, second_degree_limit)
|
251 |
+
|
252 |
+
# Merge filtered neighbor edges into second_degree_edges
|
253 |
+
for edge_type, edges in filtered_neighbor_edges.items():
|
254 |
+
if edge_type not in second_degree_edges:
|
255 |
+
second_degree_edges[edge_type] = []
|
256 |
+
second_degree_edges[edge_type].extend(edges)
|
257 |
+
|
258 |
+
# Merge first and second degree edges
|
259 |
+
for edge_type, edges in second_degree_edges.items():
|
260 |
+
if edge_type in all_edges:
|
261 |
+
all_edges[edge_type].extend(edges)
|
262 |
+
else:
|
263 |
+
all_edges[edge_type] = edges
|
264 |
+
|
265 |
+
# Update visualized_edges with all edges
|
266 |
+
visualized_edges = all_edges
|
267 |
+
|
268 |
print(f'Edges to be visualized: {visualized_edges}')
|
269 |
|
270 |
net = Network(height="600px", width="100%", directed=True, notebook=False)
|
|
|
348 |
for edge_type, edges in visualized_edges.items():
|
349 |
source_type, relation_type, target_type = edge_type
|
350 |
|
351 |
+
if relation_type in ['protein_function', 'domain_function']:
|
352 |
relation_type = EDGE_LABEL_TRANSLATION[(relation_type, target_type)]
|
353 |
else:
|
354 |
relation_type = EDGE_LABEL_TRANSLATION[relation_type]
|
|
|
538 |
|
539 |
# Save graph to a protein-specific file in a temporary directory
|
540 |
os.makedirs('temp_viz', exist_ok=True)
|
541 |
+
suffix = "_with_2nd_degree" if include_second_degree else "_1st_degree"
|
542 |
+
file_path = os.path.join('temp_viz', f'{protein_id}_graph{suffix}.html')
|
543 |
|
544 |
net.save_graph(file_path)
|
545 |
|