Spaces:
Sleeping
Sleeping
Erva Ulusoy
commited on
Commit
·
e6abfd3
1
Parent(s):
7f3941f
fixed bugs on new edge index assignments
Browse files- run_prothgt_app.py +14 -14
run_prothgt_app.py
CHANGED
@@ -6,6 +6,7 @@ import yaml
|
|
6 |
import os
|
7 |
from datasets import load_dataset
|
8 |
import gdown
|
|
|
9 |
|
10 |
class ProtHGT(torch.nn.Module):
|
11 |
def __init__(self, data,hidden_channels, num_heads, num_layers, mlp_hidden_layers, mlp_dropout):
|
@@ -45,23 +46,22 @@ class ProtHGT(torch.nn.Module):
|
|
45 |
|
46 |
return self.mlp(z).view(-1), x_dict
|
47 |
|
48 |
-
def _load_data(heterodata, protein_ids, go_category
|
49 |
"""Process the loaded heterodata for specific proteins and GO categories."""
|
50 |
# Get protein indices for all input proteins
|
51 |
protein_indices = [heterodata['Protein']['id_mapping'][pid] for pid in protein_ids]
|
52 |
-
|
53 |
-
# Create edge indices for prediction
|
54 |
-
categories = [go_category] if go_category else ['GO_term_F', 'GO_term_P', 'GO_term_C']
|
55 |
-
|
56 |
-
for category in categories:
|
57 |
-
# Create pairs for all proteins with all GO terms
|
58 |
-
n_terms = len(heterodata[category]['id_mapping'])
|
59 |
-
protein_indices_repeated = torch.tensor(protein_indices).repeat_interleave(n_terms)
|
60 |
-
term_indices = torch.arange(n_terms).repeat(len(protein_indices))
|
61 |
-
|
62 |
-
edge_index = torch.stack([protein_indices_repeated, term_indices])
|
63 |
-
heterodata.edge_index_dict[('Protein', 'protein_function', category)] = edge_index
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
return heterodata
|
66 |
|
67 |
def get_available_proteins(protein_list_file='data/available_proteins.txt'):
|
@@ -169,7 +169,7 @@ def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_cate
|
|
169 |
print(f'Generating predictions for {go_cat}...')
|
170 |
|
171 |
# Process data for current GO category
|
172 |
-
processed_data = _load_data(heterodata, protein_ids, go_cat)
|
173 |
|
174 |
# Load model config
|
175 |
with open(model_config_path, 'r') as file:
|
|
|
6 |
import os
|
7 |
from datasets import load_dataset
|
8 |
import gdown
|
9 |
+
import copy
|
10 |
|
11 |
class ProtHGT(torch.nn.Module):
|
12 |
def __init__(self, data,hidden_channels, num_heads, num_layers, mlp_hidden_layers, mlp_dropout):
|
|
|
46 |
|
47 |
return self.mlp(z).view(-1), x_dict
|
48 |
|
49 |
+
def _load_data(heterodata, protein_ids, go_category):
|
50 |
"""Process the loaded heterodata for specific proteins and GO categories."""
|
51 |
# Get protein indices for all input proteins
|
52 |
protein_indices = [heterodata['Protein']['id_mapping'][pid] for pid in protein_ids]
|
53 |
+
n_terms = len(heterodata[go_category]['id_mapping'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
+
all_edges = []
|
56 |
+
for protein_idx in protein_indices:
|
57 |
+
for term_idx in range(n_terms):
|
58 |
+
all_edges.append([protein_idx, term_idx])
|
59 |
+
|
60 |
+
edge_index = torch.tensor(all_edges).t()
|
61 |
+
|
62 |
+
heterodata[('Protein', 'protein_function', go_category)].edge_index = edge_index
|
63 |
+
heterodata[(go_category, 'rev_protein_function', 'Protein')].edge_index = torch.stack([edge_index[1], edge_index[0]])
|
64 |
+
|
65 |
return heterodata
|
66 |
|
67 |
def get_available_proteins(protein_list_file='data/available_proteins.txt'):
|
|
|
169 |
print(f'Generating predictions for {go_cat}...')
|
170 |
|
171 |
# Process data for current GO category
|
172 |
+
processed_data = _load_data(copy.deepcopy(heterodata), protein_ids, go_cat)
|
173 |
|
174 |
# Load model config
|
175 |
with open(model_config_path, 'r') as file:
|