Sukanyaaa commited on
Commit
88ce25d
·
1 Parent(s): 1405a8b

fix train.py

Browse files
Files changed (1) hide show
  1. train.py +153 -312
train.py CHANGED
@@ -105,6 +105,9 @@ def get_system(system_id: str) -> PinderSystem:
105
  from Bio import PDB
106
  from Bio.PDB.PDBIO import PDBIO
107
 
 
 
 
108
  log = setup_logger(__name__)
109
 
110
  try:
@@ -265,208 +268,46 @@ class PairedPDB(HeteroData): # type: ignore
265
 
266
  return graph
267
 
268
- # To create dataset, we have used only PINDER datyaset with following steps as follows:
 
 
 
 
 
 
269
 
270
- # log = setup_logger(__name__)
271
-
272
- # try:
273
- # from torch_cluster import knn_graph
274
-
275
- # torch_cluster_installed = True
276
- # except ImportError as e:
277
- # log.warning(
278
- # "torch-cluster is not installed!"
279
- # "Please install the appropriate library for your pytorch installation."
280
- # "See https://github.com/rusty1s/pytorch_cluster/issues/185 for background."
281
- # )
282
- # torch_cluster_installed = False
283
-
284
-
285
- # def structure2tensor(
286
- # atom_coordinates: NDArray[np.double] | None = None,
287
- # atom_types: NDArray[np.str_] | None = None,
288
- # element_types: NDArray[np.str_] | None = None,
289
- # residue_coordinates: NDArray[np.double] | None = None,
290
- # residue_ids: NDArray[np.int_] | None = None,
291
- # residue_types: NDArray[np.str_] | None = None,
292
- # chain_ids: NDArray[np.str_] | None = None,
293
- # dtype: torch.dtype = torch.float32,
294
- # ) -> dict[str, torch.Tensor]:
295
- # property_dict = {}
296
- # if atom_types is not None:
297
- # unknown_name_idx = max(pc.ALL_ATOM_POSNS.values()) + 1
298
- # types_array_at = np.zeros((len(atom_types), 1))
299
- # for i, name in enumerate(atom_types):
300
- # types_array_at[i] = pc.ALL_ATOM_POSNS.get(name, unknown_name_idx)
301
- # property_dict["atom_types"] = torch.tensor(types_array_at).type(dtype)
302
- # if element_types is not None:
303
- # types_array_ele = np.zeros((len(element_types), 1))
304
- # for i, name in enumerate(element_types):
305
- # types_array_ele[i] = pc.ELE2NUM.get(name, pc.ELE2NUM["other"])
306
- # property_dict["element_types"] = torch.tensor(types_array_ele).type(dtype)
307
- # if residue_types is not None:
308
- # unknown_name_idx = max(pc.AA_TO_INDEX.values()) + 1
309
- # types_array_res = np.zeros((len(residue_types), 1))
310
- # for i, name in enumerate(residue_types):
311
- # types_array_res[i] = pc.AA_TO_INDEX.get(name, unknown_name_idx)
312
- # property_dict["residue_types"] = torch.tensor(types_array_res).type(dtype)
313
-
314
- # if atom_coordinates is not None:
315
- # property_dict["atom_coordinates"] = torch.tensor(atom_coordinates, dtype=dtype)
316
-
317
- # if residue_coordinates is not None:
318
- # property_dict["residue_coordinates"] = torch.tensor(
319
- # residue_coordinates, dtype=dtype
320
- # )
321
- # if residue_ids is not None:
322
- # property_dict["residue_ids"] = torch.tensor(residue_ids, dtype=dtype)
323
- # if chain_ids is not None:
324
- # property_dict["chain_ids"] = torch.zeros(len(chain_ids), dtype=dtype)
325
- # property_dict["chain_ids"][chain_ids == "L"] = 1
326
- # return property_dict
327
-
328
-
329
- # class NodeRepresentation(Enum):
330
- # Surface = "surface"
331
- # Atom = "atom"
332
- # Residue = "residue"
333
-
334
-
335
- # class PairedPDB(HeteroData): # type: ignore
336
- # @classmethod
337
- # def from_tuple_system(
338
- # cls,
339
-
340
- # tupal: tuple = (Structure , Structure , Structure),
341
-
342
- # add_edges: bool = True,
343
- # k: int = 10,
344
-
345
- # ) -> PairedPDB:
346
- # return cls.from_structure_pair(
347
-
348
- # holo=tupal[0],
349
- # apo=tupal[1],
350
- # add_edges=add_edges,
351
- # k=k,
352
- # )
353
-
354
- # @classmethod
355
- # def from_structure_pair(
356
- # cls,
357
-
358
- # holo: Structure,
359
- # apo: Structure,
360
-
361
- # add_edges: bool = True,
362
- # k: int = 10,
363
- # ) -> PairedPDB:
364
- # graph = cls()
365
- # holo_calpha = holo.filter("atom_name", mask=["CA"])
366
- # apo_calpha = apo.filter("atom_name", mask=["CA"])
367
- # r_h = (holo.dataframe['chain_id'] == 'R').sum()
368
- # r_a = (apo.dataframe['chain_id'] == 'R').sum()
369
-
370
- # holo_r_props = structure2tensor(
371
- # atom_coordinates=holo.coords[:r_h],
372
- # atom_types=holo.atom_array.atom_name[:r_h],
373
- # element_types=holo.atom_array.element[:r_h],
374
- # residue_coordinates=holo_calpha.coords[:r_h],
375
- # residue_types=holo_calpha.atom_array.res_name[:r_h],
376
- # residue_ids=holo_calpha.atom_array.res_id[:r_h],
377
- # )
378
- # holo_l_props = structure2tensor(
379
- # atom_coordinates=holo.coords[r_h:],
380
-
381
- # atom_types=holo.atom_array.atom_name[r_h:],
382
- # element_types=holo.atom_array.element[r_h:],
383
- # residue_coordinates=holo_calpha.coords[r_h:],
384
- # residue_types=holo_calpha.atom_array.res_name[r_h:],
385
- # residue_ids=holo_calpha.atom_array.res_id[r_h:],
386
- # )
387
- # apo_r_props = structure2tensor(
388
- # atom_coordinates=apo.coords[:r_a],
389
- # atom_types=apo.atom_array.atom_name[:r_a],
390
- # element_types=apo.atom_array.element[:r_a],
391
- # residue_coordinates=apo_calpha.coords[:r_a],
392
- # residue_types=apo_calpha.atom_array.res_name[:r_a],
393
- # residue_ids=apo_calpha.atom_array.res_id[:r_a],
394
- # )
395
- # apo_l_props = structure2tensor(
396
- # atom_coordinates=apo.coords[r_a:],
397
- # atom_types=apo.atom_array.atom_name[r_a:],
398
- # element_types=apo.atom_array.element[r_a:],
399
- # residue_coordinates=apo_calpha.coords[r_a:],
400
- # residue_types=apo_calpha.atom_array.res_name[r_a:],
401
- # residue_ids=apo_calpha.atom_array.res_id[r_a:],
402
- # )
403
-
404
-
405
-
406
- # graph["ligand"].x = apo_l_props["atom_types"]
407
- # graph["ligand"].pos = apo_l_props["atom_coordinates"]
408
- # graph["receptor"].x = apo_r_props["atom_types"]
409
- # graph["receptor"].pos = apo_r_props["atom_coordinates"]
410
- # graph["ligand"].y = holo_l_props["atom_coordinates"]
411
- # # graph["ligand"].pos = holo_l_props["atom_coordinates"]
412
- # graph["receptor"].y = holo_r_props["atom_coordinates"]
413
- # # graph["receptor"].pos = holo_r_props["atom_coordinates"]
414
- # if add_edges and torch_cluster_installed:
415
- # graph["ligand"].edge_index = knn_graph(
416
- # graph["ligand"].pos, k=k
417
- # )
418
- # graph["receptor"].edge_index = knn_graph(
419
- # graph["receptor"].pos, k=k
420
- # )
421
- # # graph["ligand"].edge_index = knn_graph(
422
- # # graph["ligand"].pos, k=k
423
- # # )
424
- # # graph["receptor"].edge_index = knn_graph(
425
- # # graph["receptor"].pos, k=k
426
- # # )
427
-
428
- # return graph
429
-
430
- # index = get_index()
431
- # # train = index[index.split == "train"].copy()
432
- # # val = index[index.split == "val"].copy()
433
- # # test = index[index.split == "test"].copy()
434
- # # train_filtered = train[(train['apo_R'] == True) & (train['apo_L'] == True)].copy()
435
- # # val_filtered = val[(val['apo_R'] == True) & (val['apo_L'] == True)].copy()
436
- # # test_filtered = test[(test['apo_R'] == True) & (test['apo_L'] == True)].copy()
437
-
438
- # # train_apo = [get_system(train_filtered.id.iloc[i]).create_masked_bound_unbound_complexes(
439
- # # monomer_types=["apo"], renumber_residues=True
440
- # # ) for i in range(0, 10000)]
441
-
442
- # # train_new_apo11 = [get_system(train_filtered.id.iloc[i]).create_masked_bound_unbound_complexes(
443
- # # monomer_types=["apo"], renumber_residues=True
444
- # # ) for i in range(10000,10908)]
445
-
446
- # # train_new_apo12 = [get_system(train_filtered.id.iloc[i]).create_masked_bound_unbound_complexes(
447
- # # # monomer_types=["apo"], renumber_residues=True
448
- # # ) for i in range(10908,11816)]
449
-
450
- # # val_new_apo1 = [get_system(val_filtered.id.iloc[i]).create_masked_bound_unbound_complexes(
451
- # # monomer_types=["apo"], renumber_residues=True
452
- # # ) for i in range(0,342)]
453
-
454
- # # test_new_apo1 = [get_system(test_filtered.id.iloc[i]).create_masked_bound_unbound_complexes(
455
- # # monomer_types=["apo"], renumber_residues=True
456
- # # ) for i in range(0,342)]
457
-
458
- # # val_apo = val_new_apo1 + train_new_apo11
459
- # # test_apo = test_new_apo1 + train_new_apo12
460
-
461
- # import pickle
462
- # # with open("train_apo.pkl", "wb") as file:
463
- # # pickle.dump(train_apo, file)
464
 
