mgyigit commited on
Commit
f0ae72a
·
1 Parent(s): 486ddc0

Update trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +10 -12
trainer.py CHANGED
@@ -581,10 +581,10 @@ class Trainer(object):
581
  drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data
582
 
583
  if self.submodel == "CrossLoss":
584
- GAN1_input_e = drugs_a_tensor
585
- GAN1_input_x = drugs_x_tensor
586
- GAN1_disc_e = a_tensor
587
- GAN1_disc_x = x_tensor
588
  elif self.submodel == "Ligand":
589
  GAN1_input_e = a_tensor
590
  GAN1_input_x = x_tensor
@@ -737,11 +737,13 @@ class Trainer(object):
737
  drug_smiles = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
738
  else:
739
  drug_smiles = [line for line in open("data/akt_train.smi", 'r').read().splitlines()]
740
-
741
- drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
742
- drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
743
- fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
744
 
 
 
 
 
 
 
745
  akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
746
  akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
747
 
@@ -819,10 +821,6 @@ class Trainer(object):
819
  GAN1_input_x = x_tensor
820
  GAN1_disc_e = drugs_a_tensor
821
  GAN1_disc_x = drugs_x_tensor
822
- GAN2_input_e = drugs_a_tensor
823
- GAN2_input_x = drugs_x_tensor
824
- GAN2_disc_e = a_tensor
825
- GAN2_disc_x = x_tensor
826
  elif self.submodel == "Ligand":
827
  GAN1_input_e = a_tensor
828
  GAN1_input_x = x_tensor
 
581
  drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data
582
 
583
  if self.submodel == "CrossLoss":
584
+ GAN1_input_e = a_tensor
585
+ GAN1_input_x = x_tensor
586
+ GAN1_disc_e = drugs_a_tensor
587
+ GAN1_disc_x = drugs_x_tensor
588
  elif self.submodel == "Ligand":
589
  GAN1_input_e = a_tensor
590
  GAN1_input_x = x_tensor
 
737
  drug_smiles = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
738
  else:
739
  drug_smiles = [line for line in open("data/akt_train.smi", 'r').read().splitlines()]
 
 
 
 
740
 
741
+ if self.submodel == "RL":
742
+ drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
743
+ drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
744
+ fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
745
+ else:
746
+ fps_r = None
747
  akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
748
  akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
749
 
 
821
  GAN1_input_x = x_tensor
822
  GAN1_disc_e = drugs_a_tensor
823
  GAN1_disc_x = drugs_x_tensor
 
 
 
 
824
  elif self.submodel == "Ligand":
825
  GAN1_input_e = a_tensor
826
  GAN1_input_x = x_tensor