File size: 4,918 Bytes
4d89100 80ffb8e 4d89100 80ffb8e 4d89100 8bbfef5 4d89100 8bbfef5 4d89100 275093a 4d89100 e49cab3 80ffb8e 4d89100 e49cab3 4d89100 e49cab3 4d89100 e49cab3 4d89100 e49cab3 4d89100 8bbfef5 4d89100 275093a 4d89100 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import functools
import html
import logging
import textwrap
import traceback
from pathlib import Path
from typing import List
import gradio as gr
import pandas as pd
from rxn.utilities.logging import setup_console_logger
from rxn.utilities.strings import remove_postfix
from utils import TranslatorWithSentencePiece, download_cde_data, split_into_sentences
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
_SAC_MODEL_TAG = "Heterogenous SAC model (ACE)"
_ORGANIC_MODEL_TAG = "Organic chemistry model"
_PRETRAINED_MODEL_TAG = "Pretrained model"
model_type_to_models = {
_SAC_MODEL_TAG: ["sac.pt"],
_ORGANIC_MODEL_TAG: ["organic-1.pt", "organic-2.pt", "organic-3.pt"],
_PRETRAINED_MODEL_TAG: ["pretrained.pt"],
}
SYNTHESIS_TEXT_PLACEHOLDER = (
"Enter the synthesis procedure here, or click on one of the examples below."
)
@functools.lru_cache
def load_model(model_type: str) -> TranslatorWithSentencePiece:
logger.info(f"Loading model {model_type}... ")
model_files = model_type_to_models[model_type]
sp_model = "sp_model.model"
model = TranslatorWithSentencePiece(
translation_model=model_files,
sentencepiece_model=sp_model,
)
logger.info(f"Loading model {model_type}... Done.")
return model
def sentence_and_actions_to_html(
sentence: str, action_string: str, show_sentences: bool
) -> str:
output = ""
li_start = '<li style="margin-left: 12px;">' if show_sentences else "<li>"
li_end = "</li>"
action_string = remove_postfix(action_string, ".")
if show_sentences:
output += f"<p>{sentence}</p>"
actions = [f"{li_start}{action}{li_end}" for action in action_string.split("; ")]
output += "".join(actions)
if show_sentences:
# If we show the sentence, we need the list/enumeration delimiters,
# as there is one list per sentence.
output = f"<ol>{output}</ol>"
return output
def try_action_extraction(model_type: str, text: str, show_sentences: bool) -> str:
logger.info(f'Extracting actions from paragraph "{textwrap.shorten(text, 60)}"...')
download_cde_data()
model = load_model(model_type)
logger.info(f"Splitting paragraph into sentences...")
sentences = split_into_sentences(text)
logger.info(f"Splitting paragraph into sentences... Done.")
logger.info(f"Translation with OpenNMT...")
action_strings = model.translate(sentences)
logger.info(f"Translation with OpenNMT... Done.")
output = ""
for sentence, action_string in zip(sentences, action_strings):
output += sentence_and_actions_to_html(sentence, action_string, show_sentences)
if not show_sentences:
# If the sentences were not shown, we need to add the list/enumeration
# delimiters here (globally)
output = f"<ol>{output}</ol>"
# PostTreatment was renamed to ThermalTreatment, old model still relies on the former
output = output.replace("POSTTREATMENT", "THERMALTREATMENT")
logger.info(
f'Extracting actions from paragraph "{textwrap.shorten(text, 60)}"... Done.'
)
return output
def action_extraction(model_type: str, text: str, show_sentences: bool) -> str:
try:
return try_action_extraction(model_type, text, show_sentences)
except Exception as e:
tb = "".join(traceback.TracebackException.from_exception(e).format())
tb_html = f"<pre>{html.escape(tb)}</pre>"
return f"<p><b>Error!</b> The action extraction failed: {e}</p>{tb_html}"
def launch() -> gr.Interface:
logger.info("Launching the Gradio app")
metadata_dir = Path(__file__).parent / "model_cards"
examples_df: pd.DataFrame = pd.read_csv(
metadata_dir / "sac_synthesis_mining_examples.csv", header=None
).fillna("")
examples: List[List[str]] = examples_df.to_numpy().tolist()
with open(metadata_dir / "sac_synthesis_mining_article.md", "r") as f:
article = f.read()
with open(metadata_dir / "sac_synthesis_mining_description.md", "r") as f:
description = f.read()
demo = gr.Interface(
fn=action_extraction,
title="Extraction of synthesis protocols from paragraphs in machine readable format",
inputs=[
gr.Dropdown(
[_SAC_MODEL_TAG, _ORGANIC_MODEL_TAG, _PRETRAINED_MODEL_TAG],
label="Model",
value=_SAC_MODEL_TAG,
),
gr.Textbox(label="Synthesis text", lines=7, placeholder=SYNTHESIS_TEXT_PLACEHOLDER),
gr.Checkbox(label="Show sentences in the output"),
],
outputs=gr.HTML(label="Output"),
article=article,
description=description,
examples=examples,
allow_flagging="never",
theme="gradio/base",
)
demo.launch(debug=True, show_error=True)
return demo
setup_console_logger(level="INFO")
demo = launch()
|