File size: 4,918 Bytes
4d89100
 
 
80ffb8e
4d89100
 
 
 
 
 
 
 
 
80ffb8e
4d89100
 
 
 
 
 
8bbfef5
4d89100
 
 
 
8bbfef5
4d89100
 
275093a
 
 
4d89100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e49cab3
80ffb8e
 
 
4d89100
e49cab3
 
4d89100
e49cab3
 
 
4d89100
e49cab3
4d89100
 
 
 
 
 
 
 
 
 
 
 
 
e49cab3
 
 
4d89100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bbfef5
4d89100
 
 
275093a
4d89100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import functools
import html
import logging
import textwrap
import traceback
from pathlib import Path
from typing import List

import gradio as gr
import pandas as pd
from rxn.utilities.logging import setup_console_logger
from rxn.utilities.strings import remove_postfix

from utils import TranslatorWithSentencePiece, download_cde_data, split_into_sentences

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

_SAC_MODEL_TAG = "Heterogenous SAC model (ACE)"
_ORGANIC_MODEL_TAG = "Organic chemistry model"
_PRETRAINED_MODEL_TAG = "Pretrained model"

model_type_to_models = {
    _SAC_MODEL_TAG: ["sac.pt"],
    _ORGANIC_MODEL_TAG: ["organic-1.pt", "organic-2.pt", "organic-3.pt"],
    _PRETRAINED_MODEL_TAG: ["pretrained.pt"],
}

SYNTHESIS_TEXT_PLACEHOLDER = (
    "Enter the synthesis procedure here, or click on one of the examples below."
)

@functools.lru_cache
def load_model(model_type: str) -> TranslatorWithSentencePiece:
    logger.info(f"Loading model {model_type}... ")

    model_files = model_type_to_models[model_type]
    sp_model = "sp_model.model"

    model = TranslatorWithSentencePiece(
        translation_model=model_files,
        sentencepiece_model=sp_model,
    )
    logger.info(f"Loading model {model_type}... Done.")

    return model


def sentence_and_actions_to_html(
    sentence: str, action_string: str, show_sentences: bool
) -> str:
    output = ""
    li_start = '<li style="margin-left: 12px;">' if show_sentences else "<li>"
    li_end = "</li>"

    action_string = remove_postfix(action_string, ".")

    if show_sentences:
        output += f"<p>{sentence}</p>"

    actions = [f"{li_start}{action}{li_end}" for action in action_string.split("; ")]
    output += "".join(actions)

    if show_sentences:
        # If we show the sentence, we need the list/enumeration delimiters,
        # as there is one list per sentence.
        output = f"<ol>{output}</ol>"

    return output


def try_action_extraction(model_type: str, text: str, show_sentences: bool) -> str:
    logger.info(f'Extracting actions from paragraph "{textwrap.shorten(text, 60)}"...')

    download_cde_data()

    model = load_model(model_type)

    logger.info(f"Splitting paragraph into sentences...")
    sentences = split_into_sentences(text)
    logger.info(f"Splitting paragraph into sentences... Done.")

    logger.info(f"Translation with OpenNMT...")
    action_strings = model.translate(sentences)
    logger.info(f"Translation with OpenNMT... Done.")

    output = ""
    for sentence, action_string in zip(sentences, action_strings):
        output += sentence_and_actions_to_html(sentence, action_string, show_sentences)

    if not show_sentences:
        # If the sentences were not shown, we need to add the list/enumeration
        # delimiters here (globally)
        output = f"<ol>{output}</ol>"

    # PostTreatment was renamed to ThermalTreatment, old model still relies on the former
    output = output.replace("POSTTREATMENT", "THERMALTREATMENT")

    logger.info(
        f'Extracting actions from paragraph "{textwrap.shorten(text, 60)}"... Done.'
    )
    return output


def action_extraction(model_type: str, text: str, show_sentences: bool) -> str:
    try:
        return try_action_extraction(model_type, text, show_sentences)
    except Exception as e:
        tb = "".join(traceback.TracebackException.from_exception(e).format())
        tb_html = f"<pre>{html.escape(tb)}</pre>"
        return f"<p><b>Error!</b> The action extraction failed: {e}</p>{tb_html}"


def launch() -> gr.Interface:
    logger.info("Launching the Gradio app")

    metadata_dir = Path(__file__).parent / "model_cards"

    examples_df: pd.DataFrame = pd.read_csv(
        metadata_dir / "sac_synthesis_mining_examples.csv", header=None
    ).fillna("")
    examples: List[List[str]] = examples_df.to_numpy().tolist()

    with open(metadata_dir / "sac_synthesis_mining_article.md", "r") as f:
        article = f.read()
    with open(metadata_dir / "sac_synthesis_mining_description.md", "r") as f:
        description = f.read()

    demo = gr.Interface(
        fn=action_extraction,
        title="Extraction of synthesis protocols from paragraphs in machine readable format",
        inputs=[
            gr.Dropdown(
                [_SAC_MODEL_TAG, _ORGANIC_MODEL_TAG, _PRETRAINED_MODEL_TAG],
                label="Model",
                value=_SAC_MODEL_TAG,
            ),
            gr.Textbox(label="Synthesis text", lines=7, placeholder=SYNTHESIS_TEXT_PLACEHOLDER),
            gr.Checkbox(label="Show sentences in the output"),
        ],
        outputs=gr.HTML(label="Output"),
        article=article,
        description=description,
        examples=examples,
        allow_flagging="never",
        theme="gradio/base",
    )
    demo.launch(debug=True, show_error=True)
    return demo


setup_console_logger(level="INFO")
demo = launch()