Sukanyaaa commited on
Commit
82f0f49
·
1 Parent(s): 0e58a55

fix train.py

Browse files
Files changed (1) hide show
  1. train.py +611 -0
train.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import time
3
+ import json
4
+ import gradio as gr
5
+ from gradio_molecule3d import Molecule3D
6
+ import torch
7
+ from pinder.core import get_pinder_location
8
+ get_pinder_location()
9
+ from pytorch_lightning import LightningModule
10
+
11
+ import torch
12
+ import lightning.pytorch as pl
13
+ import torch.nn.functional as F
14
+
15
+ import torch.nn as nn
16
+ import torchmetrics
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch_geometric.nn import MessagePassing
20
+ from torch_geometric.nn import global_mean_pool
21
+ from torch.nn import Sequential, Linear, BatchNorm1d, ReLU
22
+ from torch_scatter import scatter
23
+ from torch.nn import Module
24
+
25
+
26
+ import pinder.core as pinder
27
+ pinder.__version__
28
+ from torch_geometric.loader import DataLoader
29
+ from pinder.core.loader.dataset import get_geo_loader
30
+ from pinder.core import download_dataset
31
+ from pinder.core import get_index
32
+ from pinder.core import get_metadata
33
+ from pathlib import Path
34
+ import pandas as pd
35
+ from pinder.core import PinderSystem
36
+ import torch
37
+ from pinder.core.loader.dataset import PPIDataset
38
+ from pinder.core.loader.geodata import NodeRepresentation
39
+ import pickle
40
+ from pinder.core import get_index, PinderSystem
41
+ from torch_geometric.data import HeteroData
42
+ import os
43
+
44
+ from enum import Enum
45
+
46
+ import numpy as np
47
+ import torch
48
+ import lightning.pytorch as pl
49
+ from numpy.typing import NDArray
50
+ from torch_geometric.data import HeteroData
51
+
52
+ from pinder.core.index.system import PinderSystem
53
+ from pinder.core.loader.structure import Structure
54
+ from pinder.core.utils import constants as pc
55
+ from pinder.core.utils.log import setup_logger
56
+ from pinder.core.index.system import _align_monomers_with_mask
57
+ from pinder.core.loader.structure import Structure
58
+
59
+ import torch
60
+ import torch.nn as nn
61
+ import torch.nn.functional as F
62
+ from torch_geometric.nn import MessagePassing
63
+ from torch_geometric.nn import global_mean_pool
64
+ from torch.nn import Sequential, Linear, BatchNorm1d, ReLU
65
+ from torch_scatter import scatter
66
+ from torch.nn import Module
67
+ import time
68
+ from torch_geometric.nn import global_max_pool
69
+ import copy
70
+ import inspect
71
+ import warnings
72
+ from typing import Optional, Tuple, Union
73
+
74
+ import torch
75
+ from torch import Tensor
76
+
77
+ from torch_geometric.data import Data, Dataset, HeteroData
78
+ from torch_geometric.data.feature_store import FeatureStore
79
+ from torch_geometric.data.graph_store import GraphStore
80
+ from torch_geometric.loader import (
81
+ LinkLoader,
82
+ LinkNeighborLoader,
83
+ NeighborLoader,
84
+ NodeLoader,
85
+ )
86
+ from torch_geometric.loader.dataloader import DataLoader
87
+ from torch_geometric.loader.utils import get_edge_label_index, get_input_nodes
88
+ from torch_geometric.sampler import BaseSampler, NeighborSampler
89
+ from torch_geometric.typing import InputEdges, InputNodes
90
+
91
+ try:
92
+ from lightning.pytorch import LightningDataModule as PLLightningDataModule
93
+ no_pytorch_lightning = False
94
+ except (ImportError, ModuleNotFoundError):
95
+ PLLightningDataModule = object
96
+ no_pytorch_lightning = True
97
+
98
+ from lightning.pytorch.callbacks import ModelCheckpoint
99
+ from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
100
+ from lightning.pytorch.callbacks.early_stopping import EarlyStopping
101
+ from torch_geometric.data.lightning.datamodule import LightningDataset
102
+ from pytorch_lightning.loggers.wandb import WandbLogger
103
+ def get_system(system_id: str) -> PinderSystem:
104
+ return PinderSystem(system_id)
105
+ from Bio import PDB
106
+ from Bio.PDB.PDBIO import PDBIO
107
+
108
+ log = setup_logger(__name__)
109
+
110
+ try:
111
+ from torch_cluster import knn_graph
112
+
113
+ torch_cluster_installed = True
114
+ except ImportError as e:
115
+ log.warning(
116
+ "torch-cluster is not installed!"
117
+ "Please install the appropriate library for your pytorch installation."
118
+ "See https://github.com/rusty1s/pytorch_cluster/issues/185 for background."
119
+ )
120
+ torch_cluster_installed = False
121
+
122
+
123
+ def structure2tensor(
124
+ atom_coordinates: NDArray[np.double] | None = None,
125
+ atom_types: NDArray[np.str_] | None = None,
126
+ element_types: NDArray[np.str_] | None = None,
127
+ residue_coordinates: NDArray[np.double] | None = None,
128
+ residue_ids: NDArray[np.int_] | None = None,
129
+ residue_types: NDArray[np.str_] | None = None,
130
+ chain_ids: NDArray[np.str_] | None = None,
131
+ dtype: torch.dtype = torch.float32,
132
+ ) -> dict[str, torch.Tensor]:
133
+ property_dict = {}
134
+ if atom_types is not None:
135
+ unknown_name_idx = max(pc.ALL_ATOM_POSNS.values()) + 1
136
+ types_array_at = np.zeros((len(atom_types), 1))
137
+ for i, name in enumerate(atom_types):
138
+ types_array_at[i] = pc.ALL_ATOM_POSNS.get(name, unknown_name_idx)
139
+ property_dict["atom_types"] = torch.tensor(types_array_at).type(dtype)
140
+ if element_types is not None:
141
+ types_array_ele = np.zeros((len(element_types), 1))
142
+ for i, name in enumerate(element_types):
143
+ types_array_ele[i] = pc.ELE2NUM.get(name, pc.ELE2NUM["other"])
144
+ property_dict["element_types"] = torch.tensor(types_array_ele).type(dtype)
145
+ if residue_types is not None:
146
+ unknown_name_idx = max(pc.AA_TO_INDEX.values()) + 1
147
+ types_array_res = np.zeros((len(residue_types), 1))
148
+ for i, name in enumerate(residue_types):
149
+ types_array_res[i] = pc.AA_TO_INDEX.get(name, unknown_name_idx)
150
+ property_dict["residue_types"] = torch.tensor(types_array_res).type(dtype)
151
+
152
+ if atom_coordinates is not None:
153
+ property_dict["atom_coordinates"] = torch.tensor(atom_coordinates, dtype=dtype)
154
+
155
+ if residue_coordinates is not None:
156
+ property_dict["residue_coordinates"] = torch.tensor(
157
+ residue_coordinates, dtype=dtype
158
+ )
159
+ if residue_ids is not None:
160
+ property_dict["residue_ids"] = torch.tensor(residue_ids, dtype=dtype)
161
+ if chain_ids is not None:
162
+ property_dict["chain_ids"] = torch.zeros(len(chain_ids), dtype=dtype)
163
+ property_dict["chain_ids"][chain_ids == "L"] = 1
164
+ return property_dict
165
+
166
+
167
+ class NodeRepresentation(Enum):
168
+ Surface = "surface"
169
+ Atom = "atom"
170
+ Residue = "residue"
171
+
172
+
173
+ class PairedPDB(HeteroData): # type: ignore
174
+ @classmethod
175
+ def from_tuple_system(
176
+ cls,
177
+
178
+ tupal: tuple = (Structure , Structure , Structure),
179
+
180
+ add_edges: bool = True,
181
+ k: int = 10,
182
+
183
+ ) -> PairedPDB:
184
+ return cls.from_structure_pair(
185
+
186
+ holo=tupal[0],
187
+ apo=tupal[1],
188
+ add_edges=add_edges,
189
+ k=k,
190
+ )
191
+
192
+ @classmethod
193
+ def from_structure_pair(
194
+ cls,
195
+
196
+ holo: Structure,
197
+ apo: Structure,
198
+
199
+ add_edges: bool = True,
200
+ k: int = 10,
201
+ ) -> PairedPDB:
202
+ graph = cls()
203
+ holo_calpha = holo.filter("atom_name", mask=["CA"])
204
+ apo_calpha = apo.filter("atom_name", mask=["CA"])
205
+ r_h = (holo.dataframe['chain_id'] == 'R').sum()
206
+ r_a = (apo.dataframe['chain_id'] == 'R').sum()
207
+
208
+ holo_r_props = structure2tensor(
209
+ atom_coordinates=holo.coords[:r_h],
210
+ atom_types=holo.atom_array.atom_name[:r_h],
211
+ element_types=holo.atom_array.element[:r_h],
212
+ residue_coordinates=holo_calpha.coords[:r_h],
213
+ residue_types=holo_calpha.atom_array.res_name[:r_h],
214
+ residue_ids=holo_calpha.atom_array.res_id[:r_h],
215
+ )
216
+ holo_l_props = structure2tensor(
217
+ atom_coordinates=holo.coords[r_h:],
218
+
219
+ atom_types=holo.atom_array.atom_name[r_h:],
220
+ element_types=holo.atom_array.element[r_h:],
221
+ residue_coordinates=holo_calpha.coords[r_h:],
222
+ residue_types=holo_calpha.atom_array.res_name[r_h:],
223
+ residue_ids=holo_calpha.atom_array.res_id[r_h:],
224
+ )
225
+ apo_r_props = structure2tensor(
226
+ atom_coordinates=apo.coords[:r_a],
227
+ atom_types=apo.atom_array.atom_name[:r_a],
228
+ element_types=apo.atom_array.element[:r_a],
229
+ residue_coordinates=apo_calpha.coords[:r_a],
230
+ residue_types=apo_calpha.atom_array.res_name[:r_a],
231
+ residue_ids=apo_calpha.atom_array.res_id[:r_a],
232
+ )
233
+ apo_l_props = structure2tensor(
234
+ atom_coordinates=apo.coords[r_a:],
235
+ atom_types=apo.atom_array.atom_name[r_a:],
236
+ element_types=apo.atom_array.element[r_a:],
237
+ residue_coordinates=apo_calpha.coords[r_a:],
238
+ residue_types=apo_calpha.atom_array.res_name[r_a:],
239
+ residue_ids=apo_calpha.atom_array.res_id[r_a:],
240
+ )
241
+
242
+
243
+
244
+ graph["ligand"].x = apo_l_props["atom_types"]
245
+ graph["ligand"].pos = apo_l_props["atom_coordinates"]
246
+ graph["receptor"].x = apo_r_props["atom_types"]
247
+ graph["receptor"].pos = apo_r_props["atom_coordinates"]
248
+ graph["ligand"].y = holo_l_props["atom_coordinates"]
249
+ # graph["ligand"].pos = holo_l_props["atom_coordinates"]
250
+ graph["receptor"].y = holo_r_props["atom_coordinates"]
251
+ # graph["receptor"].pos = holo_r_props["atom_coordinates"]
252
+ if add_edges and torch_cluster_installed:
253
+ graph["ligand"].edge_index = knn_graph(
254
+ graph["ligand"].pos, k=k
255
+ )
256
+ graph["receptor"].edge_index = knn_graph(
257
+ graph["receptor"].pos, k=k
258
+ )
259
+ # graph["ligand"].edge_index = knn_graph(
260
+ # graph["ligand"].pos, k=k
261
+ # )
262
+ # graph["receptor"].edge_index = knn_graph(
263
+ # graph["receptor"].pos, k=k
264
+ # )
265
+
266
+ return graph
267
+
268
+ # To create dataset, we have used only PINDER datyaset with following steps as follows:
269
+
270
+ # log = setup_logger(__name__)
271
+
272
+ # try:
273
+ # from torch_cluster import knn_graph
274
+
275
+ # torch_cluster_installed = True
276
+ # except ImportError as e:
277
+ # log.warning(
278
+ # "torch-cluster is not installed!"
279
+ # "Please install the appropriate library for your pytorch installation."
280
+ # "See https://github.com/rusty1s/pytorch_cluster/issues/185 for background."
281
+ # )
282
+ # torch_cluster_installed = False
283
+
284
+
285
+ # def structure2tensor(
286
+ # atom_coordinates: NDArray[np.double] | None = None,
287
+ # atom_types: NDArray[np.str_] | None = None,
288
+ # element_types: NDArray[np.str_] | None = None,
289
+ # residue_coordinates: NDArray[np.double] | None = None,
290
+ # residue_ids: NDArray[np.int_] | None = None,
291
+ # residue_types: NDArray[np.str_] | None = None,
292
+ # chain_ids: NDArray[np.str_] | None = None,
293
+ # dtype: torch.dtype = torch.float32,
294
+ # ) -> dict[str, torch.Tensor]:
295
+ # property_dict = {}
296
+ # if atom_types is not None:
297
+ # unknown_name_idx = max(pc.ALL_ATOM_POSNS.values()) + 1
298
+ # types_array_at = np.zeros((len(atom_types), 1))
299
+ # for i, name in enumerate(atom_types):
300
+ # types_array_at[i] = pc.ALL_ATOM_POSNS.get(name, unknown_name_idx)
301
+ # property_dict["atom_types"] = torch.tensor(types_array_at).type(dtype)
302
+ # if element_types is not None:
303
+ # types_array_ele = np.zeros((len(element_types), 1))
304
+ # for i, name in enumerate(element_types):
305
+ # types_array_ele[i] = pc.ELE2NUM.get(name, pc.ELE2NUM["other"])
306
+ # property_dict["element_types"] = torch.tensor(types_array_ele).type(dtype)
307
+ # if residue_types is not None:
308
+ # unknown_name_idx = max(pc.AA_TO_INDEX.values()) + 1
309
+ # types_array_res = np.zeros((len(residue_types), 1))
310
+ # for i, name in enumerate(residue_types):
311
+ # types_array_res[i] = pc.AA_TO_INDEX.get(name, unknown_name_idx)
312
+ # property_dict["residue_types"] = torch.tensor(types_array_res).type(dtype)
313
+
314
+ # if atom_coordinates is not None:
315
+ # property_dict["atom_coordinates"] = torch.tensor(atom_coordinates, dtype=dtype)
316
+
317
+ # if residue_coordinates is not None:
318
+ # property_dict["residue_coordinates"] = torch.tensor(
319
+ # residue_coordinates, dtype=dtype
320
+ # )
321
+ # if residue_ids is not None:
322
+ # property_dict["residue_ids"] = torch.tensor(residue_ids, dtype=dtype)
323
+ # if chain_ids is not None:
324
+ # property_dict["chain_ids"] = torch.zeros(len(chain_ids), dtype=dtype)
325
+ # property_dict["chain_ids"][chain_ids == "L"] = 1
326
+ # return property_dict
327
+
328
+
329
+ # class NodeRepresentation(Enum):
330
+ # Surface = "surface"
331
+ # Atom = "atom"
332
+ # Residue = "residue"
333
+
334
+
335
+ # class PairedPDB(HeteroData): # type: ignore
336
+ # @classmethod
337
+ # def from_tuple_system(
338
+ # cls,
339
+
340
+ # tupal: tuple = (Structure , Structure , Structure),
341
+
342
+ # add_edges: bool = True,
343
+ # k: int = 10,
344
+
345
+ # ) -> PairedPDB:
346
+ # return cls.from_structure_pair(
347
+
348
+ # holo=tupal[0],
349
+ # apo=tupal[1],
350
+ # add_edges=add_edges,
351
+ # k=k,
352
+ # )
353
+
354
+ # @classmethod
355
+ # def from_structure_pair(
356
+ # cls,
357
+
358
+ # holo: Structure,
359
+ # apo: Structure,
360
+
361
+ # add_edges: bool = True,
362
+ # k: int = 10,
363
+ # ) -> PairedPDB:
364
+ # graph = cls()
365
+ # holo_calpha = holo.filter("atom_name", mask=["CA"])
366
+ # apo_calpha = apo.filter("atom_name", mask=["CA"])
367
+ # r_h = (holo.dataframe['chain_id'] == 'R').sum()
368
+ # r_a = (apo.dataframe['chain_id'] == 'R').sum()
369
+
370
+ # holo_r_props = structure2tensor(
371
+ # atom_coordinates=holo.coords[:r_h],
372
+ # atom_types=holo.atom_array.atom_name[:r_h],
373
+ # element_types=holo.atom_array.element[:r_h],
374
+ # residue_coordinates=holo_calpha.coords[:r_h],
375
+ # residue_types=holo_calpha.atom_array.res_name[:r_h],
376
+ # residue_ids=holo_calpha.atom_array.res_id[:r_h],
377
+ # )
378
+ # holo_l_props = structure2tensor(
379
+ # atom_coordinates=holo.coords[r_h:],
380
+
381
+ # atom_types=holo.atom_array.atom_name[r_h:],
382
+ # element_types=holo.atom_array.element[r_h:],
383
+ # residue_coordinates=holo_calpha.coords[r_h:],
384
+ # residue_types=holo_calpha.atom_array.res_name[r_h:],
385
+ # residue_ids=holo_calpha.atom_array.res_id[r_h:],
386
+ # )
387
+ # apo_r_props = structure2tensor(
388
+ # atom_coordinates=apo.coords[:r_a],
389
+ # atom_types=apo.atom_array.atom_name[:r_a],
390
+ # element_types=apo.atom_array.element[:r_a],
391
+ # residue_coordinates=apo_calpha.coords[:r_a],
392
+ # residue_types=apo_calpha.atom_array.res_name[:r_a],
393
+ # residue_ids=apo_calpha.atom_array.res_id[:r_a],
394
+ # )
395
+ # apo_l_props = structure2tensor(
396
+ # atom_coordinates=apo.coords[r_a:],
397
+ # atom_types=apo.atom_array.atom_name[r_a:],
398
+ # element_types=apo.atom_array.element[r_a:],
399
+ # residue_coordinates=apo_calpha.coords[r_a:],
400
+ # residue_types=apo_calpha.atom_array.res_name[r_a:],
401
+ # residue_ids=apo_calpha.atom_array.res_id[r_a:],
402
+ # )
403
+
404
+
405
+
406
+ # graph["ligand"].x = apo_l_props["atom_types"]
407
+ # graph["ligand"].pos = apo_l_props["atom_coordinates"]
408
+ # graph["receptor"].x = apo_r_props["atom_types"]
409
+ # graph["receptor"].pos = apo_r_props["atom_coordinates"]
410
+ # graph["ligand"].y = holo_l_props["atom_coordinates"]
411
+ # # graph["ligand"].pos = holo_l_props["atom_coordinates"]
412
+ # graph["receptor"].y = holo_r_props["atom_coordinates"]
413
+ # # graph["receptor"].pos = holo_r_props["atom_coordinates"]
414
+ # if add_edges and torch_cluster_installed:
415
+ # graph["ligand"].edge_index = knn_graph(
416
+ # graph["ligand"].pos, k=k
417
+ # )
418
+ # graph["receptor"].edge_index = knn_graph(
419
+ # graph["receptor"].pos, k=k
420
+ # )
421
+ # # graph["ligand"].edge_index = knn_graph(
422
+ # # graph["ligand"].pos, k=k
423
+ # # )
424
+ # # graph["receptor"].edge_index = knn_graph(
425
+ # # graph["receptor"].pos, k=k
426
+ # # )
427
+
428
+ # return graph
429
+
430
+ # index = get_index()
431
+ # # train = index[index.split == "train"].copy()
432
+ # # val = index[index.split == "val"].copy()
433
+ # # test = index[index.split == "test"].copy()
434
+ # # train_filtered = train[(train['apo_R'] == True) & (train['apo_L'] == True)].copy()
435
+ # # val_filtered = val[(val['apo_R'] == True) & (val['apo_L'] == True)].copy()
436
+ # # test_filtered = test[(test['apo_R'] == True) & (test['apo_L'] == True)].copy()
437
+
438
+ # # train_apo = [get_system(train_filtered.id.iloc[i]).create_masked_bound_unbound_complexes(
439
+ # # monomer_types=["apo"], renumber_residues=True
440
+ # # ) for i in range(0, 10000)]
441
+
442
+ # # train_new_apo11 = [get_system(train_filtered.id.iloc[i]).create_masked_bound_unbound_complexes(
443
+ # # monomer_types=["apo"], renumber_residues=True
444
+ # # ) for i in range(10000,10908)]
445
+
446
+ # # train_new_apo12 = [get_system(train_filtered.id.iloc[i]).create_masked_bound_unbound_complexes(
447
+ # # # monomer_types=["apo"], renumber_residues=True
448
+ # # ) for i in range(10908,11816)]
449
+
450
+ # # val_new_apo1 = [get_system(val_filtered.id.iloc[i]).create_masked_bound_unbound_complexes(
451
+ # # monomer_types=["apo"], renumber_residues=True
452
+ # # ) for i in range(0,342)]
453
+
454
+ # # test_new_apo1 = [get_system(test_filtered.id.iloc[i]).create_masked_bound_unbound_complexes(
455
+ # # monomer_types=["apo"], renumber_residues=True
456
+ # # ) for i in range(0,342)]
457
+
458
+ # # val_apo = val_new_apo1 + train_new_apo11
459
+ # # test_apo = test_new_apo1 + train_new_apo12
460
+
461
+ # import pickle
462
+ # # with open("train_apo.pkl", "wb") as file:
463
+ # # pickle.dump(train_apo, file)
464
+
465
+ # # with open("val_apo.pkl", "wb") as file:
466
+ # # pickle.dump(val_apo, file)
467
+
468
+ # # with open("test_apo.pkl", "wb") as file:
469
+ # # pickle.dump(test_apo, file)
470
+ # with open("train_apo.pkl", "rb") as file:
471
+ # train_apo = pickle.load(file)
472
+
473
+ # with open("val_apo.pkl", "rb") as file:
474
+ # val_apo = pickle.load(file)
475
+
476
+ # with open("test_apo.pkl", "rb") as file:
477
+ # test_apo = pickle.load(file)
478
+
479
+ # # # %%
480
+ # train_geo = [PairedPDB.from_tuple_system(train_apo[i]) for i in range(0,len(train_apo))]
481
+ # val_geo = [PairedPDB.from_tuple_system(val_apo[i]) for i in range(0,len(val_apo))]
482
+ # test_geo = [PairedPDB.from_tuple_system(test_apo[i]) for i in range(0,len(test_apo))]
483
+ # # # %%
484
+ # # Train= []
485
+ # # for i in range(0,len(train_geo)):
486
+ # # data = HeteroData()
487
+ # # data["ligand"].x = train_geo[i]["ligand"].x
488
+ # # data['ligand'].y = train_geo[i]["ligand"].y
489
+ # # data["ligand"].pos = train_geo[i]["ligand"].pos
490
+ # # data["ligand","ligand"].edge_index = train_geo[i]["ligand"]
491
+ # # data["receptor"].x = train_geo[i]["receptor"].x
492
+ # # data['receptor'].y = train_geo[i]["receptor"].y
493
+ # # data["receptor"].pos = train_geo[i]["receptor"].pos
494
+ # # data["receptor","receptor"].edge_index = train_geo[i]["receptor"]
495
+ # # #torch.save(data, f"./data/processed/train_sample_{i}.pt")
496
+ # # Train.append(data)
497
+
498
+ # from torch_geometric.data import HeteroData
499
+ # import torch_sparse
500
+ # from torch_geometric.edge_index import to_sparse_tensor
501
+ # import torch
502
+
503
+ # # Example of converting edge indices to SparseTensor and storing them in HeteroData
504
+
505
+ # Train1 = []
506
+ # for i in range(len(train_geo)):
507
+ # data = HeteroData()
508
+ # # Define ligand node features
509
+ # data["ligand"].x = train_geo[i]["ligand"].x
510
+ # data["ligand"].y = train_geo[i]["ligand"].y
511
+ # data["ligand"].pos = train_geo[i]["ligand"].pos
512
+ # # Convert ligand edge index to SparseTensor
513
+ # ligand_edge_index = train_geo[i]["ligand"]["edge_index"]
514
+ # data["ligand", "ligand"].edge_index = to_sparse_tensor(ligand_edge_index, sparse_sizes=(train_geo[i]["ligand"].num_nodes,)*2)
515
+
516
+ # # Define receptor node features
517
+ # data["receptor"].x = train_geo[i]["receptor"].x
518
+ # data["receptor"].y = train_geo[i]["receptor"].y
519
+ # data["receptor"].pos = train_geo[i]["receptor"].pos
520
+ # # Convert receptor edge index to SparseTensor
521
+ # receptor_edge_index = train_geo[i]["receptor"]["edge_index"]
522
+ # data["receptor", "receptor"].edge_index = to_sparse_tensor(receptor_edge_index, sparse_sizes=(train_geo[i]["receptor"].num_nodes,)*2)
523
+
524
+ # Train1.append(data)
525
+
526
+
527
+ # # # %%
528
+ # # Val= []
529
+ # # for i in range(0,len(val_geo)):
530
+ # # data = HeteroData()
531
+ # # data["ligand"].x = val_geo[i]["ligand"].x
532
+ # # data['ligand'].y = val_geo[i]["ligand"].y
533
+ # # data["ligand"].pos = val_geo[i]["ligand"].pos
534
+ # # data["ligand","ligand"].edge_index = val_geo[i]["ligand"]
535
+ # # data["receptor"].x = val_geo[i]["receptor"].x
536
+ # # data['receptor'].y = val_geo[i]["receptor"].y
537
+ # # data["receptor"].pos = val_geo[i]["receptor"].pos
538
+ # # data["receptor","receptor"].edge_index = val_geo[i]["receptor"]
539
+ # # #torch.save(data, f"./data/processed/val_sample_{i}.pt")
540
+ # # Val.append(data)
541
+ # Val1 = []
542
+ # for i in range(len(val_geo)):
543
+ # data = HeteroData()
544
+ # # Define ligand node features
545
+ # data["ligand"].x = val_geo[i]["ligand"].x
546
+ # data["ligand"].y = val_geo[i]["ligand"].y
547
+ # data["ligand"].pos = val_geo[i]["ligand"].pos
548
+ # # Convert ligand edge index to SparseTensor
549
+ # ligand_edge_index = val_geo[i]["ligand"]["edge_index"]
550
+ # data["ligand", "ligand"].edge_index = to_sparse_tensor(ligand_edge_index, sparse_sizes=(val_geo[i]["ligand"].num_nodes,)*2)
551
+
552
+ # # Define receptor node features
553
+ # data["receptor"].x = val_geo[i]["receptor"].x
554
+ # data["receptor"].y = val_geo[i]["receptor"].y
555
+ # data["receptor"].pos = val_geo[i]["receptor"].pos
556
+ # # Convert receptor edge index to SparseTensor
557
+ # receptor_edge_index = val_geo[i]["receptor"]["edge_index"]
558
+ # data["receptor", "receptor"].edge_index = to_sparse_tensor(receptor_edge_index, sparse_sizes=(val_geo[i]["receptor"].num_nodes,)*2)
559
+
560
+ # Val1.append(data)
561
+ # # # %%
562
+ # # Test= []
563
+ # # for i in range(0,len(test_geo)):
564
+ # # data = HeteroData()
565
+ # # data["ligand"].x = test_geo[i]["ligand"].x
566
+ # # data['ligand'].y = test_geo[i]["ligand"].y
567
+ # # data["ligand"].pos = test_geo[i]["ligand"].pos
568
+ # # data["ligand","ligand"].edge_index = test_geo[i]["ligand"]
569
+ # # data["receptor"].x = test_geo[i]["receptor"].x
570
+ # # data['receptor'].y = test_geo[i]["receptor"].y
571
+ # # data["receptor"].pos = test_geo[i]["receptor"].pos
572
+ # # data["receptor","receptor"].edge_index = test_geo[i]["receptor"]
573
+ # # #torch.save(data, f"./data/processed/test_sample_{i}.pt")
574
+ # # Test.append(data)
575
+ # Test1 = []
576
+ # for i in range(len(test_geo)):
577
+ # data = HeteroData()
578
+ # # Define ligand node features
579
+ # data["ligand"].x = test_geo[i]["ligand"].x
580
+ # data["ligand"].y = test_geo[i]["ligand"].y
581
+ # data["ligand"].pos = test_geo[i]["ligand"].pos
582
+ # # Convert ligand edge index to SparseTensor
583
+ # ligand_edge_index = test_geo[i]["ligand"]["edge_index"]
584
+ # data["ligand", "ligand"].edge_index = to_sparse_tensor(ligand_edge_index, sparse_sizes=(test_geo[i]["ligand"].num_nodes,)*2)
585
+
586
+ # # Define receptor node features
587
+ # data["receptor"].x = test_geo[i]["receptor"].x
588
+ # data["receptor"].y = test_geo[i]["receptor"].y
589
+ # data["receptor"].pos = test_geo[i]["receptor"].pos
590
+ # # Convert receptor edge index to SparseTensor
591
+ # receptor_edge_index = test_geo[i]["receptor"]["edge_index"]
592
+ # data["receptor", "receptor"].edge_index = to_sparse_tensor(receptor_edge_index, sparse_sizes=(test_geo[i]["receptor"].num_nodes,)*2)
593
+
594
+ # Test1.append(data)
595
+ # # with open("Train.pkl", "wb") as file:
596
+ # # pickle.dump(Train, file)
597
+
598
+ # # with open("Val.pkl", "wb") as file:
599
+ # # pickle.dump(Val, file)
600
+
601
+ # # with open("Test.pkl", "wb") as file:
602
+ # # pickle.dump(Test, file)
603
+
604
+ # # with open("Train1.pkl", "rb") as file:
605
+ # # Train= pickle.load(file)
606
+
607
+ # # with open("Val.pkl", "rb") as file:
608
+ # # Val = pickle.load(file)
609
+
610
+ # # with open("Test.pkl", "rb") as file:
611
+ # # Test = pickle.load(file)