Erva Ulusoy commited on
Commit
e6abfd3
·
1 Parent(s): 7f3941f

fixed bugs on new edge index assignments

Browse files
Files changed (1) hide show
  1. run_prothgt_app.py +14 -14
run_prothgt_app.py CHANGED
@@ -6,6 +6,7 @@ import yaml
6
  import os
7
  from datasets import load_dataset
8
  import gdown
 
9
 
10
  class ProtHGT(torch.nn.Module):
11
  def __init__(self, data,hidden_channels, num_heads, num_layers, mlp_hidden_layers, mlp_dropout):
@@ -45,23 +46,22 @@ class ProtHGT(torch.nn.Module):
45
 
46
  return self.mlp(z).view(-1), x_dict
47
 
48
- def _load_data(heterodata, protein_ids, go_category=None):
49
  """Process the loaded heterodata for specific proteins and GO categories."""
50
  # Get protein indices for all input proteins
51
  protein_indices = [heterodata['Protein']['id_mapping'][pid] for pid in protein_ids]
52
-
53
- # Create edge indices for prediction
54
- categories = [go_category] if go_category else ['GO_term_F', 'GO_term_P', 'GO_term_C']
55
-
56
- for category in categories:
57
- # Create pairs for all proteins with all GO terms
58
- n_terms = len(heterodata[category]['id_mapping'])
59
- protein_indices_repeated = torch.tensor(protein_indices).repeat_interleave(n_terms)
60
- term_indices = torch.arange(n_terms).repeat(len(protein_indices))
61
-
62
- edge_index = torch.stack([protein_indices_repeated, term_indices])
63
- heterodata.edge_index_dict[('Protein', 'protein_function', category)] = edge_index
64
 
 
 
 
 
 
 
 
 
 
 
65
  return heterodata
66
 
67
  def get_available_proteins(protein_list_file='data/available_proteins.txt'):
@@ -169,7 +169,7 @@ def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_cate
169
  print(f'Generating predictions for {go_cat}...')
170
 
171
  # Process data for current GO category
172
- processed_data = _load_data(heterodata, protein_ids, go_cat)
173
 
174
  # Load model config
175
  with open(model_config_path, 'r') as file:
 
6
  import os
7
  from datasets import load_dataset
8
  import gdown
9
+ import copy
10
 
11
  class ProtHGT(torch.nn.Module):
12
  def __init__(self, data,hidden_channels, num_heads, num_layers, mlp_hidden_layers, mlp_dropout):
 
46
 
47
  return self.mlp(z).view(-1), x_dict
48
 
49
+ def _load_data(heterodata, protein_ids, go_category):
50
  """Process the loaded heterodata for specific proteins and GO categories."""
51
  # Get protein indices for all input proteins
52
  protein_indices = [heterodata['Protein']['id_mapping'][pid] for pid in protein_ids]
53
+ n_terms = len(heterodata[go_category]['id_mapping'])
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ all_edges = []
56
+ for protein_idx in protein_indices:
57
+ for term_idx in range(n_terms):
58
+ all_edges.append([protein_idx, term_idx])
59
+
60
+ edge_index = torch.tensor(all_edges).t()
61
+
62
+ heterodata[('Protein', 'protein_function', go_category)].edge_index = edge_index
63
+ heterodata[(go_category, 'rev_protein_function', 'Protein')].edge_index = torch.stack([edge_index[1], edge_index[0]])
64
+
65
  return heterodata
66
 
67
  def get_available_proteins(protein_list_file='data/available_proteins.txt'):
 
169
  print(f'Generating predictions for {go_cat}...')
170
 
171
  # Process data for current GO category
172
+ processed_data = _load_data(copy.deepcopy(heterodata), protein_ids, go_cat)
173
 
174
  # Load model config
175
  with open(model_config_path, 'r') as file: