Spaces:
Running
Running
Update inference.py
Browse files- inference.py +45 -22
inference.py
CHANGED
@@ -136,7 +136,7 @@ class Inference(object):
|
|
136 |
"""Restore the trained generator and discriminator."""
|
137 |
print('Loading the model...')
|
138 |
G_path = os.path.join(model_directory, '{}-G.ckpt'.format(submodel))
|
139 |
-
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage
|
140 |
|
141 |
def inference(self):
|
142 |
# Load the trained generator.
|
@@ -170,7 +170,9 @@ class Inference(object):
|
|
170 |
uniqueness_calc = []
|
171 |
real_smiles_snn = []
|
172 |
nodes_sample = torch.Tensor(size=[1, self.vertexes, 1]).to(self.device)
|
173 |
-
|
|
|
|
|
174 |
val_counter = 0
|
175 |
none_counter = 0
|
176 |
|
@@ -179,6 +181,7 @@ class Inference(object):
|
|
179 |
pbar = tqdm(range(self.sample_num))
|
180 |
pbar.set_description('Inference mode for {} model started'.format(self.submodel))
|
181 |
for i, data in enumerate(self.inf_loader):
|
|
|
182 |
val_counter += 1
|
183 |
# Preprocess dataset
|
184 |
_, a_tensor, x_tensor = load_molecules(
|
@@ -206,13 +209,14 @@ class Inference(object):
|
|
206 |
inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
|
207 |
|
208 |
for molecules in inference_drugs:
|
209 |
-
|
210 |
-
|
211 |
|
212 |
for molecules in inference_drugs:
|
213 |
if molecules is not None:
|
214 |
-
molecules = molecules.replace("*", "C")
|
215 |
-
|
|
|
216 |
uniqueness_calc.append(molecules)
|
217 |
nodes_sample = torch.cat((nodes_sample, g_nodes_hat_sample.view(1, self.vertexes, 1)), 0)
|
218 |
pbar.update(1)
|
@@ -223,21 +227,30 @@ class Inference(object):
|
|
223 |
if generation_number == self.sample_num or none_counter == self.sample_num:
|
224 |
break
|
225 |
|
|
|
|
|
226 |
if not self.disable_correction:
|
227 |
-
|
228 |
-
gen_smi =
|
|
|
229 |
else:
|
230 |
-
gen_smi =
|
231 |
-
|
|
|
232 |
et = time.time() - start_time
|
233 |
|
234 |
gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), 2, nBits=1024) for x in uniqueness_calc if Chem.MolFromSmiles(x) is not None]
|
235 |
real_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_smiles_snn if x is not None]
|
|
|
236 |
|
|
|
|
|
237 |
if not self.disable_correction:
|
238 |
val = round(len(gen_smi)/self.sample_num, 3)
|
|
|
239 |
else:
|
240 |
val = round(fraction_valid(gen_smi), 3)
|
|
|
241 |
|
242 |
uniq = round(fraction_unique(gen_smi), 3)
|
243 |
nov = round(novelty(gen_smi, chembl_smiles), 3)
|
@@ -251,23 +264,33 @@ class Inference(object):
|
|
251 |
qed = round(np.mean([QED.qed(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
|
252 |
sa = round(np.mean([sascorer.calculateScore(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
|
253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
model_res = pd.DataFrame({"submodel": [self.submodel], "validity": [val],
|
255 |
"uniqueness": [uniq], "novelty": [nov],
|
256 |
"novelty_test": [nov_test], "drug_novelty": [drug_nov],
|
257 |
"max_len": [max_len], "mean_atom_type": [mean_atom],
|
258 |
"snn_chembl": [snn_chembl], "snn_drug": [snn_drug],
|
259 |
"IntDiv": [int_div], "qed": [qed], "sa": [sa]})
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
return model_res
|
269 |
-
|
270 |
-
|
271 |
if __name__=="__main__":
|
272 |
parser = argparse.ArgumentParser()
|
273 |
|
@@ -300,4 +323,4 @@ if __name__=="__main__":
|
|
300 |
|
301 |
config = parser.parse_args()
|
302 |
inference = Inference(config)
|
303 |
-
inference.inference()
|
|
|
136 |
"""Restore the trained generator and discriminator."""
|
137 |
print('Loading the model...')
|
138 |
G_path = os.path.join(model_directory, '{}-G.ckpt'.format(submodel))
|
139 |
+
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
140 |
|
141 |
def inference(self):
|
142 |
# Load the trained generator.
|
|
|
170 |
uniqueness_calc = []
|
171 |
real_smiles_snn = []
|
172 |
nodes_sample = torch.Tensor(size=[1, self.vertexes, 1]).to(self.device)
|
173 |
+
f = open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "w")
|
174 |
+
f.write("SMILES")
|
175 |
+
f.write("\n")
|
176 |
val_counter = 0
|
177 |
none_counter = 0
|
178 |
|
|
|
181 |
pbar = tqdm(range(self.sample_num))
|
182 |
pbar.set_description('Inference mode for {} model started'.format(self.submodel))
|
183 |
for i, data in enumerate(self.inf_loader):
|
184 |
+
|
185 |
val_counter += 1
|
186 |
# Preprocess dataset
|
187 |
_, a_tensor, x_tensor = load_molecules(
|
|
|
209 |
inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
|
210 |
|
211 |
for molecules in inference_drugs:
|
212 |
+
if molecules is None:
|
213 |
+
none_counter += 1
|
214 |
|
215 |
for molecules in inference_drugs:
|
216 |
if molecules is not None:
|
217 |
+
molecules = molecules.replace("*", "C")
|
218 |
+
f.write(molecules)
|
219 |
+
f.write("\n")
|
220 |
uniqueness_calc.append(molecules)
|
221 |
nodes_sample = torch.cat((nodes_sample, g_nodes_hat_sample.view(1, self.vertexes, 1)), 0)
|
222 |
pbar.update(1)
|
|
|
227 |
if generation_number == self.sample_num or none_counter == self.sample_num:
|
228 |
break
|
229 |
|
230 |
+
f.close()
|
231 |
+
print("Inference completed, starting metrics calculation.")
|
232 |
if not self.disable_correction:
|
233 |
+
corrected = correct.correct("experiments/inference/{}/inference_drugs.txt".format(self.submodel))
|
234 |
+
gen_smi = corrected["SMILES"].tolist()
|
235 |
+
|
236 |
else:
|
237 |
+
gen_smi = pd.read_csv("experiments/inference/{}/inference_drugs.txt".format(self.submodel))["SMILES"].tolist()
|
238 |
+
|
239 |
+
|
240 |
et = time.time() - start_time
|
241 |
|
242 |
gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), 2, nBits=1024) for x in uniqueness_calc if Chem.MolFromSmiles(x) is not None]
|
243 |
real_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_smiles_snn if x is not None]
|
244 |
+
print("Inference mode is lasted for {:.2f} seconds".format(et))
|
245 |
|
246 |
+
print("Metrics calculation started using MOSES.")
|
247 |
+
|
248 |
if not self.disable_correction:
|
249 |
val = round(len(gen_smi)/self.sample_num, 3)
|
250 |
+
print("Validity: ", val, "\n")
|
251 |
else:
|
252 |
val = round(fraction_valid(gen_smi), 3)
|
253 |
+
print("Validity: ", val, "\n")
|
254 |
|
255 |
uniq = round(fraction_unique(gen_smi), 3)
|
256 |
nov = round(novelty(gen_smi, chembl_smiles), 3)
|
|
|
264 |
qed = round(np.mean([QED.qed(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
|
265 |
sa = round(np.mean([sascorer.calculateScore(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
|
266 |
|
267 |
+
print("Uniqueness: ", uniq, "\n")
|
268 |
+
print("Novelty: ", nov, "\n")
|
269 |
+
print("Novelty_test: ", nov_test, "\n")
|
270 |
+
print("Drug_novelty: ", drug_nov, "\n")
|
271 |
+
print("max_len: ", max_len, "\n")
|
272 |
+
print("mean_atom_type: ", mean_atom, "\n")
|
273 |
+
print("snn_chembl: ", snn_chembl, "\n")
|
274 |
+
print("snn_drug: ", snn_drug, "\n")
|
275 |
+
print("IntDiv: ", int_div, "\n")
|
276 |
+
print("QED: ", qed, "\n")
|
277 |
+
print("SA: ", sa, "\n")
|
278 |
+
|
279 |
+
print("Metrics are calculated.")
|
280 |
model_res = pd.DataFrame({"submodel": [self.submodel], "validity": [val],
|
281 |
"uniqueness": [uniq], "novelty": [nov],
|
282 |
"novelty_test": [nov_test], "drug_novelty": [drug_nov],
|
283 |
"max_len": [max_len], "mean_atom_type": [mean_atom],
|
284 |
"snn_chembl": [snn_chembl], "snn_drug": [snn_drug],
|
285 |
"IntDiv": [int_div], "qed": [qed], "sa": [sa]})
|
286 |
+
search_res = pd.concat([search_res, model_res], axis=0)
|
287 |
+
os.remove("experiments/inference/{}/inference_drugs.txt".format(self.submodel))
|
288 |
+
search_res.to_csv("experiments/inference/{}/inference_results.csv".format(self.submodel), index=False)
|
289 |
+
generatedsmiles = pd.DataFrame({"SMILES": gen_smi})
|
290 |
+
generatedsmiles.to_csv("experiments/inference/{}/inference_drugs.csv".format(self.submodel), index=False)
|
291 |
+
|
292 |
+
return model_res
|
293 |
+
|
|
|
|
|
|
|
294 |
if __name__=="__main__":
|
295 |
parser = argparse.ArgumentParser()
|
296 |
|
|
|
323 |
|
324 |
config = parser.parse_args()
|
325 |
inference = Inference(config)
|
326 |
+
inference.inference()
|