import gradio as gr import os import glob import cv2 import numpy as np import torch from rxnscribe import RxnScribe from huggingface_hub import hf_hub_download REPO_ID = "yujieq/RxnScribe" FILENAME = "pix2seq_reaction_full.ckpt" ckpt_path = hf_hub_download(REPO_ID, FILENAME) device = torch.device('cpu') model = RxnScribe(ckpt_path, device) def get_markdown(reaction): output = [] for x in ['reactants', 'conditions', 'products']: s = '' for ent in reaction[x]: if 'smiles' in ent: s += ent['smiles'] + '
' elif 'text' in ent: s += ' '.join(ent['text']) + '
' else: s += ent['category'] output.append(s) return output def predict(image, molscribe, ocr): predictions = model.predict_image(image, molscribe=molscribe, ocr=ocr) pred_image = model.draw_predictions_combined(predictions, image=image) markdown = [[i] + get_markdown(reaction) for i, reaction in enumerate(predictions)] return pred_image, markdown with gr.Blocks() as demo: gr.Markdown("""

RxnScribe

Extract chemical reactions from a diagram. Please upload a reaction diagram, RxnScribe will predict the reaction structures in the diagram. The predicted reactions are visualized in separate images. **Red** boxes are *reactants*. **Green** boxes are *reaction conditions*. **Blue** boxes are *products*. It usually takes 10-20 seconds to process a diagram with this demo. Check the options to run [MolScribe](https://huggingface.co/spaces/yujieq/MolScribe) and [OCR](https://huggingface.co/spaces/tomofi/EasyOCR) (it will take a longer time, of course). Code: https://github.com/thomas0809/RxnScribe Authors: [Yujie Qian](mailto:yujieq@csail.mit.edu), Jiang Guo, Zhengkai Tu, Connor W. Coley, Regina Barzilay. _MIT CSAIL_. """) with gr.Column(): with gr.Row(): image = gr.Image(label="Upload reaction diagram", show_label=False, type='pil').style(height=256) with gr.Row(): molscribe = gr.Checkbox(label="Run MolScribe to recognize molecule structures") ocr = gr.Checkbox(label="Run OCR to recognize text") btn = gr.Button("Submit").style(full_width=False) with gr.Row(): gallery = gr.Image(label='Predicted reactions', show_label=True).style(height='auto') markdown = gr.Dataframe( headers=['#', 'reactant', 'condition', 'product'], datatype=['number'] + ['markdown'] * 3, wrap=False ) btn.click(predict, inputs=[image, molscribe, ocr], outputs=[gallery, markdown]) gr.Examples( examples=sorted(glob.glob('examples/*.png')), inputs=[image], outputs=[gallery, markdown], fn=predict, ) demo.launch()