MHN-React / app.py
phseidl's picture
Update app.py
ff0e005
raw
history blame
4.72 kB
import gradio as gr
import pickle
from mhnreact.inspect import list_models, load_clf
from rdkit.Chem import rdChemReactions as Reaction
from rdkit.Chem.Draw import rdMolDraw2D
from PIL import Image, ImageDraw, ImageFont
from ssretro_template import ssretro, ssretro_custom
def custom_template_file(template: str):
temp = [x.strip() for x in template.split(',')]
template_dict = {}
for i in range(len(temp)):
template_dict[i] = temp[i]
with open('saved_dictionary.pkl', 'wb') as f:
pickle.dump(template_dict, f)
return template_dict
def get_output(p):
rxn = Reaction.ReactionFromSmarts(p, useSmiles=False)
d = rdMolDraw2D.MolDraw2DCairo(800, 200)
d.DrawReaction(rxn, highlightByReactant=False)
d.FinishDrawing()
text = d.GetDrawingText()
return text
def ssretro_prediction(molecule, custom_template=False):
model_fn = list_models()[0]
retro_clf = load_clf(model_fn)
predict, txt = [], []
if custom_template:
outputs = ssretro_custom(molecule, retro_clf)
else:
outputs = ssretro(molecule, retro_clf)
for pred in outputs:
txt.append(
f'predicted top-{pred["template_rank"] - 1}, template index: {pred["template_idx"]}, prob: {pred["prob"]: 2.1f}%;')
predict.append(get_output(pred["reaction"]))
return predict, txt
def mhn_react_backend(mol, use_custom: bool):
output_dir = "outputs"
formatter = "03d"
images = []
predictions, comments = ssretro_prediction(mol, use_custom)
for i in range(len(predictions)):
output_im = f"{str(output_dir)}/{format(i, formatter)}.png"
with open(output_im, "wb") as fh:
fh.write(predictions[i])
fh.close()
font = ImageFont.truetype(r'tools/arial.ttf', 20)
img = Image.open(output_im)
right = 10
left = 10
top = 50
bottom = 1
width, height = img.size
new_width = width + right + left
new_height = height + top + bottom
result = Image.new(img.mode, (new_width, new_height), (255, 255, 255))
result.paste(img, (left, top))
I1 = ImageDraw.Draw(result)
I1.text((20, 20), comments[i], font=font, fill=(0, 0, 0))
images.append(result)
result.save(output_im)
return images
with gr.Blocks() as demo:
gr.Markdown(
"""
[![arXiv](https://img.shields.io/badge/acs.jcim-1c01065-yellow.svg)](https://doi.org/10.1021/acs.jcim.1c01065)
[![arXiv](https://img.shields.io/badge/arXiv-2104.03279-b31b1b.svg)](https://arxiv.org/abs/2104.03279)
[![Python 3.7](https://img.shields.io/badge/python-3.7-blue.svg)](https://www.python.org/downloads/release/python-370/)
[![Pytorch](https://img.shields.io/badge/Pytorch-1.6-red.svg)](https://pytorch.org/get-started/previous-versions/)
[![License](https://img.shields.io/badge/License-BSD%202--Clause-orange.svg)](https://opensource.org/licenses/BSD-2-Clause)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-jku/mhn-react/blob/main/notebooks/colab_MHNreact_demo.ipynb)
### MHN-react
Adapting modern Hopfield networks (Ramsauer et al., 2021) (MHN) to associate different data modalities,
molecules and reaction templates, to improve predictive performance for rare templates and single-step retrosynthesis.
"""
)
with gr.Accordion("Guide"):
gr.Markdown("Information (add) <br> "
"In case the output is empty => No suitable templates?"
"use one of example molecules: <br> CC(=O)NCCC1=CNc2c1cc(OC)cc2"
)
with gr.Tab("Generate Templates"):
with gr.Row():
with gr.Column(scale = 1):
inp = gr.Textbox(placeholder="Input molecule in SMILES format", label="input molecule")
radio = gr.Radio([False, True], label="use custom templates")
btn = gr.Button(value="Generate")
with gr.Column(scale=2):
out = gr.Gallery(label="retro-synthesis")
btn.click(mhn_react_backend, [inp, radio], out)
with gr.Tab("Create custom templates"):
gr.Markdown(
"""
Input the templates separated by comma. <br> Please do not upload templates one-by-one
"""
)
with gr.Column():
inp_t = gr.Textbox(placeholder="custom template", label="add custom template(s)")
btn = gr.Button(value="upload")
out_t = gr.Textbox(label = "added templates")
btn.click(custom_template_file, inp_t, out_t)
demo.launch()