Erva Ulusoy commited on
Commit
24c5c6a
·
1 Parent(s): 4e751b2

initialize app

Browse files
Files changed (3) hide show
  1. ProtHGT_app.py +26 -0
  2. data/available_proteins.txt +0 -0
  3. run_prothgt_app.py +129 -0
ProtHGT_app.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import streamlit.components.v1 as components
3
+ import os
4
+ import time
5
+ import pandas as pd
6
+
7
+ from run_prothgt_app import *
8
+
9
+ def convert_df(df):
10
+ return df.to_csv(index=False).encode('utf-8')
11
+
12
+ with st.sidebar:
13
+ st.title("ProtHGT: Heterogeneous Graph Transformers for Automated Protein Function Prediction Using Knowledge Graphs and Language Models")
14
+ st.write("[![publication](https://img.shields.io/badge/DOI-10.1002/pro.4988-b31b1b.svg)]() [![github-repository](https://img.shields.io/badge/GitHub-black?logo=github)](https://github.com/HUBioDataLab/ProtHGT)")
15
+
16
+ # Add protein selection
17
+ # You'll need to replace this with your actual data loading
18
+ available_proteins = get_available_proteins() # Function to get list of proteins from your data
19
+ selected_protein = st.selectbox(
20
+ "Select or search for a protein (UniProt ID)",
21
+ options=available_proteins,
22
+ placeholder="Start typing to search...",
23
+ )
24
+
25
+ if selected_protein:
26
+ st.write(f"Selected protein: {selected_protein}")
data/available_proteins.txt ADDED
The diff for this file is too large to render. See raw diff
 
