import functools import html import logging import textwrap import traceback from pathlib import Path from typing import List import gradio as gr import pandas as pd from rxn.utilities.logging import setup_console_logger from rxn.utilities.strings import remove_postfix from utils import TranslatorWithSentencePiece, download_cde_data, split_into_sentences logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) _SAC_MODEL_TAG = "Heterogenous SAC model (ACE)" _ORGANIC_MODEL_TAG = "Organic chemistry model" model_type_to_models = { _SAC_MODEL_TAG: ["sac.pt"], _ORGANIC_MODEL_TAG: ["organic-1.pt", "organic-2.pt", "organic-3.pt"], } SYNTHESIS_TEXT_PLACEHOLDER = ( "Enter the synthesis procedure here, or click on one of the examples below." ) @functools.lru_cache def load_model(model_type: str) -> TranslatorWithSentencePiece: logger.info(f"Loading model {model_type}... ") model_files = model_type_to_models[model_type] sp_model = "sp_model.model" model = TranslatorWithSentencePiece( translation_model=model_files, sentencepiece_model=sp_model, ) logger.info(f"Loading model {model_type}... Done.") return model def sentence_and_actions_to_html( sentence: str, action_string: str, show_sentences: bool ) -> str: output = "" li_start = '
  • ' if show_sentences else "
  • " li_end = "
  • " action_string = remove_postfix(action_string, ".") if show_sentences: output += f"

    {sentence}

    " actions = [f"{li_start}{action}{li_end}" for action in action_string.split("; ")] output += "".join(actions) if show_sentences: # If we show the sentence, we need the list/enumeration delimiters, # as there is one list per sentence. output = f"
      {output}
    " return output def try_action_extraction(model_type: str, text: str, show_sentences: bool) -> str: logger.info(f'Extracting actions from paragraph "{textwrap.shorten(text, 60)}"...') download_cde_data() model = load_model(model_type) logger.info(f"Splitting paragraph into sentences...") sentences = split_into_sentences(text) logger.info(f"Splitting paragraph into sentences... Done.") logger.info(f"Translation with OpenNMT...") action_strings = model.translate(sentences) logger.info(f"Translation with OpenNMT... Done.") output = "" for sentence, action_string in zip(sentences, action_strings): output += sentence_and_actions_to_html(sentence, action_string, show_sentences) if not show_sentences: # If the sentences were not shown, we need to add the list/enumeration # delimiters here (globally) output = f"
      {output}
    " # PostTreatment was renamed to ThermalTreatment, old model still relies on the former output = output.replace("POSTTREATMENT", "THERMALTREATMENT") logger.info( f'Extracting actions from paragraph "{textwrap.shorten(text, 60)}"... Done.' ) return output def action_extraction(model_type: str, text: str, show_sentences: bool) -> str: try: return try_action_extraction(model_type, text, show_sentences) except Exception as e: tb = "".join(traceback.TracebackException.from_exception(e).format()) tb_html = f"
    {html.escape(tb)}
    " return f"

    Error! The action extraction failed: {e}

    {tb_html}" def launch() -> gr.Interface: logger.info("Launching the Gradio app") metadata_dir = Path(__file__).parent / "model_cards" examples_df: pd.DataFrame = pd.read_csv( metadata_dir / "sac_synthesis_mining_examples.csv", header=None ).fillna("") examples: List[List[str]] = examples_df.to_numpy().tolist() with open(metadata_dir / "sac_synthesis_mining_article.md", "r") as f: article = f.read() with open(metadata_dir / "sac_synthesis_mining_description.md", "r") as f: description = f.read() demo = gr.Interface( fn=action_extraction, title="Extraction of synthesis protocols from paragraphs in machine readable format", inputs=[ gr.Dropdown( [_SAC_MODEL_TAG, _ORGANIC_MODEL_TAG], label="Model", value=_SAC_MODEL_TAG, ), gr.Textbox(label="Synthesis text", lines=7, placeholder=SYNTHESIS_TEXT_PLACEHOLDER), gr.Checkbox(label="Show sentences in the output"), ], outputs=gr.HTML(label="Output"), article=article, description=description, examples=examples, allow_flagging="never", theme="gradio/base", ) demo.launch(debug=True, show_error=True) return demo setup_console_logger(level="INFO") demo = launch()