RxnScribe / app.py
thomas0809
add markdown and examples
c157f5a
raw
history blame
3.03 kB
import gradio as gr
import os
import glob
import cv2
import numpy as np
import torch
from rxnscribe import RxnScribe
from huggingface_hub import hf_hub_download
REPO_ID = "yujieq/RxnScribe"
FILENAME = "pix2seq_reaction_full.ckpt"
ckpt_path = hf_hub_download(REPO_ID, FILENAME)
device = torch.device('cpu')
model = RxnScribe(ckpt_path, device)
def get_markdown(reaction):
output = []
for x in ['reactants', 'conditions', 'products']:
s = ''
for ent in reaction[x]:
if 'smiles' in ent:
s += ent['smiles'] + '<br>'
elif 'text' in ent:
s += ' '.join(ent['text']) + '<br>'
else:
s += ent['category']
output.append(s)
return output
def predict(image, molscribe, ocr):
predictions = model.predict_image(image, molscribe=molscribe, ocr=ocr)
pred_image = model.draw_predictions_combined(predictions, image=image)
markdown = [[i] + get_markdown(reaction) for i, reaction in enumerate(predictions)]
return pred_image, markdown
with gr.Blocks() as demo:
gr.Markdown("""
<center> <h1>RxnScribe</h1> </center>
Extract chemical reactions from a diagram. Please upload a reaction diagram, RxnScribe will predict the reaction structures in the diagram.
The predicted reactions are visualized in separate images.
<span style="color:red">**Red** boxes are <ins>*reactants*</ins>.</span>
<span style="color:green">**Green** boxes are <ins>*reaction conditions*</ins>.</span>
<span style="color:blue">**Blue** boxes are <ins>*products*</ins>.</span>
It usually takes 10-20 seconds to process a diagram with this demo.
Check the options to run [MolScribe](https://huggingface.co/spaces/yujieq/MolScribe) and [OCR](https://huggingface.co/spaces/tomofi/EasyOCR) (it will take a longer time, of course).
Code: https://github.com/thomas0809/RxnScribe
Authors: [Yujie Qian](mailto:[email protected]), Jiang Guo, Zhengkai Tu, Connor W. Coley, Regina Barzilay. _MIT CSAIL_.
""")
with gr.Column():
with gr.Row():
image = gr.Image(label="Upload reaction diagram", show_label=False, type='pil').style(height=256)
with gr.Row():
molscribe = gr.Checkbox(label="Run MolScribe to recognize molecule structures")
ocr = gr.Checkbox(label="Run OCR to recognize text")
btn = gr.Button("Submit").style(full_width=False)
with gr.Row():
gallery = gr.Image(label='Predicted reactions', show_label=True).style(height='auto')
markdown = gr.Dataframe(
headers=['#', 'reactant', 'condition', 'product'],
datatype=['number'] + ['markdown'] * 3,
wrap=False
)
btn.click(predict, inputs=[image, molscribe, ocr], outputs=[gallery, markdown])
gr.Examples(
examples=sorted(glob.glob('examples/*.png')),
inputs=[image],
outputs=[gallery, markdown],
fn=predict,
)
demo.launch()