Spaces:
Running
Running
File size: 5,301 Bytes
24c5c6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from datasets import load_dataset
from torch_geometric.transforms import ToUndirected
import torch
from torch.nn import Linear
from torch_geometric.nn import HGTConv, MLP
import pandas as pd
class ProtHGT(torch.nn.Module):
def __init__(self, data,hidden_channels, num_heads, num_layers, mlp_hidden_layers, mlp_dropout):
super().__init__()
self.lin_dict = torch.nn.ModuleDict({
node_type: Linear(-1, hidden_channels)
for node_type in data.node_types
})
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads, group='sum')
self.convs.append(conv)
# self.left_linear = Linear(hidden_channels, hidden_channels)
# self.right_linear = Linear(hidden_channels, hidden_channels)
# self.sqrt_hd = hidden_channels**1/2
# self.mlp =MLP([2*hidden_channels, 128, 1], dropout=0.5, norm=None)
self.mlp = MLP(mlp_hidden_layers , dropout=mlp_dropout, norm=None)
def generate_embeddings(self, x_dict, edge_index_dict):
# Generate updated embeddings through the GNN layers
x_dict = {
node_type: self.lin_dict[node_type](x).relu_()
for node_type, x in x_dict.items()
}
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
return x_dict
def forward(self, x_dict, edge_index_dict, tr_edge_label_index, target_type, test=False):
# Get updated embeddings
x_dict = self.generate_embeddings(x_dict, edge_index_dict)
# Make predictions
row, col = tr_edge_label_index
z = torch.cat([x_dict["Protein"][row], x_dict[target_type][col]], dim=-1)
return self.mlp(z).view(-1), x_dict
def _load_data(protein_id, go_category=None, heterodata_path=''):
heterodata = load_dataset(heterodata_path)
# Remove unnecessary edge types in one go
edge_types_to_remove = [
('Protein', 'protein_function', 'GO_term_F'),
('Protein', 'protein_function', 'GO_term_P'),
('Protein', 'protein_function', 'GO_term_C'),
('GO_term_F', 'rev_protein_function', 'Protein'),
('GO_term_P', 'rev_protein_function', 'Protein'),
('GO_term_C', 'rev_protein_function', 'Protein')
]
for edge_type in edge_types_to_remove:
if edge_type in heterodata:
del heterodata[edge_type]
# Remove reverse edges
heterodata = {k: v for k, v in heterodata.items() if not isinstance(k, tuple) or 'rev' not in k[1]}
protein_index = heterodata['Protein']['id_mapping'][protein_id]
# Create edge indices more efficiently
categories = [go_category] if go_category else ['GO_term_F', 'GO_term_P', 'GO_term_C']
for category in categories:
pairs = [(protein_index, i) for i in range(len(heterodata[category]))]
heterodata['Protein', 'protein_function', category] = {'edge_index': pairs}
return ToUndirected(merge=False)(heterodata)
def get_available_proteins(protein_list_file='data/available_proteins.txt'):
with open(protein_list_file, 'r') as file:
return [line.strip() for line in file.readlines()]
def _generate_predictions(heterodata, model_path, model_config, target_type):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ProtHGT(heterodata, model_config['hidden_channels'], model_config['num_heads'], model_config['num_layers'], model_config['mlp_hidden_layers'], model_config['mlp_dropout'])
print('Loading model from', model_path)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
heterodata.to(device)
with torch.no_grad():
predictions, _ = model(heterodata.x_dict, heterodata.edge_index_dict, heterodata[("Protein", "protein_function", target_type)].edge_label_index, target_type)
return predictions
def _create_prediction_df(predictions, heterodata, protein_id, go_category):
prediction_df = pd.DataFrame({
'Protein': protein_id,
'GO_category': go_category,
'GO_term': heterodata[go_category]['id_mapping'].keys(),
'Probability': predictions.tolist()
})
prediction_df.sort_values(by='Probability', ascending=False, inplace=True)
prediction_df.reset_index(drop=True, inplace=True)
return prediction_df
def generate_prediction_df(protein_id, heterodata_path, model_path, model_config, go_category=None):
heterodata = _load_data(protein_id, go_category, heterodata_path)
if go_category:
predictions = _generate_predictions(heterodata, model_path, model_config, go_category)
prediction_df = _create_prediction_df(predictions, heterodata, protein_id, go_category)
return prediction_df
else:
all_predictions = []
for go_category in ['GO_term_F', 'GO_term_P', 'GO_term_C']:
predictions = _generate_predictions(heterodata, model_path, model_config, go_category)
category_df = _create_prediction_df(predictions, heterodata, protein_id, go_category)
all_predictions.append(category_df)
return pd.concat(all_predictions, ignore_index=True)
|