Alain Vaucher
Explicitly download the CDE data; add logs
80ffb8e
raw
history blame
4.34 kB
import functools
import html
import logging
import textwrap
import traceback
from pathlib import Path
from typing import List
import gradio as gr
import pandas as pd
from rxn.utilities.logging import setup_console_logger
from rxn.utilities.strings import remove_postfix
from utils import TranslatorWithSentencePiece, download_cde_data, split_into_sentences
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
_SAC_MODEL_TAG = "Heterogenous SAC model (ACE)"
_ORGANIC_MODEL_TAG = "Organic chemistry model"
model_type_to_models = {
_SAC_MODEL_TAG: ["sac.pt"],
_ORGANIC_MODEL_TAG: ["organic-1.pt", "organic-2.pt", "organic-3.pt"],
}
@functools.lru_cache
def load_model(model_type: str) -> TranslatorWithSentencePiece:
logger.info(f"Loading model {model_type}... ")
model_files = model_type_to_models[model_type]
sp_model = "sp_model.model"
model = TranslatorWithSentencePiece(
translation_model=model_files,
sentencepiece_model=sp_model,
)
logger.info(f"Loading model {model_type}... Done.")
return model
def sentence_and_actions_to_html(
sentence: str, action_string: str, show_sentences: bool
) -> str:
output = ""
li_start = '<li style="margin-left: 12px;">' if show_sentences else "<li>"
li_end = "</li>"
action_string = remove_postfix(action_string, ".")
if show_sentences:
output += f"<p>{sentence}</p>"
actions = [f"{li_start}{action}{li_end}" for action in action_string.split("; ")]
output += "".join(actions)
if show_sentences:
# If we show the sentence, we need the list/enumeration delimiters,
# as there is one list per sentence.
output = f"<ol>{output}</ol>"
return output
def try_action_extraction(model_type: str, text: str, show_sentences: bool) -> str:
logger.info(f'Extracting actions from paragraph "{textwrap.shorten(text, 60)}".')
download_cde_data()
model = load_model(model_type)
sentences = split_into_sentences(text)
action_strings = model.translate(sentences)
output = ""
for sentence, action_string in zip(sentences, action_strings):
output += sentence_and_actions_to_html(sentence, action_string, show_sentences)
if not show_sentences:
# If the sentences were not shown, we need to add the list/enumeration
# delimiters here (globally)
output = f"<ol>{output}</ol>"
# PostTreatment was renamed to ThermalTreatment, old model still relies on the former
output = output.replace("POSTTREATMENT", "THERMALTREATMENT")
return output
def action_extraction(model_type: str, text: str, show_sentences: bool) -> str:
try:
return try_action_extraction(model_type, text, show_sentences)
except Exception as e:
tb = "".join(traceback.TracebackException.from_exception(e).format())
tb_html = f"<pre>{html.escape(tb)}</pre>"
return f"<p><b>Error!</b> The action extraction failed: {e}</p>{tb_html}"
def launch() -> gr.Interface:
logger.info("Launching the Gradio app")
metadata_dir = Path(__file__).parent / "model_cards"
examples_df: pd.DataFrame = pd.read_csv(
metadata_dir / "sac_synthesis_mining_examples.csv", header=None
).fillna("")
examples: List[List[str]] = examples_df.to_numpy().tolist()
with open(metadata_dir / "sac_synthesis_mining_article.md", "r") as f:
article = f.read()
with open(metadata_dir / "sac_synthesis_mining_description.md", "r") as f:
description = f.read()
demo = gr.Interface(
fn=action_extraction,
title="Extraction of synthesis protocols from paragraphs in machine readable format",
inputs=[
gr.Dropdown(
[_SAC_MODEL_TAG, _ORGANIC_MODEL_TAG],
label="Model",
value=_SAC_MODEL_TAG,
),
gr.Textbox(label="Synthesis text", lines=7, placeholder=examples[0][1]),
gr.Checkbox(label="Show sentences in the output"),
],
outputs=gr.HTML(label="Output"),
article=article,
description=description,
examples=examples,
allow_flagging="never",
theme="gradio/base",
)
demo.launch(debug=True, show_error=True)
return demo
setup_console_logger(level="INFO")
demo = launch()