Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pickle | |
from mhnreact.inspect import list_models, load_clf | |
from rdkit.Chem import rdChemReactions as Reaction | |
from rdkit.Chem.Draw import rdMolDraw2D | |
from PIL import Image, ImageDraw, ImageFont | |
from ssretro_template import ssretro, ssretro_custom | |
def custom_template_file(template: str): | |
temp = [x.strip() for x in template.split(',')] | |
template_dict = {} | |
for i in range(len(temp)): | |
template_dict[i] = temp[i] | |
with open('saved_dictionary.pkl', 'wb') as f: | |
pickle.dump(template_dict, f) | |
return template_dict | |
def get_output(p): | |
rxn = Reaction.ReactionFromSmarts(p, useSmiles=False) | |
d = rdMolDraw2D.MolDraw2DCairo(800, 200) | |
d.DrawReaction(rxn, highlightByReactant=False) | |
d.FinishDrawing() | |
text = d.GetDrawingText() | |
return text | |
def ssretro_prediction(molecule, custom_template=False): | |
model_fn = list_models()[0] | |
retro_clf = load_clf(model_fn) | |
predict, txt = [], [] | |
if custom_template: | |
outputs = ssretro_custom(molecule, retro_clf) | |
else: | |
outputs = ssretro(molecule, retro_clf) | |
for pred in outputs: | |
txt.append( | |
f'predicted top-{pred["template_rank"] - 1}, template index: {pred["template_idx"]}, prob: {pred["prob"]: 2.1f}%;') | |
predict.append(get_output(pred["reaction"])) | |
return predict, txt | |
def mhn_react_backend(mol, use_custom: bool): | |
output_dir = "outputs" | |
formatter = "03d" | |
images = [] | |
predictions, comments = ssretro_prediction(mol, use_custom) | |
for i in range(len(predictions)): | |
output_im = f"{str(output_dir)}/{format(i, formatter)}.png" | |
with open(output_im, "wb") as fh: | |
fh.write(predictions[i]) | |
fh.close() | |
font = ImageFont.truetype(r'tools/arial.ttf', 20) | |
img = Image.open(output_im) | |
right = 10 | |
left = 10 | |
top = 50 | |
bottom = 1 | |
width, height = img.size | |
new_width = width + right + left | |
new_height = height + top + bottom | |
result = Image.new(img.mode, (new_width, new_height), (255, 255, 255)) | |
result.paste(img, (left, top)) | |
I1 = ImageDraw.Draw(result) | |
I1.text((20, 20), comments[i], font=font, fill=(0, 0, 0)) | |
images.append(result) | |
result.save(output_im) | |
return images | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
[![Github](https://img.shields.io/badge/github-%20mhn--react-blue)](https://img.shields.io/badge/github-%20mhn--react-blue) | |
[![arXiv](https://img.shields.io/badge/acs.jcim-1c01065-yellow.svg)](https://doi.org/10.1021/acs.jcim.1c01065) | |
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-jku/mhn-react/blob/main/notebooks/colab_MHNreact_demo.ipynb) | |
### MHN-react | |
Adapting modern Hopfield networks (Ramsauer et al., 2021) (MHN) to associate different data modalities, | |
molecules and reaction templates, to improve predictive performance for rare templates and single-step retrosynthesis. | |
""" | |
) | |
with gr.Accordion("Information"): | |
gr.Markdown("use one of example molecules <br> CC(=O)NCCC1=CNc2c1cc(OC)cc2, <br> CN1CCC[C@H]1c2cccnc2, <br> OCCc1c(C)[n+](cs1)Cc2cnc(C)nc2N" | |
"In case the output is empty, no applicable templates were found" | |
) | |
with gr.Tab("Generate Templates"): | |
with gr.Row(): | |
with gr.Column(scale = 1): | |
inp = gr.Textbox(placeholder="Input molecule in SMILES format", label="input molecule") | |
radio = gr.Radio([False, True], label="use custom templates") | |
btn = gr.Button(value="Generate") | |
with gr.Column(scale=2): | |
out = gr.Gallery(label="retro-synthesis") | |
btn.click(mhn_react_backend, [inp, radio], out) | |
with gr.Tab("Create custom templates"): | |
gr.Markdown( | |
""" | |
Input the templates separated by comma. <br> Please do not upload templates one-by-one | |
""" | |
) | |
with gr.Column(): | |
inp_t = gr.Textbox(placeholder="custom template", label="add custom template(s)") | |
btn = gr.Button(value="upload") | |
out_t = gr.Textbox(label = "added templates") | |
btn.click(custom_template_file, inp_t, out_t) | |
demo.launch(debug = True) | |