Spaces:
Running
Running
Update trainer.py
Browse files- trainer.py +101 -74
trainer.py
CHANGED
@@ -6,7 +6,7 @@ import torch
|
|
6 |
from utils import *
|
7 |
from models import Generator, Generator2, simple_disc
|
8 |
import torch_geometric.utils as geoutils
|
9 |
-
#import
|
10 |
import re
|
11 |
from torch_geometric.loader import DataLoader
|
12 |
from new_dataloader import DruggenDataset
|
@@ -19,7 +19,7 @@ RDLogger.DisableLog('rdApp.*')
|
|
19 |
from loss import discriminator_loss, generator_loss, discriminator2_loss, generator2_loss
|
20 |
from training_data import load_data
|
21 |
import random
|
22 |
-
|
23 |
|
24 |
class Trainer(object):
|
25 |
|
@@ -27,6 +27,19 @@ class Trainer(object):
|
|
27 |
|
28 |
def __init__(self, config):
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
|
31 |
"""Initialize configurations."""
|
32 |
self.submodel = config.submodel
|
@@ -57,7 +70,10 @@ class Trainer(object):
|
|
57 |
|
58 |
self.inf_drugs_dataset_file = config.inf_drug_dataset_file # Drug dataset file name for the second GAN.
|
59 |
# Contains drug molecules only. (In this case AKT1 inhibitors.)
|
60 |
-
|
|
|
|
|
|
|
61 |
self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
|
62 |
|
63 |
self.drug_data_dir = config.drug_data_dir # Directory where the drug dataset files are stored.
|
@@ -219,6 +235,14 @@ class Trainer(object):
|
|
219 |
self.clipping_value = config.clipping_value
|
220 |
# Miscellaneous.
|
221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
self.mode = config.mode
|
223 |
|
224 |
self.noise_strength_0 = torch.nn.Parameter(torch.zeros([]))
|
@@ -398,7 +422,7 @@ class Trainer(object):
|
|
398 |
|
399 |
''' Loading the atom and bond decoders'''
|
400 |
|
401 |
-
with open("data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
|
402 |
|
403 |
return pickle.load(f)
|
404 |
|
@@ -406,7 +430,7 @@ class Trainer(object):
|
|
406 |
|
407 |
''' Loading the atom and bond decoders'''
|
408 |
|
409 |
-
with open("data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f:
|
410 |
|
411 |
return pickle.load(f)
|
412 |
|
@@ -429,17 +453,17 @@ class Trainer(object):
|
|
429 |
print('Loading the trained models from epoch / iteration {}-{}...'.format(epoch, iteration))
|
430 |
|
431 |
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(epoch, iteration))
|
432 |
-
|
433 |
|
434 |
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
435 |
-
|
436 |
|
437 |
|
438 |
G2_path = os.path.join(model_directory, '{}-{}-G2.ckpt'.format(epoch, iteration))
|
439 |
-
|
440 |
|
441 |
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
442 |
-
|
443 |
|
444 |
|
445 |
def save_model(self, model_directory, idx,i):
|
@@ -507,16 +531,19 @@ class Trainer(object):
|
|
507 |
|
508 |
|
509 |
# protein data
|
510 |
-
full_smiles = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
|
511 |
-
drug_smiles = [line for line in open("data/akt_train.smi", 'r').read().splitlines()]
|
512 |
|
513 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
514 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
515 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
516 |
|
517 |
-
akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
518 |
-
akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
519 |
-
|
|
|
|
|
|
|
520 |
# Start training.
|
521 |
|
522 |
print('Start training...')
|
@@ -577,8 +604,8 @@ class Trainer(object):
|
|
577 |
GAN2_disc_e = drugs_a_tensor
|
578 |
GAN2_disc_x = drugs_x_tensor
|
579 |
elif self.submodel == "RL":
|
580 |
-
GAN1_input_e =
|
581 |
-
GAN1_input_x =
|
582 |
GAN1_disc_e = a_tensor
|
583 |
GAN1_disc_x = x_tensor
|
584 |
GAN2_input_e = drugs_a_tensor
|
@@ -586,8 +613,8 @@ class Trainer(object):
|
|
586 |
GAN2_disc_e = drugs_a_tensor
|
587 |
GAN2_disc_x = drugs_x_tensor
|
588 |
elif self.submodel == "NoTarget":
|
589 |
-
GAN1_input_e =
|
590 |
-
GAN1_input_x =
|
591 |
GAN1_disc_e = a_tensor
|
592 |
GAN1_disc_x = x_tensor
|
593 |
|
@@ -639,9 +666,10 @@ class Trainer(object):
|
|
639 |
GAN1_input_x,
|
640 |
self.batch_size,
|
641 |
sim_reward,
|
642 |
-
self.dataset.
|
643 |
fps_r,
|
644 |
-
self.submodel
|
|
|
645 |
|
646 |
g_loss, fake_mol, g_edges_hat_sample, g_nodes_hat_sample, node, edge = generator_output
|
647 |
|
@@ -659,7 +687,8 @@ class Trainer(object):
|
|
659 |
fps_r,
|
660 |
GAN2_input_e,
|
661 |
GAN2_input_x,
|
662 |
-
self.submodel
|
|
|
663 |
|
664 |
g2_loss, fake_mol_g, dr_g_edges_hat_sample, dr_g_nodes_hat_sample = output
|
665 |
|
@@ -695,31 +724,31 @@ class Trainer(object):
|
|
695 |
|
696 |
# Load the trained generator.
|
697 |
self.G.to(self.device)
|
698 |
-
#self.D.to(self.device)
|
699 |
self.G2.to(self.device)
|
700 |
-
#self.D2.to(self.device)
|
701 |
|
702 |
G_path = os.path.join(self.inference_model, '{}-G.ckpt'.format(self.submodel))
|
703 |
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
704 |
-
|
705 |
-
|
|
|
706 |
|
707 |
-
|
708 |
-
drug_smiles = [line for line in open("data/akt_test.smi", 'r').read().splitlines()]
|
709 |
|
710 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
711 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
712 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
713 |
|
714 |
-
akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
715 |
-
akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
716 |
|
717 |
self.G.eval()
|
718 |
#self.D.eval()
|
719 |
self.G2.eval()
|
720 |
#self.D2.eval()
|
|
|
|
|
721 |
|
722 |
-
self.inf_batch_size =256
|
723 |
self.inf_dataset = DruggenDataset(self.mol_data_dir,
|
724 |
self.inf_dataset_file,
|
725 |
self.inf_raw_file,
|
@@ -753,24 +782,25 @@ class Trainer(object):
|
|
753 |
#metric_calc_mol = []
|
754 |
metric_calc_dr = []
|
755 |
date = time.time()
|
756 |
-
if not os.path.exists("experiments/inference/{}".format(self.submodel)):
|
757 |
-
os.makedirs("experiments/inference/{}".format(self.submodel))
|
758 |
with torch.inference_mode():
|
759 |
|
760 |
-
dataloader_iterator = iter(self.
|
761 |
-
|
762 |
-
for
|
|
|
763 |
try:
|
764 |
drugs = next(dataloader_iterator)
|
765 |
except StopIteration:
|
766 |
-
dataloader_iterator = iter(self.
|
767 |
drugs = next(dataloader_iterator)
|
768 |
|
769 |
# Preprocess both dataset
|
770 |
|
771 |
bulk_data = load_data(data,
|
772 |
drugs,
|
773 |
-
self.
|
774 |
self.device,
|
775 |
self.b_dim,
|
776 |
self.m_dim,
|
@@ -809,8 +839,8 @@ class Trainer(object):
|
|
809 |
GAN2_disc_e = drugs_a_tensor
|
810 |
GAN2_disc_x = drugs_x_tensor
|
811 |
elif self.submodel == "RL":
|
812 |
-
GAN1_input_e =
|
813 |
-
GAN1_input_x =
|
814 |
GAN1_disc_e = a_tensor
|
815 |
GAN1_disc_x = x_tensor
|
816 |
GAN2_input_e = drugs_a_tensor
|
@@ -818,8 +848,8 @@ class Trainer(object):
|
|
818 |
GAN2_disc_e = drugs_a_tensor
|
819 |
GAN2_disc_x = drugs_x_tensor
|
820 |
elif self.submodel == "NoTarget":
|
821 |
-
GAN1_input_e =
|
822 |
-
GAN1_input_x =
|
823 |
GAN1_disc_e = a_tensor
|
824 |
GAN1_disc_x = x_tensor
|
825 |
# =================================================================================== #
|
@@ -830,53 +860,50 @@ class Trainer(object):
|
|
830 |
self.V,
|
831 |
GAN1_input_e,
|
832 |
GAN1_input_x,
|
833 |
-
self.
|
834 |
sim_reward,
|
835 |
-
self.dataset.
|
836 |
fps_r,
|
837 |
-
self.submodel
|
|
|
838 |
|
839 |
-
_,
|
840 |
|
841 |
# =================================================================================== #
|
842 |
# 3. GAN2 Inference #
|
843 |
# =================================================================================== #
|
844 |
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
|
|
|
|
|
857 |
|
858 |
-
|
859 |
|
860 |
inference_drugs = [Chem.MolToSmiles(line) for line in fake_mol_g if line is not None]
|
|
|
861 |
|
862 |
-
|
863 |
-
|
864 |
-
#inference_smiles = [Chem.MolToSmiles(line) for line in fake_mol]
|
865 |
-
|
866 |
-
|
867 |
-
|
868 |
-
print("molecule batch {} inferred".format(i))
|
869 |
-
|
870 |
-
with open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
|
871 |
for molecules in inference_drugs:
|
872 |
|
873 |
f.write(molecules)
|
874 |
f.write("\n")
|
875 |
metric_calc_dr.append(molecules)
|
876 |
|
877 |
-
|
878 |
-
|
879 |
-
|
|
|
880 |
break
|
881 |
|
882 |
et = time.time() - start_time
|
@@ -885,8 +912,8 @@ class Trainer(object):
|
|
885 |
|
886 |
print("Metrics calculation started using MOSES.")
|
887 |
|
888 |
-
print("Validity: ", fraction_valid(
|
889 |
-
print("Uniqueness: ", fraction_unique(
|
890 |
-
print("Validity: ", novelty(
|
891 |
|
892 |
-
print("Metrics are calculated.")
|
|
|
6 |
from utils import *
|
7 |
from models import Generator, Generator2, simple_disc
|
8 |
import torch_geometric.utils as geoutils
|
9 |
+
#import wandb
|
10 |
import re
|
11 |
from torch_geometric.loader import DataLoader
|
12 |
from new_dataloader import DruggenDataset
|
|
|
19 |
from loss import discriminator_loss, generator_loss, discriminator2_loss, generator2_loss
|
20 |
from training_data import load_data
|
21 |
import random
|
22 |
+
from tqdm import tqdm
|
23 |
|
24 |
class Trainer(object):
|
25 |
|
|
|
27 |
|
28 |
def __init__(self, config):
|
29 |
|
30 |
+
if config.set_seed:
|
31 |
+
np.random.seed(config.seed)
|
32 |
+
random.seed(config.seed)
|
33 |
+
torch.manual_seed(config.seed)
|
34 |
+
torch.cuda.manual_seed(config.seed)
|
35 |
+
|
36 |
+
torch.backends.cudnn.deterministic = True
|
37 |
+
torch.backends.cudnn.benchmark = False
|
38 |
+
|
39 |
+
os.environ["PYTHONHASHSEED"] = str(config.seed)
|
40 |
+
|
41 |
+
print(f'Using seed {config.seed}')
|
42 |
+
|
43 |
self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
|
44 |
"""Initialize configurations."""
|
45 |
self.submodel = config.submodel
|
|
|
70 |
|
71 |
self.inf_drugs_dataset_file = config.inf_drug_dataset_file # Drug dataset file name for the second GAN.
|
72 |
# Contains drug molecules only. (In this case AKT1 inhibitors.)
|
73 |
+
self.inference_iterations = config.inference_iterations
|
74 |
+
|
75 |
+
self.inf_batch_size = config.inf_batch_size
|
76 |
+
|
77 |
self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
|
78 |
|
79 |
self.drug_data_dir = config.drug_data_dir # Directory where the drug dataset files are stored.
|
|
|
235 |
self.clipping_value = config.clipping_value
|
236 |
# Miscellaneous.
|
237 |
|
238 |
+
# resume training
|
239 |
+
|
240 |
+
self.resume = config.resume
|
241 |
+
self.resume_epoch = config.resume_epoch
|
242 |
+
self.resume_iter = config.resume_iter
|
243 |
+
self.resume_directory = config.resume_directory
|
244 |
+
|
245 |
+
|
246 |
self.mode = config.mode
|
247 |
|
248 |
self.noise_strength_0 = torch.nn.Parameter(torch.zeros([]))
|
|
|
422 |
|
423 |
''' Loading the atom and bond decoders'''
|
424 |
|
425 |
+
with open("DrugGEN/data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
|
426 |
|
427 |
return pickle.load(f)
|
428 |
|
|
|
430 |
|
431 |
''' Loading the atom and bond decoders'''
|
432 |
|
433 |
+
with open("DrugGEN/data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f:
|
434 |
|
435 |
return pickle.load(f)
|
436 |
|
|
|
453 |
print('Loading the trained models from epoch / iteration {}-{}...'.format(epoch, iteration))
|
454 |
|
455 |
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(epoch, iteration))
|
456 |
+
D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(epoch, iteration))
|
457 |
|
458 |
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
459 |
+
self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
|
460 |
|
461 |
|
462 |
G2_path = os.path.join(model_directory, '{}-{}-G2.ckpt'.format(epoch, iteration))
|
463 |
+
D2_path = os.path.join(model_directory, '{}-{}-D2.ckpt'.format(epoch, iteration))
|
464 |
|
465 |
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
466 |
+
self.D2.load_state_dict(torch.load(D2_path, map_location=lambda storage, loc: storage))
|
467 |
|
468 |
|
469 |
def save_model(self, model_directory, idx,i):
|
|
|
531 |
|
532 |
|
533 |
# protein data
|
534 |
+
full_smiles = [line for line in open("DrugGEN/data/chembl_train.smi", 'r').read().splitlines()]
|
535 |
+
drug_smiles = [line for line in open("DrugGEN/data/akt_train.smi", 'r').read().splitlines()]
|
536 |
|
537 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
538 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
539 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
540 |
|
541 |
+
akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
542 |
+
akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
543 |
+
|
544 |
+
if self.resume:
|
545 |
+
self.restore_model(self.resume_epoch, self.resume_iter, self.resume_directory)
|
546 |
+
|
547 |
# Start training.
|
548 |
|
549 |
print('Start training...')
|
|
|
604 |
GAN2_disc_e = drugs_a_tensor
|
605 |
GAN2_disc_x = drugs_x_tensor
|
606 |
elif self.submodel == "RL":
|
607 |
+
GAN1_input_e = a_tensor
|
608 |
+
GAN1_input_x = x_tensor
|
609 |
GAN1_disc_e = a_tensor
|
610 |
GAN1_disc_x = x_tensor
|
611 |
GAN2_input_e = drugs_a_tensor
|
|
|
613 |
GAN2_disc_e = drugs_a_tensor
|
614 |
GAN2_disc_x = drugs_x_tensor
|
615 |
elif self.submodel == "NoTarget":
|
616 |
+
GAN1_input_e = a_tensor
|
617 |
+
GAN1_input_x = x_tensor
|
618 |
GAN1_disc_e = a_tensor
|
619 |
GAN1_disc_x = x_tensor
|
620 |
|
|
|
666 |
GAN1_input_x,
|
667 |
self.batch_size,
|
668 |
sim_reward,
|
669 |
+
self.dataset.matrices2mol,
|
670 |
fps_r,
|
671 |
+
self.submodel,
|
672 |
+
self.dataset_name)
|
673 |
|
674 |
g_loss, fake_mol, g_edges_hat_sample, g_nodes_hat_sample, node, edge = generator_output
|
675 |
|
|
|
687 |
fps_r,
|
688 |
GAN2_input_e,
|
689 |
GAN2_input_x,
|
690 |
+
self.submodel,
|
691 |
+
self.drugs_name)
|
692 |
|
693 |
g2_loss, fake_mol_g, dr_g_edges_hat_sample, dr_g_nodes_hat_sample = output
|
694 |
|
|
|
724 |
|
725 |
# Load the trained generator.
|
726 |
self.G.to(self.device)
|
|
|
727 |
self.G2.to(self.device)
|
|
|
728 |
|
729 |
G_path = os.path.join(self.inference_model, '{}-G.ckpt'.format(self.submodel))
|
730 |
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
731 |
+
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
|
732 |
+
G2_path = os.path.join(self.inference_model, '{}-G2.ckpt'.format(self.submodel))
|
733 |
+
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
734 |
|
735 |
+
|
736 |
+
drug_smiles = [line for line in open("DrugGEN/data/akt_test.smi", 'r').read().splitlines()]
|
737 |
|
738 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
739 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
740 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
741 |
|
742 |
+
akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
743 |
+
akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
744 |
|
745 |
self.G.eval()
|
746 |
#self.D.eval()
|
747 |
self.G2.eval()
|
748 |
#self.D2.eval()
|
749 |
+
|
750 |
+
step = self.inference_iterations
|
751 |
|
|
|
752 |
self.inf_dataset = DruggenDataset(self.mol_data_dir,
|
753 |
self.inf_dataset_file,
|
754 |
self.inf_raw_file,
|
|
|
782 |
#metric_calc_mol = []
|
783 |
metric_calc_dr = []
|
784 |
date = time.time()
|
785 |
+
if not os.path.exists("DrugGEN/experiments/inference/{}".format(self.submodel)):
|
786 |
+
os.makedirs("DrugGEN/experiments/inference/{}".format(self.submodel))
|
787 |
with torch.inference_mode():
|
788 |
|
789 |
+
dataloader_iterator = iter(self.inf_drugs_loader)
|
790 |
+
pbar = tqdm(range(self.inference_sample_num))
|
791 |
+
pbar.set_description('Inference mode for {} model started'.format(self.submodel))
|
792 |
+
for i, data in enumerate(self.inf_loader):
|
793 |
try:
|
794 |
drugs = next(dataloader_iterator)
|
795 |
except StopIteration:
|
796 |
+
dataloader_iterator = iter(self.inf_drugs_loader)
|
797 |
drugs = next(dataloader_iterator)
|
798 |
|
799 |
# Preprocess both dataset
|
800 |
|
801 |
bulk_data = load_data(data,
|
802 |
drugs,
|
803 |
+
self.inf_batch_size,
|
804 |
self.device,
|
805 |
self.b_dim,
|
806 |
self.m_dim,
|
|
|
839 |
GAN2_disc_e = drugs_a_tensor
|
840 |
GAN2_disc_x = drugs_x_tensor
|
841 |
elif self.submodel == "RL":
|
842 |
+
GAN1_input_e = a_tensor
|
843 |
+
GAN1_input_x = x_tensor
|
844 |
GAN1_disc_e = a_tensor
|
845 |
GAN1_disc_x = x_tensor
|
846 |
GAN2_input_e = drugs_a_tensor
|
|
|
848 |
GAN2_disc_e = drugs_a_tensor
|
849 |
GAN2_disc_x = drugs_x_tensor
|
850 |
elif self.submodel == "NoTarget":
|
851 |
+
GAN1_input_e = a_tensor
|
852 |
+
GAN1_input_x = x_tensor
|
853 |
GAN1_disc_e = a_tensor
|
854 |
GAN1_disc_x = x_tensor
|
855 |
# =================================================================================== #
|
|
|
860 |
self.V,
|
861 |
GAN1_input_e,
|
862 |
GAN1_input_x,
|
863 |
+
self.inf_batch_size,
|
864 |
sim_reward,
|
865 |
+
self.dataset.matrices2mol,
|
866 |
fps_r,
|
867 |
+
self.submodel,
|
868 |
+
self.dataset_name)
|
869 |
|
870 |
+
_, fake_mol_g, _, _, node, edge = generator_output
|
871 |
|
872 |
# =================================================================================== #
|
873 |
# 3. GAN2 Inference #
|
874 |
# =================================================================================== #
|
875 |
|
876 |
+
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
|
877 |
+
output = generator2_loss(self.G2,
|
878 |
+
self.D2,
|
879 |
+
self.V2,
|
880 |
+
edge,
|
881 |
+
node,
|
882 |
+
self.inf_batch_size,
|
883 |
+
sim_reward,
|
884 |
+
self.dataset.matrices2mol_drugs,
|
885 |
+
fps_r,
|
886 |
+
GAN2_input_e,
|
887 |
+
GAN2_input_x,
|
888 |
+
self.submodel,
|
889 |
+
self.drugs_name)
|
890 |
|
891 |
+
_, fake_mol_g, edges, nodes = output
|
892 |
|
893 |
inference_drugs = [Chem.MolToSmiles(line) for line in fake_mol_g if line is not None]
|
894 |
+
inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
|
895 |
|
896 |
+
with open("DrugGEN/experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
897 |
for molecules in inference_drugs:
|
898 |
|
899 |
f.write(molecules)
|
900 |
f.write("\n")
|
901 |
metric_calc_dr.append(molecules)
|
902 |
|
903 |
+
if len(inference_drugs) > 0:
|
904 |
+
pbar.update(1)
|
905 |
+
|
906 |
+
if len(metric_calc_dr) == self.inference_sample_num:
|
907 |
break
|
908 |
|
909 |
et = time.time() - start_time
|
|
|
912 |
|
913 |
print("Metrics calculation started using MOSES.")
|
914 |
|
915 |
+
print("Validity: ", fraction_valid(metric_calc_dr), "\n")
|
916 |
+
print("Uniqueness: ", fraction_unique(metric_calc_dr), "\n")
|
917 |
+
print("Validity: ", novelty(metric_calc_dr, drug_smiles), "\n")
|
918 |
|
919 |
+
print("Metrics are calculated.")
|