465
- # # with open("val_apo.pkl", "wb") as file:
466
- # # pickle.dump(val_apo, file)
467
 
468
- # # with open("test_apo.pkl", "wb") as file:
469
- # # pickle.dump(test_apo, file)
470
  # with open("train_apo.pkl", "rb") as file:
471
  # train_apo = pickle.load(file)
472
 
@@ -476,136 +317,136 @@ class PairedPDB(HeteroData): # type: ignore
476
  # with open("test_apo.pkl", "rb") as file:
477
  # test_apo = pickle.load(file)
478
 
479
- # # # %%
480
- # train_geo = [PairedPDB.from_tuple_system(train_apo[i]) for i in range(0,len(train_apo))]
481
- # val_geo = [PairedPDB.from_tuple_system(val_apo[i]) for i in range(0,len(val_apo))]
482
- # test_geo = [PairedPDB.from_tuple_system(test_apo[i]) for i in range(0,len(test_apo))]
483
- # # # %%
484
- # # Train= []
485
- # # for i in range(0,len(train_geo)):
486
- # # data = HeteroData()
487
- # # data["ligand"].x = train_geo[i]["ligand"].x
488
- # # data['ligand'].y = train_geo[i]["ligand"].y
489
- # # data["ligand"].pos = train_geo[i]["ligand"].pos
490
- # # data["ligand","ligand"].edge_index = train_geo[i]["ligand"]
491
- # # data["receptor"].x = train_geo[i]["receptor"].x
492
- # # data['receptor'].y = train_geo[i]["receptor"].y
493
- # # data["receptor"].pos = train_geo[i]["receptor"].pos
494
- # # data["receptor","receptor"].edge_index = train_geo[i]["receptor"]
495
- # # #torch.save(data, f"./data/processed/train_sample_{i}.pt")
496
- # # Train.append(data)
497
-
498
- # from torch_geometric.data import HeteroData
499
- # import torch_sparse
500
- # from torch_geometric.edge_index import to_sparse_tensor
501
- # import torch
502
-
503
- # # Example of converting edge indices to SparseTensor and storing them in HeteroData
504
-
505
- # Train1 = []
506
- # for i in range(len(train_geo)):
507
  # data = HeteroData()
