Spaces:
Running
Running
from pyvis.network import Network | |
import os | |
NODE_TYPE_COLORS = { | |
'Disease': '#079dbb', | |
'HPO': '#58d0e8', | |
'Drug': '#815ac0', | |
'Compound': '#d2b7e5', | |
'Domain': '#6bbf59', | |
'GO_term_P': '#ff8800', | |
'GO_term_F': '#ffaa00', | |
'GO_term_C': '#ffc300', | |
'Pathway': '#720026', | |
'kegg_Pathway': '#720026', | |
'EC_number': '#ce4257', | |
'Protein': '#3aa6a4' | |
} | |
EDGE_LABEL_TRANSLATION = { | |
'Orthology': 'is ortholog to', | |
'Pathway': 'takes part in', | |
'kegg_path_prot': 'takes part in', | |
'protein_domain': 'has', | |
'PPI': 'interacts with', | |
'HPO': 'is associated with', | |
'kegg_dis_prot': 'is related to', | |
'Disease': 'is related to', | |
'Drug': 'targets', | |
'protein_ec': 'catalyzes', | |
'Chembl': 'targets', | |
('protein_function', 'GO_term_F'): 'enables', | |
('protein_function', 'GO_term_P'): 'is involved in', | |
('protein_function', 'GO_term_C'): 'localizes to', | |
} | |
GO_CATEGORY_MAPPING = { | |
'Biological Process': 'GO_term_P', | |
'Molecular Function': 'GO_term_F', | |
'Cellular Component': 'GO_term_C' | |
} | |
def _gather_protein_edges(data, protein_id): | |
protein_idx = data['Protein']['id_mapping'][protein_id] | |
reverse_id_mapping = {} | |
for node_type in data.node_types: | |
reverse_id_mapping[node_type] = {v:k for k, v in data[node_type]['id_mapping'].items()} | |
protein_edges = {} | |
print(f'Gathering edges for {protein_id}...') | |
for edge_type in data.edge_types: | |
if 'rev' not in edge_type[1]: | |
if edge_type not in protein_edges: | |
protein_edges[edge_type] = [] | |
if edge_type[0] == 'Protein': | |
print(f'Gathering edges for {edge_type}...') | |
# append the edges with protein_idx as source node | |
edges = data[edge_type].edge_index[:, data[edge_type].edge_index[0] == protein_idx] | |
protein_edges[edge_type].extend(edges.T.tolist()) | |
elif edge_type[2] == 'Protein': | |
print(f'Gathering edges for {edge_type}...') | |
# append the edges with protein_idx as target node | |
edges = data[edge_type].edge_index[:, data[edge_type].edge_index[1] == protein_idx] | |
protein_edges[edge_type].extend(edges.T.tolist()) | |
for edge_type in protein_edges.keys(): | |
if protein_edges[edge_type]: | |
mapped_edges = set() | |
for edge in protein_edges[edge_type]: | |
# Get source and target node types from edge_type | |
source_type, _, target_type = edge_type | |
# Map indices back to original IDs | |
source_id = reverse_id_mapping[source_type][edge[0]] | |
target_id = reverse_id_mapping[target_type][edge[1]] | |
mapped_edges.add((source_id, target_id)) | |
protein_edges[edge_type] = mapped_edges | |
return protein_edges | |
def _filter_edges(protein_id, protein_edges, prediction_df, limit=10): | |
filtered_edges = {} | |
prediction_categories = prediction_df['GO_category'].unique() | |
prediction_categories = [GO_CATEGORY_MAPPING[category] for category in prediction_categories] | |
go_category_reverse_mapping = {v:k for k, v in GO_CATEGORY_MAPPING.items()} | |
for edge_type, edges in protein_edges.items(): | |
# Skip if edges is empty | |
if edges is None or len(edges) == 0: | |
continue | |
if edge_type[2].startswith('GO_term'): # Check if it's any GO term edge | |
if edge_type[2] in prediction_categories: | |
# Handle edges for GO terms that are in prediction_df | |
category_mask = (prediction_df['GO_category'] == go_category_reverse_mapping[edge_type[2]]) & (prediction_df['UniProt_ID'] == protein_id) | |
category_predictions = prediction_df[category_mask] | |
if len(category_predictions) > 0: | |
category_predictions = category_predictions.sort_values(by='Probability', ascending=False) | |
edges_set = set(edges) # Convert to set for O(1) lookup | |
valid_edges = [] | |
for _, row in category_predictions.iterrows(): | |
term = row['GO_ID'] | |
prob = row['Probability'] | |
edge = (protein_id, term) | |
is_ground_truth = edge in edges_set | |
valid_edges.append((edge, prob, is_ground_truth)) | |
if len(valid_edges) >= limit: | |
break | |
filtered_edges[edge_type] = valid_edges | |
else: | |
# If no predictions but it's a GO category in prediction_df | |
filtered_edges[edge_type] = [(edge, 'no_pred', True) for edge in list(edges)[:limit]] | |
else: | |
# For GO terms not in prediction_df, mark them as ground truth with blue color | |
filtered_edges[edge_type] = [(edge, 'no_pred', True) for edge in list(edges)[:limit]] | |
else: | |
# For non-GO edges, include all edges up to limit | |
filtered_edges[edge_type] = [(edge, None, True) for edge in list(edges)[:limit]] | |
return filtered_edges | |
def visualize_protein_subgraph(data, protein_id, prediction_df, limit=10): | |
protein_edges = _gather_protein_edges(data, protein_id) | |
visualized_edges = _filter_edges(protein_id, protein_edges, prediction_df, limit) | |
print(f'Edges to be visualized: {visualized_edges}') | |
net = Network(height="600px", width="100%", directed=True, notebook=False) | |
# Create groups configuration from NODE_TYPE_COLORS | |
groups_config = {} | |
for node_type, color in NODE_TYPE_COLORS.items(): | |
groups_config[node_type] = { | |
"color": {"background": color, "border": color} | |
} | |
# Convert groups_config to a JSON-compatible string | |
import json | |
groups_json = json.dumps(groups_config) | |
# Configure physics options with settings for better clustering | |
net.set_options("""{ | |
"physics": { | |
"enabled": true, | |
"barnesHut": { | |
"gravitationalConstant": -1000, | |
"springLength": 250, | |
"springConstant": 0.001, | |
"damping": 0.09, | |
"avoidOverlap": 0 | |
}, | |
"forceAtlas2Based": { | |
"gravitationalConstant": -50, | |
"centralGravity": 0.01, | |
"springLength": 100, | |
"springConstant": 0.08, | |
"damping": 0.4, | |
"avoidOverlap": 0 | |
}, | |
"solver": "barnesHut", | |
"stabilization": { | |
"enabled": true, | |
"iterations": 1000, | |
"updateInterval": 25 | |
} | |
}, | |
"layout": { | |
"improvedLayout": true, | |
"hierarchical": { | |
"enabled": false | |
} | |
}, | |
"interaction": { | |
"hover": true, | |
"navigationButtons": true, | |
"multiselect": true | |
}, | |
"configure": { | |
"enabled": true, | |
"filter": ["physics", "layout", "manipulation"], | |
"showButton": true | |
}, | |
"groups": """ + groups_json + "}") | |
# Add the main protein node | |
net.add_node(protein_id, | |
label=f"Protein: {protein_id}", | |
color={'background': 'white', 'border': '#c1121f'}, | |
borderWidth=4, | |
shape="dot", | |
font={'color': '#000000', 'size': 15}, | |
group='Protein', | |
size=30, | |
mass=2.5) | |
# Track added nodes to avoid duplication | |
added_nodes = {protein_id} | |
# Add edges and target nodes | |
for edge_type, edges in visualized_edges.items(): | |
source_type, relation_type, target_type = edge_type | |
if relation_type == 'protein_function': | |
relation_type = EDGE_LABEL_TRANSLATION[(relation_type, target_type)] | |
else: | |
relation_type = EDGE_LABEL_TRANSLATION[relation_type] | |
for edge_info in edges: | |
edge, probability, is_ground_truth = edge_info | |
source, target = edge[0], edge[1] | |
source_str = str(source) | |
target_str = str(target) | |
# Add source node if not present | |
if source_str not in added_nodes: | |
net.add_node(source_str, | |
label=f"{source_str}", | |
shape="dot", | |
font={'color': '#000000', 'size': 12}, | |
title=f"{source_type}: {source_str}", | |
group=source_type, | |
size=15, | |
mass=1.5) | |
added_nodes.add(source_str) | |
# Add target node if not present | |
if target_str not in added_nodes: | |
net.add_node(target_str, | |
label=f"{target_str}", | |
shape="dot", | |
font={'color': '#000000', 'size': 12}, | |
title=f"{target_type}: {target_str}", | |
group=target_type, | |
size=15, | |
mass=1.5) | |
added_nodes.add(target_str) | |
# Add edge with relationship type and probability as label | |
edge_label = f"{relation_type}" | |
if probability is not None: | |
if probability == 'no_pred': | |
edge_color = '#219ebc' | |
edge_label += ' (P=Not generated)' | |
else: | |
edge_label += f" (P={probability:.2f})" | |
edge_color = '#8338ec' if is_ground_truth else '#c1121f' | |
# if validated prediction purple, if non-validated prediction red, if no prediction (directly from database) blue | |
net.add_edge(source_str, target_str, | |
label=edge_label, | |
font={'size': 0}, | |
color=edge_color, | |
title=edge_label, | |
length=200, | |
smooth={'type': 'curvedCW', 'roundness': 0.1}) | |
else: | |
net.add_edge(source_str, target_str, | |
label=edge_label, | |
font={'size': 0}, | |
color='#666666', # Keep default gray for non-GO edges | |
title=edge_label, | |
length=200, | |
smooth={'type': 'curvedCW', 'roundness': 0.1}) | |
# Save graph to a protein-specific file in a temporary directory | |
os.makedirs('temp_viz', exist_ok=True) | |
file_path = os.path.join('temp_viz', f'{protein_id}_graph.html') | |
net.save_graph(file_path) | |
return file_path, visualized_edges |