Spaces:
Running
Running
Update trainer.py
Browse files- trainer.py +12 -12
trainer.py
CHANGED
@@ -398,7 +398,7 @@ class Trainer(object):
|
|
398 |
|
399 |
''' Loading the atom and bond decoders'''
|
400 |
|
401 |
-
with open("
|
402 |
|
403 |
return pickle.load(f)
|
404 |
|
@@ -406,7 +406,7 @@ class Trainer(object):
|
|
406 |
|
407 |
''' Loading the atom and bond decoders'''
|
408 |
|
409 |
-
with open("
|
410 |
|
411 |
return pickle.load(f)
|
412 |
|
@@ -507,15 +507,15 @@ class Trainer(object):
|
|
507 |
|
508 |
|
509 |
# protein data
|
510 |
-
full_smiles = [line for line in open("
|
511 |
-
drug_smiles = [line for line in open("
|
512 |
|
513 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
514 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
515 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
516 |
|
517 |
-
akt1_human_adj = torch.load("
|
518 |
-
akt1_human_annot = torch.load("
|
519 |
|
520 |
# Start training.
|
521 |
|
@@ -705,14 +705,14 @@ class Trainer(object):
|
|
705 |
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
706 |
|
707 |
|
708 |
-
drug_smiles = [line for line in open("
|
709 |
|
710 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
711 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
712 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
713 |
|
714 |
-
akt1_human_adj = torch.load("
|
715 |
-
akt1_human_annot = torch.load("
|
716 |
|
717 |
self.G.eval()
|
718 |
#self.D.eval()
|
@@ -753,8 +753,8 @@ class Trainer(object):
|
|
753 |
#metric_calc_mol = []
|
754 |
metric_calc_dr = []
|
755 |
date = time.time()
|
756 |
-
if not os.path.exists("
|
757 |
-
os.makedirs("
|
758 |
with torch.inference_mode():
|
759 |
|
760 |
dataloader_iterator = iter(self.drugs_loader)
|
@@ -867,7 +867,7 @@ class Trainer(object):
|
|
867 |
|
868 |
print("molecule batch {} inferred".format(i))
|
869 |
|
870 |
-
with open("
|
871 |
for molecules in inference_drugs:
|
872 |
|
873 |
f.write(molecules)
|
|
|
398 |
|
399 |
''' Loading the atom and bond decoders'''
|
400 |
|
401 |
+
with open("data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
|
402 |
|
403 |
return pickle.load(f)
|
404 |
|
|
|
406 |
|
407 |
''' Loading the atom and bond decoders'''
|
408 |
|
409 |
+
with open("data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f:
|
410 |
|
411 |
return pickle.load(f)
|
412 |
|
|
|
507 |
|
508 |
|
509 |
# protein data
|
510 |
+
full_smiles = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
|
511 |
+
drug_smiles = [line for line in open("data/akt_train.smi", 'r').read().splitlines()]
|
512 |
|
513 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
514 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
515 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
516 |
|
517 |
+
akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
518 |
+
akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
519 |
|
520 |
# Start training.
|
521 |
|
|
|
705 |
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
706 |
|
707 |
|
708 |
+
drug_smiles = [line for line in open("data/akt_test.smi", 'r').read().splitlines()]
|
709 |
|
710 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
711 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
712 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
713 |
|
714 |
+
akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
715 |
+
akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
716 |
|
717 |
self.G.eval()
|
718 |
#self.D.eval()
|
|
|
753 |
#metric_calc_mol = []
|
754 |
metric_calc_dr = []
|
755 |
date = time.time()
|
756 |
+
if not os.path.exists("experiments/inference/{}".format(self.submodel)):
|
757 |
+
os.makedirs("experiments/inference/{}".format(self.submodel))
|
758 |
with torch.inference_mode():
|
759 |
|
760 |
dataloader_iterator = iter(self.drugs_loader)
|
|
|
867 |
|
868 |
print("molecule batch {} inferred".format(i))
|
869 |
|
870 |
+
with open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
|
871 |
for molecules in inference_drugs:
|
872 |
|
873 |
f.write(molecules)
|