508
- # # Define ligand node features
509
  # data["ligand"].x = train_geo[i]["ligand"].x
510
- # data["ligand"].y = train_geo[i]["ligand"].y
511
  # data["ligand"].pos = train_geo[i]["ligand"].pos
512
- # # Convert ligand edge index to SparseTensor
513
- # ligand_edge_index = train_geo[i]["ligand"]["edge_index"]
514
- # data["ligand", "ligand"].edge_index = to_sparse_tensor(ligand_edge_index, sparse_sizes=(train_geo[i]["ligand"].num_nodes,)*2)
515
-
516
- # # Define receptor node features
517
  # data["receptor"].x = train_geo[i]["receptor"].x
518
- # data["receptor"].y = train_geo[i]["receptor"].y
519
  # data["receptor"].pos = train_geo[i]["receptor"].pos
520
- # # Convert receptor edge index to SparseTensor
521
- # receptor_edge_index = train_geo[i]["receptor"]["edge_index"]
522
- # data["receptor", "receptor"].edge_index = to_sparse_tensor(receptor_edge_index, sparse_sizes=(train_geo[i]["receptor"].num_nodes,)*2)
523
-
524
- # Train1.append(data)
525
-
526
-
527
- # # # %%
528
- # # Val= []
529
- # # for i in range(0,len(val_geo)):
530
- # # data = HeteroData()
531
- # # data["ligand"].x = val_geo[i]["ligand"].x
532
- # # data['ligand'].y = val_geo[i]["ligand"].y
533
- # # data["ligand"].pos = val_geo[i]["ligand"].pos
534
- # # data["ligand","ligand"].edge_index = val_geo[i]["ligand"]
535
- # # data["receptor"].x = val_geo[i]["receptor"].x
536
- # # data['receptor'].y = val_geo[i]["receptor"].y
537
- # # data["receptor"].pos = val_geo[i]["receptor"].pos
538
- # # data["receptor","receptor"].edge_index = val_geo[i]["receptor"]
539
- # # #torch.save(data, f"./data/processed/val_sample_{i}.pt")
540
- # # Val.append(data)
541
- # Val1 = []
542
- # for i in range(len(val_geo)):
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  # data = HeteroData()
544
- # # Define ligand node features
545
  # data["ligand"].x = val_geo[i]["ligand"].x
