ProtHGT / run_prothgt_app.py
Erva Ulusoy
initialize app
24c5c6a
raw
history blame
5.3 kB
from datasets import load_dataset
from torch_geometric.transforms import ToUndirected
import torch
from torch.nn import Linear
from torch_geometric.nn import HGTConv, MLP
import pandas as pd
class ProtHGT(torch.nn.Module):
def __init__(self, data,hidden_channels, num_heads, num_layers, mlp_hidden_layers, mlp_dropout):
super().__init__()
self.lin_dict = torch.nn.ModuleDict({
node_type: Linear(-1, hidden_channels)
for node_type in data.node_types
})
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads, group='sum')
self.convs.append(conv)
# self.left_linear = Linear(hidden_channels, hidden_channels)
# self.right_linear = Linear(hidden_channels, hidden_channels)
# self.sqrt_hd = hidden_channels**1/2
# self.mlp =MLP([2*hidden_channels, 128, 1], dropout=0.5, norm=None)
self.mlp = MLP(mlp_hidden_layers , dropout=mlp_dropout, norm=None)
def generate_embeddings(self, x_dict, edge_index_dict):
# Generate updated embeddings through the GNN layers
x_dict = {
node_type: self.lin_dict[node_type](x).relu_()
for node_type, x in x_dict.items()
}
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
return x_dict
def forward(self, x_dict, edge_index_dict, tr_edge_label_index, target_type, test=False):
# Get updated embeddings
x_dict = self.generate_embeddings(x_dict, edge_index_dict)
# Make predictions
row, col = tr_edge_label_index
z = torch.cat([x_dict["Protein"][row], x_dict[target_type][col]], dim=-1)
return self.mlp(z).view(-1), x_dict
def _load_data(protein_id, go_category=None, heterodata_path=''):
heterodata = load_dataset(heterodata_path)
# Remove unnecessary edge types in one go
edge_types_to_remove = [
('Protein', 'protein_function', 'GO_term_F'),
('Protein', 'protein_function', 'GO_term_P'),
('Protein', 'protein_function', 'GO_term_C'),
('GO_term_F', 'rev_protein_function', 'Protein'),
('GO_term_P', 'rev_protein_function', 'Protein'),
('GO_term_C', 'rev_protein_function', 'Protein')
]
for edge_type in edge_types_to_remove:
if edge_type in heterodata:
del heterodata[edge_type]
# Remove reverse edges
heterodata = {k: v for k, v in heterodata.items() if not isinstance(k, tuple) or 'rev' not in k[1]}
protein_index = heterodata['Protein']['id_mapping'][protein_id]
# Create edge indices more efficiently
categories = [go_category] if go_category else ['GO_term_F', 'GO_term_P', 'GO_term_C']
for category in categories:
pairs = [(protein_index, i) for i in range(len(heterodata[category]))]
heterodata['Protein', 'protein_function', category] = {'edge_index': pairs}
return ToUndirected(merge=False)(heterodata)
def get_available_proteins(protein_list_file='data/available_proteins.txt'):
with open(protein_list_file, 'r') as file:
return [line.strip() for line in file.readlines()]
def _generate_predictions(heterodata, model_path, model_config, target_type):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ProtHGT(heterodata, model_config['hidden_channels'], model_config['num_heads'], model_config['num_layers'], model_config['mlp_hidden_layers'], model_config['mlp_dropout'])
print('Loading model from', model_path)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
heterodata.to(device)
with torch.no_grad():
predictions, _ = model(heterodata.x_dict, heterodata.edge_index_dict, heterodata[("Protein", "protein_function", target_type)].edge_label_index, target_type)
return predictions
def _create_prediction_df(predictions, heterodata, protein_id, go_category):
prediction_df = pd.DataFrame({
'Protein': protein_id,
'GO_category': go_category,
'GO_term': heterodata[go_category]['id_mapping'].keys(),
'Probability': predictions.tolist()
})
prediction_df.sort_values(by='Probability', ascending=False, inplace=True)
prediction_df.reset_index(drop=True, inplace=True)
return prediction_df
def generate_prediction_df(protein_id, heterodata_path, model_path, model_config, go_category=None):
heterodata = _load_data(protein_id, go_category, heterodata_path)
if go_category:
predictions = _generate_predictions(heterodata, model_path, model_config, go_category)
prediction_df = _create_prediction_df(predictions, heterodata, protein_id, go_category)
return prediction_df
else:
all_predictions = []
for go_category in ['GO_term_F', 'GO_term_P', 'GO_term_C']:
predictions = _generate_predictions(heterodata, model_path, model_config, go_category)
category_df = _create_prediction_df(predictions, heterodata, protein_id, go_category)
all_predictions.append(category_df)
return pd.concat(all_predictions, ignore_index=True)