Erva Ulusoy commited on
Commit
8aa6c67
·
1 Parent(s): 673a3cf

updated data load function

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -1
  2. run_prothgt_app.py +45 -26
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  pandas
2
  torch_geometric
3
- torch
 
 
1
  pandas
2
  torch_geometric
3
+ torch
4
+ gdown
run_prothgt_app.py CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
5
  import yaml
6
  import os
7
  from datasets import load_dataset
 
8
 
9
  class ProtHGT(torch.nn.Module):
10
  def __init__(self, data,hidden_channels, num_heads, num_layers, mlp_hidden_layers, mlp_dropout):
@@ -44,25 +45,8 @@ class ProtHGT(torch.nn.Module):
44
 
45
  return self.mlp(z).view(-1), x_dict
46
 
47
- def _load_data(protein_ids, go_category=None):
48
-
49
- # heterodata = load_dataset('HUBioDataLab/ProtHGT-KG', data_files="prothgt-kg.pt")
50
- heterodata = torch.load('data/prothgt-kg.pt')
51
- print('Loading data...')
52
- # Remove unnecessary edge types in one go
53
- edge_types_to_remove = [
54
- ('Protein', 'protein_function', 'GO_term_F'),
55
- ('Protein', 'protein_function', 'GO_term_P'),
56
- ('Protein', 'protein_function', 'GO_term_C'),
57
- ('GO_term_F', 'rev_protein_function', 'Protein'),
58
- ('GO_term_P', 'rev_protein_function', 'Protein'),
59
- ('GO_term_C', 'rev_protein_function', 'Protein')
60
- ]
61
-
62
- for edge_type in edge_types_to_remove:
63
- if edge_type in heterodata.edge_index_dict:
64
- del heterodata.edge_index_dict[edge_type]
65
-
66
  # Get protein indices for all input proteins
67
  protein_indices = [heterodata['Protein']['id_mapping'][pid] for pid in protein_ids]
68
 
@@ -136,20 +120,53 @@ def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_cate
136
  if isinstance(protein_ids, str):
137
  protein_ids = [protein_ids]
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
140
  for go_cat, model_config_path, model_path in zip(go_category, model_config_paths, model_paths):
141
  print(f'Generating predictions for {go_cat}...')
142
 
143
- # Load data
144
- heterodata = _load_data(protein_ids, go_cat)
145
 
146
- # Load model configuration
147
  with open(model_config_path, 'r') as file:
148
  model_config = yaml.safe_load(file)
149
 
150
  # Initialize model with configuration
151
  model = ProtHGT(
152
- heterodata,
153
  hidden_channels=model_config['hidden_channels'][0],
154
  num_heads=model_config['num_heads'],
155
  num_layers=model_config['num_layers'],
@@ -162,16 +179,18 @@ def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_cate
162
  print(f'Loaded model weights from {model_path}')
163
 
164
  # Generate predictions
165
- predictions = _generate_predictions(heterodata, model, go_cat)
166
- prediction_df = _create_prediction_df(predictions, heterodata, protein_ids, go_cat)
167
  all_predictions.append(prediction_df)
168
 
169
  # Clean up memory
170
- del heterodata
171
  del model
172
  del predictions
173
  torch.cuda.empty_cache() # Clear CUDA cache if using GPU
174
 
 
 
175
  # Combine all predictions
176
  final_df = pd.concat(all_predictions, ignore_index=True)
177
 
 
5
  import yaml
6
  import os
7
  from datasets import load_dataset
8
+ import gdown
9
 
10
  class ProtHGT(torch.nn.Module):
11
  def __init__(self, data,hidden_channels, num_heads, num_layers, mlp_hidden_layers, mlp_dropout):
 
45
 
46
  return self.mlp(z).view(-1), x_dict
47
 
48
+ def _load_data(heterodata, protein_ids, go_category=None):
49
+ """Process the loaded heterodata for specific proteins and GO categories."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # Get protein indices for all input proteins
51
  protein_indices = [heterodata['Protein']['id_mapping'][pid] for pid in protein_ids]
52
 
 
120
  if isinstance(protein_ids, str):
121
  protein_ids = [protein_ids]
122
 
123
+ # Load dataset once
124
+ # heterodata = load_dataset('HUBioDataLab/ProtHGT-KG', data_files="prothgt-kg.json.gz")
125
+ print('Loading data...')
126
+ file_id = "18u1o2sm8YjMo9joFw4Ilwvg0-rUU0PXK"
127
+ output = "data/prothgt-kg.pt"
128
+
129
+ url = f"https://drive.google.com/uc?id={file_id}"
130
+ print(f"Downloading file from {url}...")
131
+ try:
132
+ gdown.download(url, output, quiet=False)
133
+ print(f"File downloaded to {output}")
134
+ except Exception as e:
135
+ print(f"Error downloading file: {e}")
136
+ raise
137
+
138
+ heterodata = torch.load(output)
139
+ print(heterodata.edge_types)
140
+
141
+ # Remove unnecessary edge types
142
+ edge_types_to_remove = [
143
+ ('Protein', 'protein_function', 'GO_term_F'),
144
+ ('Protein', 'protein_function', 'GO_term_P'),
145
+ ('Protein', 'protein_function', 'GO_term_C'),
146
+ ('GO_term_F', 'rev_protein_function', 'Protein'),
147
+ ('GO_term_P', 'rev_protein_function', 'Protein'),
148
+ ('GO_term_C', 'rev_protein_function', 'Protein')
149
+ ]
150
+
151
+ for edge_type in edge_types_to_remove:
152
+ if edge_type in heterodata.edge_index_dict:
153
+ del heterodata.edge_index_dict[edge_type]
154
+
155
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
156
+
157
  for go_cat, model_config_path, model_path in zip(go_category, model_config_paths, model_paths):
158
  print(f'Generating predictions for {go_cat}...')
159
 
160
+ # Process data for current GO category
161
+ processed_data = _load_data(heterodata, protein_ids, go_cat)
162
 
163
+ # Load model config
164
  with open(model_config_path, 'r') as file:
165
  model_config = yaml.safe_load(file)
166
 
167
  # Initialize model with configuration
168
  model = ProtHGT(
169
+ processed_data,
170
  hidden_channels=model_config['hidden_channels'][0],
171
  num_heads=model_config['num_heads'],
172
  num_layers=model_config['num_layers'],
 
179
  print(f'Loaded model weights from {model_path}')
180
 
181
  # Generate predictions
182
+ predictions = _generate_predictions(processed_data, model, go_cat)
183
+ prediction_df = _create_prediction_df(predictions, processed_data, protein_ids, go_cat)
184
  all_predictions.append(prediction_df)
185
 
186
  # Clean up memory
187
+ del processed_data
188
  del model
189
  del predictions
190
  torch.cuda.empty_cache() # Clear CUDA cache if using GPU
191
 
192
+ del heterodata
193
+
194
  # Combine all predictions
195
  final_df = pd.concat(all_predictions, ignore_index=True)
196