546
- # data["ligand"].y = val_geo[i]["ligand"].y
547
  # data["ligand"].pos = val_geo[i]["ligand"].pos
548
- # # Convert ligand edge index to SparseTensor
549
- # ligand_edge_index = val_geo[i]["ligand"]["edge_index"]
550
- # data["ligand", "ligand"].edge_index = to_sparse_tensor(ligand_edge_index, sparse_sizes=(val_geo[i]["ligand"].num_nodes,)*2)
551
-
552
- # # Define receptor node features
553
  # data["receptor"].x = val_geo[i]["receptor"].x
554
- # data["receptor"].y = val_geo[i]["receptor"].y
555
  # data["receptor"].pos = val_geo[i]["receptor"].pos
556
- # # Convert receptor edge index to SparseTensor
557
- # receptor_edge_index = val_geo[i]["receptor"]["edge_index"]
558
- # data["receptor", "receptor"].edge_index = to_sparse_tensor(receptor_edge_index, sparse_sizes=(val_geo[i]["receptor"].num_nodes,)*2)
559
-
560
- # Val1.append(data)
561
- # # # %%
562
- # # Test= []
563
- # # for i in range(0,len(test_geo)):
564
- # # data = HeteroData()
565
- # # data["ligand"].x = test_geo[i]["ligand"].x
566
- # # data['ligand'].y = test_geo[i]["ligand"].y
567
- # # data["ligand"].pos = test_geo[i]["ligand"].pos
568
- # # data["ligand","ligand"].edge_index = test_geo[i]["ligand"]
569
- # # data["receptor"].x = test_geo[i]["receptor"].x
570
- # # data['receptor'].y = test_geo[i]["receptor"].y
571
- # # data["receptor"].pos = test_geo[i]["receptor"].pos
572
- # # data["receptor","receptor"].edge_index = test_geo[i]["receptor"]
573
- # # #torch.save(data, f"./data/processed/test_sample_{i}.pt")
574
- # # Test.append(data)
575
- # Test1 = []
576
- # for i in range(len(test_geo)):
 
 
 
 
 
