mgyigit commited on
Commit
c724df9
·
verified ·
1 Parent(s): f41dc00

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +45 -22
inference.py CHANGED
@@ -136,7 +136,7 @@ class Inference(object):
136
  """Restore the trained generator and discriminator."""
137
  print('Loading the model...')
138
  G_path = os.path.join(model_directory, '{}-G.ckpt'.format(submodel))
139
- self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage, weights_only=False))
140
 
141
  def inference(self):
142
  # Load the trained generator.
@@ -170,7 +170,9 @@ class Inference(object):
170
  uniqueness_calc = []
171
  real_smiles_snn = []
172
  nodes_sample = torch.Tensor(size=[1, self.vertexes, 1]).to(self.device)
173
- generated_smiles = []
 
 
174
  val_counter = 0
175
  none_counter = 0
176
 
@@ -179,6 +181,7 @@ class Inference(object):
179
  pbar = tqdm(range(self.sample_num))
180
  pbar.set_description('Inference mode for {} model started'.format(self.submodel))
181
  for i, data in enumerate(self.inf_loader):
 
182
  val_counter += 1
183
  # Preprocess dataset
184
  _, a_tensor, x_tensor = load_molecules(
@@ -206,13 +209,14 @@ class Inference(object):
206
  inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
207
 
208
  for molecules in inference_drugs:
209
- if molecules is None:
210
- none_counter += 1
211
 
212
  for molecules in inference_drugs:
213
  if molecules is not None:
214
- molecules = molecules.replace("*", "C")
215
- generated_smiles.append(molecules)
 
216
  uniqueness_calc.append(molecules)
217
  nodes_sample = torch.cat((nodes_sample, g_nodes_hat_sample.view(1, self.vertexes, 1)), 0)
218
  pbar.update(1)
@@ -223,21 +227,30 @@ class Inference(object):
223
  if generation_number == self.sample_num or none_counter == self.sample_num:
224
  break
225
 
 
 
226
  if not self.disable_correction:
227
- correct = smi_correct(self.submodel, "experiments/inference/{}".format(self.submodel))
228
- gen_smi = correct.correct_smiles_list(generated_smiles)
 
229
  else:
230
- gen_smi = generated_smiles
231
-
 
232
  et = time.time() - start_time
233
 
234
  gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), 2, nBits=1024) for x in uniqueness_calc if Chem.MolFromSmiles(x) is not None]
235
  real_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_smiles_snn if x is not None]
 
236
 
 
 
237
  if not self.disable_correction:
238
  val = round(len(gen_smi)/self.sample_num, 3)
 
239
  else:
240
  val = round(fraction_valid(gen_smi), 3)
 
241
 
242
  uniq = round(fraction_unique(gen_smi), 3)
243
  nov = round(novelty(gen_smi, chembl_smiles), 3)
