osbm commited on
Commit
6a59579
·
1 Parent(s): 07759b4

Update trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +101 -74
trainer.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  from utils import *
7
  from models import Generator, Generator2, simple_disc
8
  import torch_geometric.utils as geoutils
9
- #import #wandb
10
  import re
11
  from torch_geometric.loader import DataLoader
12
  from new_dataloader import DruggenDataset
@@ -19,7 +19,7 @@ RDLogger.DisableLog('rdApp.*')
19
  from loss import discriminator_loss, generator_loss, discriminator2_loss, generator2_loss
20
  from training_data import load_data
21
  import random
22
-
23
 
24
  class Trainer(object):
25
 
@@ -27,6 +27,19 @@ class Trainer(object):
27
 
28
  def __init__(self, config):
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
31
  """Initialize configurations."""
32
  self.submodel = config.submodel
@@ -57,7 +70,10 @@ class Trainer(object):
57
 
58
  self.inf_drugs_dataset_file = config.inf_drug_dataset_file # Drug dataset file name for the second GAN.
59
  # Contains drug molecules only. (In this case AKT1 inhibitors.)
60
-
 
 
 
61
  self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
62
 
63
  self.drug_data_dir = config.drug_data_dir # Directory where the drug dataset files are stored.
@@ -219,6 +235,14 @@ class Trainer(object):
219
  self.clipping_value = config.clipping_value
220
  # Miscellaneous.
221
 
 
 
 
 
 
 
 
 
222
  self.mode = config.mode
223
 
224
  self.noise_strength_0 = torch.nn.Parameter(torch.zeros([]))
@@ -398,7 +422,7 @@ class Trainer(object):
398
 
399
  ''' Loading the atom and bond decoders'''
400
 
401
- with open("data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
402
 
403
  return pickle.load(f)
404
 
@@ -406,7 +430,7 @@ class Trainer(object):
406
 
407
  ''' Loading the atom and bond decoders'''
408
 
409
- with open("data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f:
410
 
411
  return pickle.load(f)
412
 
@@ -429,17 +453,17 @@ class Trainer(object):
429
  print('Loading the trained models from epoch / iteration {}-{}...'.format(epoch, iteration))
430
 
431
  G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(epoch, iteration))
432
- #D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(epoch, iteration))
433
 
434
  self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
435
- #self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
436
 
437
 
438
  G2_path = os.path.join(model_directory, '{}-{}-G2.ckpt'.format(epoch, iteration))
439
- #D2_path = os.path.join(model_directory, '{}-{}-D2.ckpt'.format(epoch, iteration))
440
 
441
  self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
442
- #self.D2.load_state_dict(torch.load(D2_path, map_location=lambda storage, loc: storage))
443
 
444
 
445
  def save_model(self, model_directory, idx,i):
@@ -507,16 +531,19 @@ class Trainer(object):
507
 
508
 
509
  # protein data
510
- full_smiles = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
511
- drug_smiles = [line for line in open("data/akt_train.smi", 'r').read().splitlines()]
512
 
513
  drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
514
  drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
515
  fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
516
 
517
- akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
518
- akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
519
-
 
 
 
520
  # Start training.
521
 
522
  print('Start training...')
@@ -577,8 +604,8 @@ class Trainer(object):
577
  GAN2_disc_e = drugs_a_tensor
578
  GAN2_disc_x = drugs_x_tensor
579
  elif self.submodel == "RL":
580
- GAN1_input_e = z_edge
581
- GAN1_input_x = z_node
582
  GAN1_disc_e = a_tensor
583
  GAN1_disc_x = x_tensor
584
  GAN2_input_e = drugs_a_tensor
@@ -586,8 +613,8 @@ class Trainer(object):
586
  GAN2_disc_e = drugs_a_tensor
587
  GAN2_disc_x = drugs_x_tensor
588
  elif self.submodel == "NoTarget":
589
- GAN1_input_e = z_edge
590
- GAN1_input_x = z_node
591
  GAN1_disc_e = a_tensor
592
  GAN1_disc_x = x_tensor
593
 
@@ -639,9 +666,10 @@ class Trainer(object):
639
  GAN1_input_x,
640
  self.batch_size,
641
  sim_reward,
642
- self.dataset.matrices2mol_drugs,
643
  fps_r,
644
- self.submodel)
 
