File size: 4,344 Bytes
4d89100
 
 
80ffb8e
4d89100
 
 
 
 
 
 
 
 
80ffb8e
4d89100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80ffb8e
 
 
 
4d89100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import functools
import html
import logging
import textwrap
import traceback
from pathlib import Path
from typing import List

import gradio as gr
import pandas as pd
from rxn.utilities.logging import setup_console_logger
from rxn.utilities.strings import remove_postfix

from utils import TranslatorWithSentencePiece, download_cde_data, split_into_sentences

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

_SAC_MODEL_TAG = "Heterogenous SAC model (ACE)"
_ORGANIC_MODEL_TAG = "Organic chemistry model"

model_type_to_models = {
    _SAC_MODEL_TAG: ["sac.pt"],
    _ORGANIC_MODEL_TAG: ["organic-1.pt", "organic-2.pt", "organic-3.pt"],
}


@functools.lru_cache
def load_model(model_type: str) -> TranslatorWithSentencePiece:
    logger.info(f"Loading model {model_type}... ")

    model_files = model_type_to_models[model_type]
    sp_model = "sp_model.model"

    model = TranslatorWithSentencePiece(
        translation_model=model_files,
        sentencepiece_model=sp_model,
    )
    logger.info(f"Loading model {model_type}... Done.")

    return model


def sentence_and_actions_to_html(
    sentence: str, action_string: str, show_sentences: bool
) -> str:
    output = ""
    li_start = '<li style="margin-left: 12px;">' if show_sentences else "<li>"
    li_end = "</li>"

    action_string = remove_postfix(action_string, ".")

    if show_sentences:
        output += f"<p>{sentence}</p>"

    actions = [f"{li_start}{action}{li_end}" for action in action_string.split("; ")]
    output += "".join(actions)

    if show_sentences:
        # If we show the sentence, we need the list/enumeration delimiters,
        # as there is one list per sentence.
        output = f"<ol>{output}</ol>"

    return output


def try_action_extraction(model_type: str, text: str, show_sentences: bool) -> str:
    logger.info(f'Extracting actions from paragraph "{textwrap.shorten(text, 60)}".')

    download_cde_data()

    model = load_model(model_type)
    sentences = split_into_sentences(text)
    action_strings = model.translate(sentences)

    output = ""
    for sentence, action_string in zip(sentences, action_strings):
        output += sentence_and_actions_to_html(sentence, action_string, show_sentences)

    if not show_sentences:
        # If the sentences were not shown, we need to add the list/enumeration
        # delimiters here (globally)
        output = f"<ol>{output}</ol>"

    # PostTreatment was renamed to ThermalTreatment, old model still relies on the former
    output = output.replace("POSTTREATMENT", "THERMALTREATMENT")

    return output


def action_extraction(model_type: str, text: str, show_sentences: bool) -> str:
    try:
        return try_action_extraction(model_type, text, show_sentences)
    except Exception as e:
        tb = "".join(traceback.TracebackException.from_exception(e).format())
        tb_html = f"<pre>{html.escape(tb)}</pre>"
        return f"<p><b>Error!</b> The action extraction failed: {e}</p>{tb_html}"


def launch() -> gr.Interface:
    logger.info("Launching the Gradio app")

    metadata_dir = Path(__file__).parent / "model_cards"

    examples_df: pd.DataFrame = pd.read_csv(
        metadata_dir / "sac_synthesis_mining_examples.csv", header=None
    ).fillna("")
    examples: List[List[str]] = examples_df.to_numpy().tolist()

    with open(metadata_dir / "sac_synthesis_mining_article.md", "r") as f:
        article = f.read()
    with open(metadata_dir / "sac_synthesis_mining_description.md", "r") as f:
        description = f.read()

    demo = gr.Interface(
        fn=action_extraction,
        title="Extraction of synthesis protocols from paragraphs in machine readable format",
        inputs=[
            gr.Dropdown(
                [_SAC_MODEL_TAG, _ORGANIC_MODEL_TAG],
                label="Model",
                value=_SAC_MODEL_TAG,
            ),
            gr.Textbox(label="Synthesis text", lines=7, placeholder=examples[0][1]),
            gr.Checkbox(label="Show sentences in the output"),
        ],
        outputs=gr.HTML(label="Output"),
        article=article,
        description=description,
        examples=examples,
        allow_flagging="never",
        theme="gradio/base",
    )
    demo.launch(debug=True, show_error=True)
    return demo


setup_console_logger(level="INFO")
demo = launch()