@@ -251,23 +264,33 @@ class Inference(object):
251
  qed = round(np.mean([QED.qed(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
252
  sa = round(np.mean([sascorer.calculateScore(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  model_res = pd.DataFrame({"submodel": [self.submodel], "validity": [val],
255
  "uniqueness": [uniq], "novelty": [nov],
256
  "novelty_test": [nov_test], "drug_novelty": [drug_nov],
257
  "max_len": [max_len], "mean_atom_type": [mean_atom],
258
  "snn_chembl": [snn_chembl], "snn_drug": [snn_drug],
259
  "IntDiv": [int_div], "qed": [qed], "sa": [sa]})
260
-
261
- # Write generated SMILES to a temporary file for app.py to use
262
- temp_file = f'{self.submodel}_denovo_mols.smi'
263
- with open(temp_file, 'w') as f:
264
- f.write("SMILES\n")
265
- for smiles in gen_smi:
266
- f.write(f"{smiles}\n")
267
-
268
- return model_res
269
-
270
-
271
  if __name__=="__main__":
272
  parser = argparse.ArgumentParser()
273
 
@@ -300,4 +323,4 @@ if __name__=="__main__":
300
 
301
  config = parser.parse_args()
302
  inference = Inference(config)
303
- inference.inference()
 
136
  """Restore the trained generator and discriminator."""
137
  print('Loading the model...')
138
  G_path = os.path.join(model_directory, '{}-G.ckpt'.format(submodel))
139
+ self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
140
 
141
  def inference(self):
142
  # Load the trained generator.
 
170
  uniqueness_calc = []
171
  real_smiles_snn = []
172
  nodes_sample = torch.Tensor(size=[1, self.vertexes, 1]).to(self.device)
173
+ f = open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "w")
174
+ f.write("SMILES")
175
+ f.write("\n")
176
  val_counter = 0
177
  none_counter = 0
178
 
 
181
  pbar = tqdm(range(self.sample_num))
182
  pbar.set_description('Inference mode for {} model started'.format(self.submodel))
183
  for i, data in enumerate(self.inf_loader):
184
+
185
  val_counter += 1
186
  # Preprocess dataset
187
  _, a_tensor, x_tensor = load_molecules(
 
209
  inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
210
 
211
  for molecules in inference_drugs:
212
+ if molecules is None:
213
+ none_counter += 1
214
 
215
  for molecules in inference_drugs:
216
  if molecules is not None:
217
+ molecules = molecules.replace("*", "C")
218
+ f.write(molecules)
219
+ f.write("\n")
220
  uniqueness_calc.append(molecules)
221
  nodes_sample = torch.cat((nodes_sample, g_nodes_hat_sample.view(1, self.vertexes, 1)), 0)
222
  pbar.update(1)
 
227
  if generation_number == self.sample_num or none_counter == self.sample_num:
228
  break
229
 
230
+ f.close()
231
+ print("Inference completed, starting metrics calculation.")
232
  if not self.disable_correction:
233
+ corrected = correct.correct("experiments/inference/{}/inference_drugs.txt".format(self.submodel))
234
+ gen_smi = corrected["SMILES"].tolist()
235
+
236
  else:
237
+ gen_smi = pd.read_csv("experiments/inference/{}/inference_drugs.txt".format(self.submodel))["SMILES"].tolist()
238
+
239
+
240
  et = time.time() - start_time
241
 
242
  gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), 2, nBits=1024) for x in uniqueness_calc if Chem.MolFromSmiles(x) is not None]
243
  real_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_smiles_snn if x is not None]
244
+ print("Inference mode is lasted for {:.2f} seconds".format(et))
245
 
246
+ print("Metrics calculation started using MOSES.")
247
+
248
  if not self.disable_correction:
249
  val = round(len(gen_smi)/self.sample_num, 3)
250
+ print("Validity: ", val, "\n")
251
  else:
252
  val = round(fraction_valid(gen_smi), 3)
253
+ print("Validity: ", val, "\n")
254
 
255
  uniq = round(fraction_unique(gen_smi), 3)
256
  nov = round(novelty(gen_smi, chembl_smiles), 3)
 
264
  qed = round(np.mean([QED.qed(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
265
  sa = round(np.mean([sascorer.calculateScore(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
266
 
267
+ print("Uniqueness: ", uniq, "\n")
268
+ print("Novelty: ", nov, "\n")
269
+ print("Novelty_test: ", nov_test, "\n")
270
+ print("Drug_novelty: ", drug_nov, "\n")
271
+ print("max_len: ", max_len, "\n")
272
+ print("mean_atom_type: ", mean_atom, "\n")
273
+ print("snn_chembl: ", snn_chembl, "\n")
274
+ print("snn_drug: ", snn_drug, "\n")
275
+ print("IntDiv: ", int_div, "\n")
276
+ print("QED: ", qed, "\n")
277
+ print("SA: ", sa, "\n")
278
+
279
+ print("Metrics are calculated.")
280
  model_res = pd.DataFrame({"submodel": [self.submodel], "validity": [val],
281
  "uniqueness": [uniq], "novelty": [nov],
282
  "novelty_test": [nov_test], "drug_novelty": [drug_nov],
283
  "max_len": [max_len], "mean_atom_type": [mean_atom],
284
  "snn_chembl": [snn_chembl], "snn_drug": [snn_drug],
285
  "IntDiv": [int_div], "qed": [qed], "sa": [sa]})
286
+ search_res = pd.concat([search_res, model_res], axis=0)
287
+ os.remove("experiments/inference/{}/inference_drugs.txt".format(self.submodel))
288
+ search_res.to_csv("experiments/inference/{}/inference_results.csv".format(self.submodel), index=False)
289
+ generatedsmiles = pd.DataFrame({"SMILES": gen_smi})
290
+ generatedsmiles.to_csv("experiments/inference/{}/inference_drugs.csv".format(self.submodel), index=False)
291
+
292
+ return model_res
293
+
 
 
 
294
  if __name__=="__main__":
295
  parser = argparse.ArgumentParser()
296
 
 
323
 
324
  config = parser.parse_args()
325
  inference = Inference(config)
326
+ inference.inference()