File size: 5,301 Bytes
24c5c6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from datasets import load_dataset
from torch_geometric.transforms import ToUndirected
import torch
from torch.nn import Linear
from torch_geometric.nn import HGTConv, MLP
import pandas as pd

class ProtHGT(torch.nn.Module):
    def __init__(self, data,hidden_channels, num_heads, num_layers, mlp_hidden_layers, mlp_dropout):
        super().__init__()

        self.lin_dict = torch.nn.ModuleDict({
            node_type: Linear(-1, hidden_channels)
            for node_type in data.node_types
        })

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads, group='sum')
            self.convs.append(conv)
        
        # self.left_linear = Linear(hidden_channels, hidden_channels)
        # self.right_linear = Linear(hidden_channels, hidden_channels)
        # self.sqrt_hd  = hidden_channels**1/2

        # self.mlp =MLP([2*hidden_channels, 128, 1], dropout=0.5, norm=None)
        self.mlp = MLP(mlp_hidden_layers , dropout=mlp_dropout, norm=None)

    def generate_embeddings(self, x_dict, edge_index_dict):
        # Generate updated embeddings through the GNN layers
        x_dict = {
            node_type: self.lin_dict[node_type](x).relu_()
            for node_type, x in x_dict.items()
        }

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            
        return x_dict

    def forward(self, x_dict, edge_index_dict, tr_edge_label_index, target_type, test=False):
        # Get updated embeddings
        x_dict = self.generate_embeddings(x_dict, edge_index_dict)

        # Make predictions
        row, col = tr_edge_label_index
        z = torch.cat([x_dict["Protein"][row], x_dict[target_type][col]], dim=-1)

        return self.mlp(z).view(-1), x_dict

def _load_data(protein_id, go_category=None, heterodata_path=''):
    heterodata = load_dataset(heterodata_path)
    
    # Remove unnecessary edge types in one go
    edge_types_to_remove = [
        ('Protein', 'protein_function', 'GO_term_F'),
        ('Protein', 'protein_function', 'GO_term_P'),
        ('Protein', 'protein_function', 'GO_term_C'),
        ('GO_term_F', 'rev_protein_function', 'Protein'),
        ('GO_term_P', 'rev_protein_function', 'Protein'),
        ('GO_term_C', 'rev_protein_function', 'Protein')
    ]
    
    for edge_type in edge_types_to_remove:
        if edge_type in heterodata:
            del heterodata[edge_type]

    # Remove reverse edges
    heterodata = {k: v for k, v in heterodata.items() if not isinstance(k, tuple) or 'rev' not in k[1]}

    protein_index = heterodata['Protein']['id_mapping'][protein_id]
    
    # Create edge indices more efficiently
    categories = [go_category] if go_category else ['GO_term_F', 'GO_term_P', 'GO_term_C']
    
    for category in categories:
        pairs = [(protein_index, i) for i in range(len(heterodata[category]))]
        heterodata['Protein', 'protein_function', category] = {'edge_index': pairs}
    
    return ToUndirected(merge=False)(heterodata)

def get_available_proteins(protein_list_file='data/available_proteins.txt'):
    with open(protein_list_file, 'r') as file:
        return [line.strip() for line in file.readlines()]

def _generate_predictions(heterodata, model_path, model_config, target_type):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ProtHGT(heterodata, model_config['hidden_channels'], model_config['num_heads'], model_config['num_layers'], model_config['mlp_hidden_layers'], model_config['mlp_dropout'])
    print('Loading model from', model_path)
    model.load_state_dict(torch.load(model_path, map_location=device))

    model.to(device)
    model.eval()
    heterodata.to(device)

    with torch.no_grad():
        predictions, _ = model(heterodata.x_dict, heterodata.edge_index_dict, heterodata[("Protein", "protein_function", target_type)].edge_label_index, target_type)
    return predictions

def _create_prediction_df(predictions, heterodata, protein_id, go_category):
    prediction_df = pd.DataFrame({
        'Protein': protein_id,
        'GO_category': go_category,
        'GO_term': heterodata[go_category]['id_mapping'].keys(),
        'Probability': predictions.tolist()
    })
    prediction_df.sort_values(by='Probability', ascending=False, inplace=True)
    prediction_df.reset_index(drop=True, inplace=True)
    return prediction_df


def generate_prediction_df(protein_id, heterodata_path, model_path, model_config, go_category=None):
    heterodata = _load_data(protein_id, go_category, heterodata_path)

    if go_category:
        predictions = _generate_predictions(heterodata, model_path, model_config, go_category)
        prediction_df = _create_prediction_df(predictions, heterodata, protein_id, go_category)
        return prediction_df
    
    else:
        all_predictions = []
        for go_category in ['GO_term_F', 'GO_term_P', 'GO_term_C']:
            predictions = _generate_predictions(heterodata, model_path, model_config, go_category)
            category_df = _create_prediction_df(predictions, heterodata, protein_id, go_category)
            all_predictions.append(category_df)

        return pd.concat(all_predictions, ignore_index=True)