577
  # data = HeteroData()
578
- # # Define ligand node features
579
  # data["ligand"].x = test_geo[i]["ligand"].x
580
- # data["ligand"].y = test_geo[i]["ligand"].y
581
  # data["ligand"].pos = test_geo[i]["ligand"].pos
582
- # # Convert ligand edge index to SparseTensor
583
- # ligand_edge_index = test_geo[i]["ligand"]["edge_index"]
584
- # data["ligand", "ligand"].edge_index = to_sparse_tensor(ligand_edge_index, sparse_sizes=(test_geo[i]["ligand"].num_nodes,)*2)
585
-
586
- # # Define receptor node features
587
  # data["receptor"].x = test_geo[i]["receptor"].x
588
- # data["receptor"].y = test_geo[i]["receptor"].y
589
  # data["receptor"].pos = test_geo[i]["receptor"].pos
590
- # # Convert receptor edge index to SparseTensor
591
- # receptor_edge_index = test_geo[i]["receptor"]["edge_index"]
592
- # data["receptor", "receptor"].edge_index = to_sparse_tensor(receptor_edge_index, sparse_sizes=(test_geo[i]["receptor"].num_nodes,)*2)
593
-
594
- # Test1.append(data)
595
- # # with open("Train.pkl", "wb") as file:
596
- # # pickle.dump(Train, file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597
 
598
- # # with open("Val.pkl", "wb") as file:
599
- # # pickle.dump(Val, file)
600
 
601
- # # with open("Test.pkl", "wb") as file:
602
- # # pickle.dump(Test, file)
603
 
604
- # # with open("Train1.pkl", "rb") as file:
605
- # # Train= pickle.load(file)
606
 
607
- # # with open("Val.pkl", "rb") as file:
608
- # # Val = pickle.load(file)
609
 
610
- # # with open("Test.pkl", "rb") as file:
611
- # # Test = pickle.load(file)
 
105
  from Bio import PDB
106
  from Bio.PDB.PDBIO import PDBIO
107
 
108
+
109
+ # To create dataset, we have used only PINDER datyaset with following steps as follows:
110
+
111
  log = setup_logger(__name__)
112
 
113
  try:
 
268
 
269
  return graph
270
 
271
+ index = get_index()
272
+ train = index[index.split == "train"].copy()
273
+ val = index[index.split == "val"].copy()
274
+ test = index[index.split == "test"].copy()
275
+ train_filtered = train[(train['apo_R'] == True) & (train['apo_L'] == True)].copy()
276
+ val_filtered = val[(val['apo_R'] == True) & (val['apo_L'] == True)].copy()
277
+ test_filtered = test[(test['apo_R'] == True) & (test['apo_L'] == True)].copy()
278
 
279
+ train_apo = [get_system(train_filtered.id.iloc[i]).create_masked_bound_unbound_complexes(
280
+ monomer_types=["apo"], renumber_residues=True
281
+ ) for i in range(0, 10000)]
282
+
283
+ train_new_apo11 = [get_system(train_filtered.id.iloc[i]).create_masked_bound_unbound_complexes(
284
+ monomer_types=["apo"], renumber_residues=True
285
+ ) for i in range(10000,10908)]
286
+
287
+ train_new_apo12 = [get_system(train_filtered.id.iloc[i]).create_masked_bound_unbound_complexes(
288
+ # monomer_types=["apo"], renumber_residues=True
289
+ ) for i in range(10908,11816)]
290
+
291
+ val_new_apo1 = [get_system(val_filtered.id.iloc[i]).create_masked_bound_unbound_complexes(
292
+ monomer_types=["apo"], renumber_residues=True
293
+ ) for i in range(0,342)]
294
+
295
+ test_new_apo1 = [get_system(test_filtered.id.iloc[i]).create_masked_bound_unbound_complexes(
296
+ monomer_types=["apo"], renumber_residues=True
297
+ ) for i in range(0,342)]
298
+
299
+ val_apo = val_new_apo1 + train_new_apo11
300
+ test_apo = test_new_apo1 + train_new_apo12
301
+
302
+ import pickle
303
+ # with open("train_apo.pkl", "wb") as file:
304
+ # pickle.dump(train_apo, file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
+ # with open("val_apo.pkl", "wb") as file:
307
+ # pickle.dump(val_apo, file)
308
 
