osbm commited on
Commit
bcafa7d
·
1 Parent(s): 0450a37

Update trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +12 -12
trainer.py CHANGED
@@ -398,7 +398,7 @@ class Trainer(object):
398
 
399
  ''' Loading the atom and bond decoders'''
400
 
401
- with open("DrugGEN/data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
402
 
403
  return pickle.load(f)
404
 
@@ -406,7 +406,7 @@ class Trainer(object):
406
 
407
  ''' Loading the atom and bond decoders'''
408
 
409
- with open("DrugGEN/data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f:
410
 
411
  return pickle.load(f)
412
 
@@ -507,15 +507,15 @@ class Trainer(object):
507
 
508
 
509
  # protein data
510
- full_smiles = [line for line in open("DrugGEN/data/chembl_train.smi", 'r').read().splitlines()]
511
- drug_smiles = [line for line in open("DrugGEN/data/akt_train.smi", 'r').read().splitlines()]
512
 
513
  drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
514
  drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
515
  fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
516
 
517
- akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
518
- akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
519
 
520
  # Start training.
521
 
@@ -705,14 +705,14 @@ class Trainer(object):
705
  self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
706
 
707
 
708
- drug_smiles = [line for line in open("DrugGEN/data/akt_test.smi", 'r').read().splitlines()]
709
 
710
  drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
711
  drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
712
  fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
713
 
714
- akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
715
- akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
716
 
717
  self.G.eval()
718
  #self.D.eval()
@@ -753,8 +753,8 @@ class Trainer(object):
753
  #metric_calc_mol = []
754
  metric_calc_dr = []
755
  date = time.time()
756
- if not os.path.exists("DrugGEN/experiments/inference/{}".format(self.submodel)):
757
- os.makedirs("DrugGEN/experiments/inference/{}".format(self.submodel))
758
  with torch.inference_mode():
759
 
760
  dataloader_iterator = iter(self.drugs_loader)
@@ -867,7 +867,7 @@ class Trainer(object):
867
 
868
  print("molecule batch {} inferred".format(i))
869
 
870
- with open("DrugGEN/experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
871
  for molecules in inference_drugs:
872
 
873
  f.write(molecules)
 
398
 
399
  ''' Loading the atom and bond decoders'''
400
 
401
+ with open("data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
402
 
403
  return pickle.load(f)
404
 
 
406
 
407
  ''' Loading the atom and bond decoders'''
408
 
409
+ with open("data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f:
410
 
411
  return pickle.load(f)
412
 
 
507
 
508
 
509
  # protein data
510
+ full_smiles = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
511
+ drug_smiles = [line for line in open("data/akt_train.smi", 'r').read().splitlines()]
512
 
513
  drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
514
  drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
515
  fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
516
 
517
+ akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
518
+ akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
519
 
520
  # Start training.
521
 
 
705
  self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
706
 
707
 
708
+ drug_smiles = [line for line in open("data/akt_test.smi", 'r').read().splitlines()]
709
 
710
  drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
711
  drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
712
  fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
713
 
714
+ akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
715
+ akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
716
 
717
  self.G.eval()
718
  #self.D.eval()
 
753
  #metric_calc_mol = []
754
  metric_calc_dr = []
755
  date = time.time()
756
+ if not os.path.exists("experiments/inference/{}".format(self.submodel)):
757
+ os.makedirs("experiments/inference/{}".format(self.submodel))
758
  with torch.inference_mode():
759
 
760
  dataloader_iterator = iter(self.drugs_loader)
 
867
 
868
  print("molecule batch {} inferred".format(i))
869
 
870
+ with open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
871
  for molecules in inference_drugs:
872
 
873
  f.write(molecules)