Spaces:
Running
Running
Update trainer.py
Browse files- trainer.py +12 -12
trainer.py
CHANGED
@@ -422,7 +422,7 @@ class Trainer(object):
|
|
422 |
|
423 |
''' Loading the atom and bond decoders'''
|
424 |
|
425 |
-
with open("
|
426 |
|
427 |
return pickle.load(f)
|
428 |
|
@@ -430,7 +430,7 @@ class Trainer(object):
|
|
430 |
|
431 |
''' Loading the atom and bond decoders'''
|
432 |
|
433 |
-
with open("
|
434 |
|
435 |
return pickle.load(f)
|
436 |
|
@@ -531,15 +531,15 @@ class Trainer(object):
|
|
531 |
|
532 |
|
533 |
# protein data
|
534 |
-
full_smiles = [line for line in open("
|
535 |
-
drug_smiles = [line for line in open("
|
536 |
|
537 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
538 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
539 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
540 |
|
541 |
-
akt1_human_adj = torch.load("
|
542 |
-
akt1_human_annot = torch.load("
|
543 |
|
544 |
if self.resume:
|
545 |
self.restore_model(self.resume_epoch, self.resume_iter, self.resume_directory)
|
@@ -733,14 +733,14 @@ class Trainer(object):
|
|
733 |
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
734 |
|
735 |
|
736 |
-
drug_smiles = [line for line in open("
|
737 |
|
738 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
739 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
740 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
741 |
|
742 |
-
akt1_human_adj = torch.load("
|
743 |
-
akt1_human_annot = torch.load("
|
744 |
|
745 |
self.G.eval()
|
746 |
#self.D.eval()
|
@@ -782,8 +782,8 @@ class Trainer(object):
|
|
782 |
#metric_calc_mol = []
|
783 |
metric_calc_dr = []
|
784 |
date = time.time()
|
785 |
-
if not os.path.exists("
|
786 |
-
os.makedirs("
|
787 |
with torch.inference_mode():
|
788 |
|
789 |
dataloader_iterator = iter(self.inf_drugs_loader)
|
@@ -893,7 +893,7 @@ class Trainer(object):
|
|
893 |
inference_drugs = [Chem.MolToSmiles(line) for line in fake_mol_g if line is not None]
|
894 |
inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
|
895 |
|
896 |
-
with open("
|
897 |
for molecules in inference_drugs:
|
898 |
|
899 |
f.write(molecules)
|
|
|
422 |
|
423 |
''' Loading the atom and bond decoders'''
|
424 |
|
425 |
+
with open("data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
|
426 |
|
427 |
return pickle.load(f)
|
428 |
|
|
|
430 |
|
431 |
''' Loading the atom and bond decoders'''
|
432 |
|
433 |
+
with open("data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f:
|
434 |
|
435 |
return pickle.load(f)
|
436 |
|
|
|
531 |
|
532 |
|
533 |
# protein data
|
534 |
+
full_smiles = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
|
535 |
+
drug_smiles = [line for line in open("data/akt_train.smi", 'r').read().splitlines()]
|
536 |
|
537 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
538 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
539 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
540 |
|
541 |
+
akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
542 |
+
akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
543 |
|
544 |
if self.resume:
|
545 |
self.restore_model(self.resume_epoch, self.resume_iter, self.resume_directory)
|
|
|
733 |
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
734 |
|
735 |
|
736 |
+
drug_smiles = [line for line in open("data/akt_test.smi", 'r').read().splitlines()]
|
737 |
|
738 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
739 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
740 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
741 |
|
742 |
+
akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
743 |
+
akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
744 |
|
745 |
self.G.eval()
|
746 |
#self.D.eval()
|
|
|
782 |
#metric_calc_mol = []
|
783 |
metric_calc_dr = []
|
784 |
date = time.time()
|
785 |
+
if not os.path.exists("experiments/inference/{}".format(self.submodel)):
|
786 |
+
os.makedirs("experiments/inference/{}".format(self.submodel))
|
787 |
with torch.inference_mode():
|
788 |
|
789 |
dataloader_iterator = iter(self.inf_drugs_loader)
|
|
|
893 |
inference_drugs = [Chem.MolToSmiles(line) for line in fake_mol_g if line is not None]
|
894 |
inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
|
895 |
|
896 |
+
with open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
|
897 |
for molecules in inference_drugs:
|
898 |
|
899 |
f.write(molecules)
|