309
+ # with open("test_apo.pkl", "wb") as file:
310
+ # pickle.dump(test_apo, file)
311
  # with open("train_apo.pkl", "rb") as file:
312
  # train_apo = pickle.load(file)
313
 
 
317
  # with open("test_apo.pkl", "rb") as file:
318
  # test_apo = pickle.load(file)
319
 
320
+ # # %%
321
+ train_geo = [PairedPDB.from_tuple_system(train_apo[i]) for i in range(0,len(train_apo))]
322
+ val_geo = [PairedPDB.from_tuple_system(val_apo[i]) for i in range(0,len(val_apo))]
323
+ test_geo = [PairedPDB.from_tuple_system(test_apo[i]) for i in range(0,len(test_apo))]
324
+ # # %%
325
+ # Train= []
326
+ # for i in range(0,len(train_geo)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  # data = HeteroData()
 
328
  # data["ligand"].x = train_geo[i]["ligand"].x
329
+ # data['ligand'].y = train_geo[i]["ligand"].y
330
  # data["ligand"].pos = train_geo[i]["ligand"].pos
331
+ # data["ligand","ligand"].edge_index = train_geo[i]["ligand"]
 
 
 
 
332
  # data["receptor"].x = train_geo[i]["receptor"].x
333
+ # data['receptor'].y = train_geo[i]["receptor"].y
334
  # data["receptor"].pos = train_geo[i]["receptor"].pos
335
+ # data["receptor","receptor"].edge_index = train_geo[i]["receptor"]
336
+ # #torch.save(data, f"./data/processed/train_sample_{i}.pt")
337
+ # Train.append(data)
338
+
339
+ from torch_geometric.data import HeteroData
340
+ import torch_sparse
341
+ from torch_geometric.edge_index import to_sparse_tensor
342
+ import torch
343
+
344
+ # Example of converting edge indices to SparseTensor and storing them in HeteroData
345
+
346
+ Train1 = []
347
+ for i in range(len(train_geo)):
348
+ data = HeteroData()
349
+ # Define ligand node features
350
+ data["ligand"].x = train_geo[i]["ligand"].x
351
+ data["ligand"].y = train_geo[i]["ligand"].y
352
+ data["ligand"].pos = train_geo[i]["ligand"].pos
353
+ # Convert ligand edge index to SparseTensor
354
+ ligand_edge_index = train_geo[i]["ligand"]["edge_index"]
355
+ data["ligand", "ligand"].edge_index = to_sparse_tensor(ligand_edge_index, sparse_sizes=(train_geo[i]["ligand"].num_nodes,)*2)
356
+
357
+ # Define receptor node features
358
+ data["receptor"].x = train_geo[i]["receptor"].x
359
+ data["receptor"].y = train_geo[i]["receptor"].y
360
+ data["receptor"].pos = train_geo[i]["receptor"].pos
361
+ # Convert receptor edge index to SparseTensor
362
+ receptor_edge_index = train_geo[i]["receptor"]["edge_index"]
363
+ data["receptor", "receptor"].edge_index = to_sparse_tensor(receptor_edge_index, sparse_sizes=(train_geo[i]["receptor"].num_nodes,)*2)
364
+
365
+ Train1.append(data)
366
+
367
+
368
+ # # %%
369
+ # Val= []
370
+ # for i in range(0,len(val_geo)):
371
  # data = HeteroData()
 
372
  # data["ligand"].x = val_geo[i]["ligand"].x
373
+ # data['ligand'].y = val_geo[i]["ligand"].y
374
  # data["ligand"].pos = val_geo[i]["ligand"].pos
375
+ # data["ligand","ligand"].edge_index = val_geo[i]["ligand"]
 
 
 
 
376
  # data["receptor"].x = val_geo[i]["receptor"].x
377
+ # data['receptor'].y = val_geo[i]["receptor"].y
378
  # data["receptor"].pos = val_geo[i]["receptor"].pos
