mgyigit commited on
Commit
5dabd12
·
1 Parent(s): 2952e53

Update trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +6 -4
trainer.py CHANGED
@@ -732,11 +732,12 @@ class Trainer(object):
732
  G2_path = os.path.join(self.inference_model, '{}-G2.ckpt'.format(self.submodel))
733
  self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
734
 
735
-
 
736
  if self.submodel == "NoTarget":
737
- drug_smiles = [line for line in open("data/chembl_test.smi", 'r').read().splitlines()]
738
  else:
739
- drug_smiles = [line for line in open("data/akt_train.smi", 'r').read().splitlines()]
740
 
741
  if self.submodel == "RL":
742
  drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
@@ -916,7 +917,8 @@ class Trainer(object):
916
  "Runtime (seconds)": round(et, 2),
917
  "Validity": f"{fraction_valid(metric_calc_dr)*100:.2f}%",
918
  "Uniqueness": f"{fraction_unique(metric_calc_dr)*100:.2f}%",
919
- "Novelty": f"{novelty(metric_calc_dr, drug_smiles)*100:.2f}%",
 
920
  }
921
  # print("Validity: ", fraction_valid(metric_calc_dr), "\n")
922
  # print("Uniqueness: ", fraction_unique(metric_calc_dr), "\n")
 
732
  G2_path = os.path.join(self.inference_model, '{}-G2.ckpt'.format(self.submodel))
733
  self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
734
 
735
+
736
+ smiles_test = [line for line in open("data/chembl_test.smi", 'r').read().splitlines()]
737
  if self.submodel == "NoTarget":
738
+ smiles_train = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
739
  else:
740
+ smiles_train = [line for line in open("data/akt_train.smi", 'r').read().splitlines()]
741
 
742
  if self.submodel == "RL":
743
  drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
 
917
  "Runtime (seconds)": round(et, 2),
918
  "Validity": f"{fraction_valid(metric_calc_dr)*100:.2f}%",
919
  "Uniqueness": f"{fraction_unique(metric_calc_dr)*100:.2f}%",
920
+ "Novelty Train": f"{novelty(metric_calc_dr, smiles_train)*100:.2f}%",
921
+ "Novelty Test": f"{novelty(metric_calc_dr, smiles_test)*100:.2f}%"
922
  }
923
  # print("Validity: ", fraction_valid(metric_calc_dr), "\n")
924
  # print("Uniqueness: ", fraction_unique(metric_calc_dr), "\n")