|
import functools |
|
import html |
|
import logging |
|
import textwrap |
|
import traceback |
|
from pathlib import Path |
|
from typing import List |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
from rxn.utilities.logging import setup_console_logger |
|
from rxn.utilities.strings import remove_postfix |
|
|
|
from utils import TranslatorWithSentencePiece, download_cde_data, split_into_sentences |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.addHandler(logging.NullHandler()) |
|
|
|
_SAC_MODEL_TAG = "Heterogenous SAC model (ACE)" |
|
_ORGANIC_MODEL_TAG = "Organic chemistry model" |
|
_PRETRAINED_MODEL_TAG = "Pretrained model" |
|
|
|
model_type_to_models = { |
|
_SAC_MODEL_TAG: ["sac.pt"], |
|
_ORGANIC_MODEL_TAG: ["organic-1.pt", "organic-2.pt", "organic-3.pt"], |
|
_PRETRAINED_MODEL_TAG: ["pretrained.pt"], |
|
} |
|
|
|
SYNTHESIS_TEXT_PLACEHOLDER = ( |
|
"Enter the synthesis procedure here, or click on one of the examples below." |
|
) |
|
|
|
@functools.lru_cache |
|
def load_model(model_type: str) -> TranslatorWithSentencePiece: |
|
logger.info(f"Loading model {model_type}... ") |
|
|
|
model_files = model_type_to_models[model_type] |
|
sp_model = "sp_model.model" |
|
|
|
model = TranslatorWithSentencePiece( |
|
translation_model=model_files, |
|
sentencepiece_model=sp_model, |
|
) |
|
logger.info(f"Loading model {model_type}... Done.") |
|
|
|
return model |
|
|
|
|
|
def sentence_and_actions_to_html( |
|
sentence: str, action_string: str, show_sentences: bool |
|
) -> str: |
|
output = "" |
|
li_start = '<li style="margin-left: 12px;">' if show_sentences else "<li>" |
|
li_end = "</li>" |
|
|
|
action_string = remove_postfix(action_string, ".") |
|
|
|
if show_sentences: |
|
output += f"<p>{sentence}</p>" |
|
|
|
actions = [f"{li_start}{action}{li_end}" for action in action_string.split("; ")] |
|
output += "".join(actions) |
|
|
|
if show_sentences: |
|
|
|
|
|
output = f"<ol>{output}</ol>" |
|
|
|
return output |
|
|
|
|
|
def try_action_extraction(model_type: str, text: str, show_sentences: bool) -> str: |
|
logger.info(f'Extracting actions from paragraph "{textwrap.shorten(text, 60)}"...') |
|
|
|
download_cde_data() |
|
|
|
model = load_model(model_type) |
|
|
|
logger.info(f"Splitting paragraph into sentences...") |
|
sentences = split_into_sentences(text) |
|
logger.info(f"Splitting paragraph into sentences... Done.") |
|
|
|
logger.info(f"Translation with OpenNMT...") |
|
action_strings = model.translate(sentences) |
|
logger.info(f"Translation with OpenNMT... Done.") |
|
|
|
output = "" |
|
for sentence, action_string in zip(sentences, action_strings): |
|
output += sentence_and_actions_to_html(sentence, action_string, show_sentences) |
|
|
|
if not show_sentences: |
|
|
|
|
|
output = f"<ol>{output}</ol>" |
|
|
|
|
|
output = output.replace("POSTTREATMENT", "THERMALTREATMENT") |
|
|
|
logger.info( |
|
f'Extracting actions from paragraph "{textwrap.shorten(text, 60)}"... Done.' |
|
) |
|
return output |
|
|
|
|
|
def action_extraction(model_type: str, text: str, show_sentences: bool) -> str: |
|
try: |
|
return try_action_extraction(model_type, text, show_sentences) |
|
except Exception as e: |
|
tb = "".join(traceback.TracebackException.from_exception(e).format()) |
|
tb_html = f"<pre>{html.escape(tb)}</pre>" |
|
return f"<p><b>Error!</b> The action extraction failed: {e}</p>{tb_html}" |
|
|
|
|
|
def launch() -> gr.Interface: |
|
logger.info("Launching the Gradio app") |
|
|
|
metadata_dir = Path(__file__).parent / "model_cards" |
|
|
|
examples_df: pd.DataFrame = pd.read_csv( |
|
metadata_dir / "sac_synthesis_mining_examples.csv", header=None |
|
).fillna("") |
|
examples: List[List[str]] = examples_df.to_numpy().tolist() |
|
|
|
with open(metadata_dir / "sac_synthesis_mining_article.md", "r") as f: |
|
article = f.read() |
|
with open(metadata_dir / "sac_synthesis_mining_description.md", "r") as f: |
|
description = f.read() |
|
|
|
demo = gr.Interface( |
|
fn=action_extraction, |
|
title="Extraction of synthesis protocols from paragraphs in machine readable format", |
|
inputs=[ |
|
gr.Dropdown( |
|
[_SAC_MODEL_TAG, _ORGANIC_MODEL_TAG, _PRETRAINED_MODEL_TAG], |
|
label="Model", |
|
value=_SAC_MODEL_TAG, |
|
), |
|
gr.Textbox(label="Synthesis text", lines=7, placeholder=SYNTHESIS_TEXT_PLACEHOLDER), |
|
gr.Checkbox(label="Show sentences in the output"), |
|
], |
|
outputs=gr.HTML(label="Output"), |
|
article=article, |
|
description=description, |
|
examples=examples, |
|
allow_flagging="never", |
|
theme="gradio/base", |
|
) |
|
demo.launch(debug=True, show_error=True) |
|
return demo |
|
|
|
|
|
setup_console_logger(level="INFO") |
|
demo = launch() |
|
|