379
+ # data["receptor","receptor"].edge_index = val_geo[i]["receptor"]
380
+ # #torch.save(data, f"./data/processed/val_sample_{i}.pt")
381
+ # Val.append(data)
382
+ Val1 = []
383
+ for i in range(len(val_geo)):
384
+ data = HeteroData()
385
+ # Define ligand node features
386
+ data["ligand"].x = val_geo[i]["ligand"].x
387
+ data["ligand"].y = val_geo[i]["ligand"].y
388
+ data["ligand"].pos = val_geo[i]["ligand"].pos
389
+ # Convert ligand edge index to SparseTensor
390
+ ligand_edge_index = val_geo[i]["ligand"]["edge_index"]
391
+ data["ligand", "ligand"].edge_index = to_sparse_tensor(ligand_edge_index, sparse_sizes=(val_geo[i]["ligand"].num_nodes,)*2)
392
+
393
+ # Define receptor node features
394
+ data["receptor"].x = val_geo[i]["receptor"].x
395
+ data["receptor"].y = val_geo[i]["receptor"].y
396
+ data["receptor"].pos = val_geo[i]["receptor"].pos
397
+ # Convert receptor edge index to SparseTensor
398
+ receptor_edge_index = val_geo[i]["receptor"]["edge_index"]
399
+ data["receptor", "receptor"].edge_index = to_sparse_tensor(receptor_edge_index, sparse_sizes=(val_geo[i]["receptor"].num_nodes,)*2)
400
+
401
+ Val1.append(data)
402
+ # # %%
403
+ # Test= []
404
+ # for i in range(0,len(test_geo)):
405
  # data = HeteroData()
 
406
  # data["ligand"].x = test_geo[i]["ligand"].x
407
+ # data['ligand'].y = test_geo[i]["ligand"].y
408
  # data["ligand"].pos = test_geo[i]["ligand"].pos
409
+ # data["ligand","ligand"].edge_index = test_geo[i]["ligand"]
 
 
 
 
410
  # data["receptor"].x = test_geo[i]["receptor"].x
411
+ # data['receptor'].y = test_geo[i]["receptor"].y
412
  # data["receptor"].pos = test_geo[i]["receptor"].pos
413
+ # data["receptor","receptor"].edge_index = test_geo[i]["receptor"]
414
+ # #torch.save(data, f"./data/processed/test_sample_{i}.pt")
415
+ # Test.append(data)
416
+ Test1 = []
417
+ for i in range(len(test_geo)):
418
+ data = HeteroData()
419
+ # Define ligand node features
420
+ data["ligand"].x = test_geo[i]["ligand"].x
421
+ data["ligand"].y = test_geo[i]["ligand"].y
422
+ data["ligand"].pos = test_geo[i]["ligand"].pos
423
+ # Convert ligand edge index to SparseTensor
424
+ ligand_edge_index = test_geo[i]["ligand"]["edge_index"]
425
+ data["ligand", "ligand"].edge_index = to_sparse_tensor(ligand_edge_index, sparse_sizes=(test_geo[i]["ligand"].num_nodes,)*2)
426
+
427
+ # Define receptor node features
428
+ data["receptor"].x = test_geo[i]["receptor"].x
429
+ data["receptor"].y = test_geo[i]["receptor"].y
430
+ data["receptor"].pos = test_geo[i]["receptor"].pos
431
+ # Convert receptor edge index to SparseTensor
432
+ receptor_edge_index = test_geo[i]["receptor"]["edge_index"]
433
+ data["receptor", "receptor"].edge_index = to_sparse_tensor(receptor_edge_index, sparse_sizes=(test_geo[i]["receptor"].num_nodes,)*2)
434
+
435
+ Test1.append(data)
436
+ # with open("Train.pkl", "wb") as file:
437
+ # pickle.dump(Train, file)
438
 
439
+ # with open("Val.pkl", "wb") as file:
440
+ # pickle.dump(Val, file)
441
 
442
+ # with open("Test.pkl", "wb") as file:
443
+ # pickle.dump(Test, file)
444
 
445
+ # with open("Train1.pkl", "rb") as file:
446
+ # Train= pickle.load(file)
447
 
448
+ # with open("Val.pkl", "rb") as file:
449
+ # Val = pickle.load(file)
450
 
451
+ # with open("Test.pkl", "rb") as file:
452
+ # Test = pickle.load(file)