645
 
646
  g_loss, fake_mol, g_edges_hat_sample, g_nodes_hat_sample, node, edge = generator_output
647
 
@@ -659,7 +687,8 @@ class Trainer(object):
659
  fps_r,
660
  GAN2_input_e,
661
  GAN2_input_x,
662
- self.submodel)
 
663
 
664
  g2_loss, fake_mol_g, dr_g_edges_hat_sample, dr_g_nodes_hat_sample = output
665
 
@@ -695,31 +724,31 @@ class Trainer(object):
695
 
696
  # Load the trained generator.
697
  self.G.to(self.device)
698
- #self.D.to(self.device)
699
  self.G2.to(self.device)
700
- #self.D2.to(self.device)
701
 
702
  G_path = os.path.join(self.inference_model, '{}-G.ckpt'.format(self.submodel))
703
  self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
704
- G2_path = os.path.join(self.inference_model, '{}-G2.ckpt'.format(self.submodel))
705
- self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
 
706
 
707
- print(G_path)
708
- drug_smiles = [line for line in open("data/akt_test.smi", 'r').read().splitlines()]
709
 
710
  drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
711
  drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
712
  fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
713
 
714
- akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
715
- akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
716
 
717
  self.G.eval()
718
  #self.D.eval()
719
  self.G2.eval()
720
  #self.D2.eval()
 
 
721
 