run_prothgt_app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from torch_geometric.transforms import ToUndirected
3
+ import torch
4
+ from torch.nn import Linear
5
+ from torch_geometric.nn import HGTConv, MLP
6
+ import pandas as pd
7
+
8
+ class ProtHGT(torch.nn.Module):
9
+ def __init__(self, data,hidden_channels, num_heads, num_layers, mlp_hidden_layers, mlp_dropout):
10
+ super().__init__()
11
+
12
+ self.lin_dict = torch.nn.ModuleDict({
13
+ node_type: Linear(-1, hidden_channels)
14
+ for node_type in data.node_types
15
+ })
16
+
17
+ self.convs = torch.nn.ModuleList()
18
+ for _ in range(num_layers):
19
+ conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads, group='sum')
20
+ self.convs.append(conv)
21
+
22
+ # self.left_linear = Linear(hidden_channels, hidden_channels)
23
+ # self.right_linear = Linear(hidden_channels, hidden_channels)
24
+ # self.sqrt_hd = hidden_channels**1/2
25
+
26
+ # self.mlp =MLP([2*hidden_channels, 128, 1], dropout=0.5, norm=None)
27
+ self.mlp = MLP(mlp_hidden_layers , dropout=mlp_dropout, norm=None)
28
+
29
+ def generate_embeddings(self, x_dict, edge_index_dict):
30
+ # Generate updated embeddings through the GNN layers
31
+ x_dict = {
32
+ node_type: self.lin_dict[node_type](x).relu_()
33
+ for node_type, x in x_dict.items()
34
+ }
35
+
36
+ for conv in self.convs:
37
+ x_dict = conv(x_dict, edge_index_dict)
38
+
39
+ return x_dict
40
+
41
+ def forward(self, x_dict, edge_index_dict, tr_edge_label_index, target_type, test=False):
42
+ # Get updated embeddings
43
+ x_dict = self.generate_embeddings(x_dict, edge_index_dict)
44
+
45
+ # Make predictions
46
+ row, col = tr_edge_label_index
47
+ z = torch.cat([x_dict["Protein"][row], x_dict[target_type][col]], dim=-1)
48
+
49
+ return self.mlp(z).view(-1), x_dict
50
+
51
+ def _load_data(protein_id, go_category=None, heterodata_path=''):
52
+ heterodata = load_dataset(heterodata_path)
53
+
54
+ # Remove unnecessary edge types in one go
55
+ edge_types_to_remove = [
56
+ ('Protein', 'protein_function', 'GO_term_F'),
57
+ ('Protein', 'protein_function', 'GO_term_P'),
58
+ ('Protein', 'protein_function', 'GO_term_C'),
59
+ ('GO_term_F', 'rev_protein_function', 'Protein'),
60
+ ('GO_term_P', 'rev_protein_function', 'Protein'),
61
+ ('GO_term_C', 'rev_protein_function', 'Protein')
62
+ ]
63
+
64
+ for edge_type in edge_types_to_remove:
65
+ if edge_type in heterodata:
66
+ del heterodata[edge_type]
67
+
68
+ # Remove reverse edges
69
+ heterodata = {k: v for k, v in heterodata.items() if not isinstance(k, tuple) or 'rev' not in k[1]}
70
+
71
+ protein_index = heterodata['Protein']['id_mapping'][protein_id]
72
+
73
+ # Create edge indices more efficiently
74
+ categories = [go_category] if go_category else ['GO_term_F', 'GO_term_P', 'GO_term_C']
75
+
76
+ for category in categories:
77
+ pairs = [(protein_index, i) for i in range(len(heterodata[category]))]
78
+ heterodata['Protein', 'protein_function', category] = {'edge_index': pairs}
79
+
80
+ return ToUndirected(merge=False)(heterodata)
81
+
82
+ def get_available_proteins(protein_list_file='data/available_proteins.txt'):
83
+ with open(protein_list_file, 'r') as file:
84
+ return [line.strip() for line in file.readlines()]
85
+
86
+ def _generate_predictions(heterodata, model_path, model_config, target_type):
87
+
88
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
89
+ model = ProtHGT(heterodata, model_config['hidden_channels'], model_config['num_heads'], model_config['num_layers'], model_config['mlp_hidden_layers'], model_config['mlp_dropout'])
90
+ print('Loading model from', model_path)
91
+ model.load_state_dict(torch.load(model_path, map_location=device))
92
+
93
+ model.to(device)
94
+ model.eval()
95
+ heterodata.to(device)
96
+
97
+ with torch.no_grad():
98
+ predictions, _ = model(heterodata.x_dict, heterodata.edge_index_dict, heterodata[("Protein", "protein_function", target_type)].edge_label_index, target_type)
99
+ return predictions
100
+
101
+ def _create_prediction_df(predictions, heterodata, protein_id, go_category):
102
+ prediction_df = pd.DataFrame({
103
+ 'Protein': protein_id,
104
+ 'GO_category': go_category,
105
+ 'GO_term': heterodata[go_category]['id_mapping'].keys(),
106
+ 'Probability': predictions.tolist()
107
+ })
108
+ prediction_df.sort_values(by='Probability', ascending=False, inplace=True)
109
+ prediction_df.reset_index(drop=True, inplace=True)
110
+ return prediction_df
111
+
112
+
113
+ def generate_prediction_df(protein_id, heterodata_path, model_path, model_config, go_category=None):
114
+ heterodata = _load_data(protein_id, go_category, heterodata_path)
115
+
116
+ if go_category:
117
+ predictions = _generate_predictions(heterodata, model_path, model_config, go_category)
118
+ prediction_df = _create_prediction_df(predictions, heterodata, protein_id, go_category)
119
+ return prediction_df
120
+
121
+ else:
122
+ all_predictions = []
123
+ for go_category in ['GO_term_F', 'GO_term_P', 'GO_term_C']:
124
+ predictions = _generate_predictions(heterodata, model_path, model_config, go_category)
125
+ category_df = _create_prediction_df(predictions, heterodata, protein_id, go_category)
126
+ all_predictions.append(category_df)
127
+
128
+ return pd.concat(all_predictions, ignore_index=True)
129
+