osbm commited on
Commit
0280c50
·
1 Parent(s): 61675a4

Update trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +12 -12
trainer.py CHANGED
@@ -422,7 +422,7 @@ class Trainer(object):
422
 
423
  ''' Loading the atom and bond decoders'''
424
 
425
- with open("DrugGEN/data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
426
 
427
  return pickle.load(f)
428
 
@@ -430,7 +430,7 @@ class Trainer(object):
430
 
431
  ''' Loading the atom and bond decoders'''
432
 
433
- with open("DrugGEN/data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f:
434
 
435
  return pickle.load(f)
436
 
@@ -531,15 +531,15 @@ class Trainer(object):
531
 
532
 
533
  # protein data
534
- full_smiles = [line for line in open("DrugGEN/data/chembl_train.smi", 'r').read().splitlines()]
535
- drug_smiles = [line for line in open("DrugGEN/data/akt_train.smi", 'r').read().splitlines()]
536
 
537
  drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
538
  drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
539
  fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
540
 
541
- akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
542
- akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
543
 
544
  if self.resume:
545
  self.restore_model(self.resume_epoch, self.resume_iter, self.resume_directory)
@@ -733,14 +733,14 @@ class Trainer(object):
733
  self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
734
 
735
 
736
- drug_smiles = [line for line in open("DrugGEN/data/akt_test.smi", 'r').read().splitlines()]
737
 
738
  drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
739
  drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
740
  fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
741
 
742
- akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
743
- akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
744
 
745
  self.G.eval()
746
  #self.D.eval()
@@ -782,8 +782,8 @@ class Trainer(object):
782
  #metric_calc_mol = []
783
  metric_calc_dr = []
784
  date = time.time()
785
- if not os.path.exists("DrugGEN/experiments/inference/{}".format(self.submodel)):
786
- os.makedirs("DrugGEN/experiments/inference/{}".format(self.submodel))
787
  with torch.inference_mode():
788
 
789
  dataloader_iterator = iter(self.inf_drugs_loader)
@@ -893,7 +893,7 @@ class Trainer(object):
893
  inference_drugs = [Chem.MolToSmiles(line) for line in fake_mol_g if line is not None]
894
  inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
895
 
896
- with open("DrugGEN/experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
897
  for molecules in inference_drugs:
898
 
899
  f.write(molecules)
 
422
 
423
  ''' Loading the atom and bond decoders'''
424
 
425
+ with open("data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
426
 
427
  return pickle.load(f)
428
 
 
430
 
431
  ''' Loading the atom and bond decoders'''
432
 
433
+ with open("data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f:
434
 
435
  return pickle.load(f)
436
 
 
531
 
532
 
533
  # protein data
534
+ full_smiles = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
535
+ drug_smiles = [line for line in open("data/akt_train.smi", 'r').read().splitlines()]
536
 
537
  drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
538
  drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
539
  fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
540
 
541
+ akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
542
+ akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
543
 
544
  if self.resume:
545
  self.restore_model(self.resume_epoch, self.resume_iter, self.resume_directory)
 
733
  self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
734
 
735
 
736
+ drug_smiles = [line for line in open("data/akt_test.smi", 'r').read().splitlines()]
737
 
738
  drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
739
  drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
740
  fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
741
 
742
+ akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
743
+ akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
744
 
745
  self.G.eval()
746
  #self.D.eval()
 
782
  #metric_calc_mol = []
783
  metric_calc_dr = []
784
  date = time.time()
785
+ if not os.path.exists("experiments/inference/{}".format(self.submodel)):
786
+ os.makedirs("experiments/inference/{}".format(self.submodel))
787
  with torch.inference_mode():
788
 
789
  dataloader_iterator = iter(self.inf_drugs_loader)
 
893
  inference_drugs = [Chem.MolToSmiles(line) for line in fake_mol_g if line is not None]
894
  inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
895
 
896
+ with open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
897
  for molecules in inference_drugs:
898
 
899
  f.write(molecules)