722
- self.inf_batch_size =256
723
  self.inf_dataset = DruggenDataset(self.mol_data_dir,
724
  self.inf_dataset_file,
725
  self.inf_raw_file,
@@ -753,24 +782,25 @@ class Trainer(object):
753
  #metric_calc_mol = []
754
  metric_calc_dr = []
755
  date = time.time()
756
- if not os.path.exists("experiments/inference/{}".format(self.submodel)):
757
- os.makedirs("experiments/inference/{}".format(self.submodel))
758
  with torch.inference_mode():
759
 
760
- dataloader_iterator = iter(self.drugs_loader)
761
-
762
- for i, data in enumerate(self.loader):
 
763
  try:
764
  drugs = next(dataloader_iterator)
765
  except StopIteration:
766
- dataloader_iterator = iter(self.drugs_loader)
767
  drugs = next(dataloader_iterator)
768
 
769
  # Preprocess both dataset
770
 
771
  bulk_data = load_data(data,
772
  drugs,
773
- self.batch_size,
774
  self.device,
775
  self.b_dim,
776
  self.m_dim,
@@ -809,8 +839,8 @@ class Trainer(object):
809
  GAN2_disc_e = drugs_a_tensor
810
  GAN2_disc_x = drugs_x_tensor
811
  elif self.submodel == "RL":
812
- GAN1_input_e = z_edge
813
- GAN1_input_x = z_node
814
  GAN1_disc_e = a_tensor
815
  GAN1_disc_x = x_tensor
816
  GAN2_input_e = drugs_a_tensor
@@ -818,8 +848,8 @@ class Trainer(object):
818
  GAN2_disc_e = drugs_a_tensor
819
  GAN2_disc_x = drugs_x_tensor
820
  elif self.submodel == "NoTarget":
821
- GAN1_input_e = z_edge
822
- GAN1_input_x = z_node
823
  GAN1_disc_e = a_tensor
824
  GAN1_disc_x = x_tensor
825
  # =================================================================================== #
@@ -830,53 +860,50 @@ class Trainer(object):
830
  self.V,
831
  GAN1_input_e,
832
  GAN1_input_x,
833
- self.batch_size,
834
  sim_reward,
835
- self.dataset.matrices2mol_drugs,
836
  fps_r,
837
- self.submodel)
 
838
 
839
- _, fake_mol, _, _, node, edge = generator_output
840
 
841
  # =================================================================================== #
842
  # 3. GAN2 Inference #
843
  # =================================================================================== #
844
 
845
- output = generator2_loss(self.G2,
846
- self.D2,
847
- self.V2,
848
- edge,
849
- node,
850
- self.batch_size,
851
- sim_reward,
852
- self.dataset.matrices2mol_drugs,
853
- fps_r,
854
- GAN2_input_e,
855
- GAN2_input_x,
856
- self.submodel)
 
 
857
 
858
- _, fake_mol_g, _, _ = output
859
 
860
  inference_drugs = [Chem.MolToSmiles(line) for line in fake_mol_g if line is not None]
 
861
 
862
-
863
-
864
- #inference_smiles = [Chem.MolToSmiles(line) for line in fake_mol]
865
-
866
-
867
-
868
- print("molecule batch {} inferred".format(i))
869
-
870
- with open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
871
  for molecules in inference_drugs:
872
 
873
  f.write(molecules)
874
  f.write("\n")
875
  metric_calc_dr.append(molecules)
876
 
877
-
878
-
879
- if i == 120:
 
880
  break
881
 
882
  et = time.time() - start_time
@@ -885,8 +912,8 @@ class Trainer(object):
885
 
886
  print("Metrics calculation started using MOSES.")
887
 
888
- print("Validity: ", fraction_valid(inference_drugs), "\n")
889
- print("Uniqueness: ", fraction_unique(inference_drugs), "\n")
890
- print("Validity: ", novelty(inference_drugs, drug_smiles), "\n")
891
 
892
- print("Metrics are calculated.")
 
6
  from utils import *
7
  from models import Generator, Generator2, simple_disc
8
  import torch_geometric.utils as geoutils
9
+ #import wandb
10
  import re
11
  from torch_geometric.loader import DataLoader
12
  from new_dataloader import DruggenDataset
 
19
  from loss import discriminator_loss, generator_loss, discriminator2_loss, generator2_loss
20
  from training_data import load_data
21
  import random
22
+ from tqdm import tqdm
23
 
24
  class Trainer(object):
25
 
 
27
 
28
  def __init__(self, config):
29
 
30
+ if config.set_seed:
31
+ np.random.seed(config.seed)
32
+ random.seed(config.seed)
33
+ torch.manual_seed(config.seed)
34
+ torch.cuda.manual_seed(config.seed)
35
+
36
+ torch.backends.cudnn.deterministic = True
37
+ torch.backends.cudnn.benchmark = False
38
+
39
+ os.environ["PYTHONHASHSEED"] = str(config.seed)
40
+
41
+ print(f'Using seed {config.seed}')
42
+
43
  self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
44
  """Initialize configurations."""
45
  self.submodel = config.submodel
 
70
 
71
  self.inf_drugs_dataset_file = config.inf_drug_dataset_file # Drug dataset file name for the second GAN.
72
  # Contains drug molecules only. (In this case AKT1 inhibitors.)
73
+ self.inference_iterations = config.inference_iterations
74
+
75
+ self.inf_batch_size = config.inf_batch_size
76
+
77
  self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
78
 
79
  self.drug_data_dir = config.drug_data_dir # Directory where the drug dataset files are stored.
 
235
  self.clipping_value = config.clipping_value
236
  # Miscellaneous.
237
 
238
+ # resume training
239
+
240
+ self.resume = config.resume
241
+ self.resume_epoch = config.resume_epoch
242
+ self.resume_iter = config.resume_iter
243
+ self.resume_directory = config.resume_directory
244
+
245
+
246
  self.mode = config.mode
247
 
248
  self.noise_strength_0 = torch.nn.Parameter(torch.zeros([]))
 
422
 
423
  ''' Loading the atom and bond decoders'''
424
 
425
+ with open("DrugGEN/data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
426
 
427
  return pickle.load(f)
428
 
 
430
 
431
  ''' Loading the atom and bond decoders'''
432
 
433
+ with open("DrugGEN/data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f:
434
 
435
  return pickle.load(f)
436
 
 
453
  print('Loading the trained models from epoch / iteration {}-{}...'.format(epoch, iteration))
454
 
455
  G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(epoch, iteration))
456
+ D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(epoch, iteration))
457
 
458
  self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
459
+ self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
460
 
461
 
462
  G2_path = os.path.join(model_directory, '{}-{}-G2.ckpt'.format(epoch, iteration))
463
+ D2_path = os.path.join(model_directory, '{}-{}-D2.ckpt'.format(epoch, iteration))
464
 
465
  self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
466
+ self.D2.load_state_dict(torch.load(D2_path, map_location=lambda storage, loc: storage))
467
 
468
 
469
  def save_model(self, model_directory, idx,i):
 
531
 
532
 
533
  # protein data
534
+ full_smiles = [line for line in open("DrugGEN/data/chembl_train.smi", 'r').read().splitlines()]
535
+ drug_smiles = [line for line in open("DrugGEN/data/akt_train.smi", 'r').read().splitlines()]
536
 
537
  drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
538
  drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
539
  fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
540
 
541
+ akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
542
+ akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
543
+
544
+ if self.resume:
545
+ self.restore_model(self.resume_epoch, self.resume_iter, self.resume_directory)
546
+
547
  # Start training.
548
 
549
  print('Start training...')
 
604
  GAN2_disc_e = drugs_a_tensor
605
  GAN2_disc_x = drugs_x_tensor
606
  elif self.submodel == "RL":
607
+ GAN1_input_e = a_tensor
608
+ GAN1_input_x = x_tensor
609
  GAN1_disc_e = a_tensor
610
  GAN1_disc_x = x_tensor
611
  GAN2_input_e = drugs_a_tensor
 
613
  GAN2_disc_e = drugs_a_tensor
614
  GAN2_disc_x = drugs_x_tensor
615
  elif self.submodel == "NoTarget":
616
+ GAN1_input_e = a_tensor
617
+ GAN1_input_x = x_tensor
618
  GAN1_disc_e = a_tensor
619
  GAN1_disc_x = x_tensor
620
 
 
666
  GAN1_input_x,
667
  self.batch_size,
668
  sim_reward,
669
+ self.dataset.matrices2mol,
670
  fps_r,
671
+ self.submodel,
672
+ self.dataset_name)
673
 
