Spaces:
Running
Running
Update trainer.py
Browse files- trainer.py +10 -12
trainer.py
CHANGED
@@ -581,10 +581,10 @@ class Trainer(object):
|
|
581 |
drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data
|
582 |
|
583 |
if self.submodel == "CrossLoss":
|
584 |
-
GAN1_input_e =
|
585 |
-
GAN1_input_x =
|
586 |
-
GAN1_disc_e =
|
587 |
-
GAN1_disc_x =
|
588 |
elif self.submodel == "Ligand":
|
589 |
GAN1_input_e = a_tensor
|
590 |
GAN1_input_x = x_tensor
|
@@ -737,11 +737,13 @@ class Trainer(object):
|
|
737 |
drug_smiles = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
|
738 |
else:
|
739 |
drug_smiles = [line for line in open("data/akt_train.smi", 'r').read().splitlines()]
|
740 |
-
|
741 |
-
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
742 |
-
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
743 |
-
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
744 |
|
|
|
|
|
|
|
|
|
|
|
|
|
745 |
akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
746 |
akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
747 |
|
@@ -819,10 +821,6 @@ class Trainer(object):
|
|
819 |
GAN1_input_x = x_tensor
|
820 |
GAN1_disc_e = drugs_a_tensor
|
821 |
GAN1_disc_x = drugs_x_tensor
|
822 |
-
GAN2_input_e = drugs_a_tensor
|
823 |
-
GAN2_input_x = drugs_x_tensor
|
824 |
-
GAN2_disc_e = a_tensor
|
825 |
-
GAN2_disc_x = x_tensor
|
826 |
elif self.submodel == "Ligand":
|
827 |
GAN1_input_e = a_tensor
|
828 |
GAN1_input_x = x_tensor
|
|
|
581 |
drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data
|
582 |
|
583 |
if self.submodel == "CrossLoss":
|
584 |
+
GAN1_input_e = a_tensor
|
585 |
+
GAN1_input_x = x_tensor
|
586 |
+
GAN1_disc_e = drugs_a_tensor
|
587 |
+
GAN1_disc_x = drugs_x_tensor
|
588 |
elif self.submodel == "Ligand":
|
589 |
GAN1_input_e = a_tensor
|
590 |
GAN1_input_x = x_tensor
|
|
|
737 |
drug_smiles = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
|
738 |
else:
|
739 |
drug_smiles = [line for line in open("data/akt_train.smi", 'r').read().splitlines()]
|
|
|
|
|
|
|
|
|
740 |
|
741 |
+
if self.submodel == "RL":
|
742 |
+
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
743 |
+
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
744 |
+
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
745 |
+
else:
|
746 |
+
fps_r = None
|
747 |
akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
748 |
akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
749 |
|
|
|
821 |
GAN1_input_x = x_tensor
|
822 |
GAN1_disc_e = drugs_a_tensor
|
823 |
GAN1_disc_x = drugs_x_tensor
|
|
|
|
|
|
|
|
|
824 |
elif self.submodel == "Ligand":
|
825 |
GAN1_input_e = a_tensor
|
826 |
GAN1_input_x = x_tensor
|