674
  g_loss, fake_mol, g_edges_hat_sample, g_nodes_hat_sample, node, edge = generator_output
675
 
 
687
  fps_r,
688
  GAN2_input_e,
689
  GAN2_input_x,
690
+ self.submodel,
691
+ self.drugs_name)
692
 
693
  g2_loss, fake_mol_g, dr_g_edges_hat_sample, dr_g_nodes_hat_sample = output
694
 
 
724
 
725
  # Load the trained generator.
726
  self.G.to(self.device)
 
727
  self.G2.to(self.device)
 
728
 
729
  G_path = os.path.join(self.inference_model, '{}-G.ckpt'.format(self.submodel))
730
  self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
731
+ if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
732
+ G2_path = os.path.join(self.inference_model, '{}-G2.ckpt'.format(self.submodel))
733
+ self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
734
 
735
+
736
+ drug_smiles = [line for line in open("DrugGEN/data/akt_test.smi", 'r').read().splitlines()]
737
 
738
  drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
739
  drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
740
  fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
741
 
742
+ akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
743
+ akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
744
 
745
  self.G.eval()
746
  #self.D.eval()
747
  self.G2.eval()
748
  #self.D2.eval()
749
+
750
+ step = self.inference_iterations
751
 
 
752
  self.inf_dataset = DruggenDataset(self.mol_data_dir,
753
  self.inf_dataset_file,
754
  self.inf_raw_file,
 
782
  #metric_calc_mol = []
783
  metric_calc_dr = []
784
  date = time.time()
785
+ if not os.path.exists("DrugGEN/experiments/inference/{}".format(self.submodel)):
786
+ os.makedirs("DrugGEN/experiments/inference/{}".format(self.submodel))
787
  with torch.inference_mode():
788
 
789
+ dataloader_iterator = iter(self.inf_drugs_loader)
790
+ pbar = tqdm(range(self.inference_sample_num))
791
+ pbar.set_description('Inference mode for {} model started'.format(self.submodel))
792
+ for i, data in enumerate(self.inf_loader):
793
  try:
794
  drugs = next(dataloader_iterator)
795
  except StopIteration:
796
+ dataloader_iterator = iter(self.inf_drugs_loader)
797
  drugs = next(dataloader_iterator)
798
 
799
  # Preprocess both dataset
800
 
801
  bulk_data = load_data(data,
802
  drugs,
803
+ self.inf_batch_size,
804
  self.device,
805
  self.b_dim,
806
  self.m_dim,
 
839
  GAN2_disc_e = drugs_a_tensor
840
  GAN2_disc_x = drugs_x_tensor
841
  elif self.submodel == "RL":
842
+ GAN1_input_e = a_tensor
843
+ GAN1_input_x = x_tensor
844
  GAN1_disc_e = a_tensor
845
  GAN1_disc_x = x_tensor
846
  GAN2_input_e = drugs_a_tensor
 
848
  GAN2_disc_e = drugs_a_tensor
849
  GAN2_disc_x = drugs_x_tensor
850
  elif self.submodel == "NoTarget":
851
+ GAN1_input_e = a_tensor
852
+ GAN1_input_x = x_tensor
853
  GAN1_disc_e = a_tensor
854
  GAN1_disc_x = x_tensor
855
  # =================================================================================== #
 
860
  self.V,
861
  GAN1_input_e,
862
  GAN1_input_x,
863
+ self.inf_batch_size,
864
  sim_reward,
865
+ self.dataset.matrices2mol,
866
  fps_r,
867
+ self.submodel,
868
+ self.dataset_name)
869
 
870
+ _, fake_mol_g, _, _, node, edge = generator_output
871
 
872
  # =================================================================================== #
873
  # 3. GAN2 Inference #
874
  # =================================================================================== #
875
 
876
+ if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
877
+ output = generator2_loss(self.G2,
878
+ self.D2,
879
+ self.V2,
880
+ edge,
881
+ node,
882
+ self.inf_batch_size,
883
+ sim_reward,
884
+ self.dataset.matrices2mol_drugs,
885
+ fps_r,
886
+ GAN2_input_e,
887
+ GAN2_input_x,
888
+ self.submodel,
889
+ self.drugs_name)
890
 
891
+ _, fake_mol_g, edges, nodes = output
892
 
893
  inference_drugs = [Chem.MolToSmiles(line) for line in fake_mol_g if line is not None]
894
+ inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
895
 
896
+ with open("DrugGEN/experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
 
 
 
 
 
 
 
 
897
  for molecules in inference_drugs:
898
 
899
  f.write(molecules)
900
  f.write("\n")
901
  metric_calc_dr.append(molecules)
902
 
903
+ if len(inference_drugs) > 0:
904
+ pbar.update(1)
905
+
906
+ if len(metric_calc_dr) == self.inference_sample_num:
907
  break
908
 
909
  et = time.time() - start_time
 
912
 
913
  print("Metrics calculation started using MOSES.")
914
 
915
+ print("Validity: ", fraction_valid(metric_calc_dr), "\n")
916
+ print("Uniqueness: ", fraction_unique(metric_calc_dr), "\n")
917
+ print("Validity: ", novelty(metric_calc_dr, drug_smiles), "\n")
918
 
919
+ print("Metrics are calculated.")