File size: 5,173 Bytes
f3fd096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from collections import Counter

import graphviz
from optimum.bettertransformer import BetterTransformer
import penman
from penman.models.noop import NoOpModel
from mbart_amr.constraints.constraints import AMRLogitsProcessor
from mbart_amr.data.linearization import linearized2penmanstr
from mbart_amr.data.tokenization import AMRMBartTokenizer
from transformers import MBartForConditionalGeneration, LogitsProcessorList

import streamlit as st

if "logits_processor" not in st.session_state:
    st.session_state["logits_processor"] = None

if "tokenizer" not in st.session_state:
    st.session_state["tokenizer"] = None

if "model" not in st.session_state:
    st.session_state["tokenizer"] = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-to-amr", src_lang="en_XX")
    st.session_state["model"] = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-to-amr")
    st.session_state["model"] = BetterTransformer.transform(st.session_state["model"], keep_original_model=False)
    st.session_state["model"].resize_token_embeddings(len(st.session_state["tokenizer"]))
    st.session_state["logits_processor"] = AMRLogitsProcessor(st.session_state["tokenizer"],
                                                              st.session_state["model"].config.max_length)

st.title("πŸ“ Parse text into AMR")

text = st.text_input(label="Text to transform (en)")

if text and "model" in st.session_state:
    gen_kwargs = {
        "max_length": st.session_state["model"].config.max_length,
        "num_beams": st.session_state["model"].config.num_beams,
        "logits_processor": LogitsProcessorList([st.session_state["logits_processor"]]) if st.session_state[
            "logits_processor"] else None
    }

    encoded = st.session_state["tokenizer"](text, return_tensors="pt")
    generated = st.session_state["model"].generate(**encoded, **gen_kwargs)
    linearized = st.session_state["tokenizer"].decode_and_fix(generated)[0]
    penman_str = linearized2penmanstr(linearized)

    try:
        graph = penman.decode(penman_str, model=NoOpModel())
    except Exception as exc:
        st.write(f"The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
                 f" to a valid graph but note that this is invalid Penman.")
        st.code(penman_str)

        with st.expander("Error trace"):
            st.write(exc)
    else:
        visualized = graphviz.Digraph(node_attr={"color": "#3aafa9", "style": "rounded,filled", "shape": "box",
                                                 "fontcolor": "white"})

        # Count which names occur multiple times, e.g. t/talk-01 t2/talk-01
        nodename_c = Counter([item[2] for item in graph.triples if item[1] == ":instance"])
        # Generated initial nodenames for each variable, e.g. {"t": "talk-01",  "t2": "talk-01"}
        nodenames = {item[0]: item[2] for item in graph.triples if item[1] == ":instance"}

        # Modify nodenames, so that the values are unique, e.g. {"t": "talk-01 (1)",  "t2": "talk-01 (2)"}
        # but only the value occurs more than once
        nodename_str_c = Counter()
        for varname in nodenames:
            nodename = nodenames[varname]
            if nodename_c[nodename] > 1:
                nodename_str_c[nodename] += 1
                nodenames[varname] = f"{nodename} ({nodename_str_c[nodename]})"

        def get_node_name(item: str):
            return nodenames[item] if item in nodenames else item

        try:
            for triple in graph.triples:
                if triple[1] == ":instance":
                    continue
                else:
                    visualized.edge(get_node_name(triple[0]), get_node_name(triple[2]), label=triple[1])
        except Exception as exc:
            st.write("The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
                     " to a valid graph but note that this is probably invalid Penman.")
            st.code(penman_str)
            st.write("The initial linearized output of the model was:")
            st.code(linearized)

            with st.expander("Error trace"):
                st.write(exc)
        else:
            st.subheader("Graph visualization")
            st.graphviz_chart(visualized, use_container_width=True)

            # Download
            img = visualized.pipe(format="png")
            st.download_button("Download graph", img, mime="image/png")

            # Additional info
            st.subheader("Model output and Penman graph")
            st.write("The linearized output of the model (after some post-processing) is:")
            st.code(linearized)
            st.write("When converted into Penman, it looks like this:")
            st.code(penman.encode(graph))


########################
# Information, socials #
########################
st.markdown("## Contact βœ’οΈ")

st.markdown("Would you like  additional functionality in the demo? Or just want to get in touch?"
            " Give me a shout on [Twitter](https://twitter.com/BramVanroy)"
            " or add me on [LinkedIn](https://www.linkedin.com/in/bramvanroy/)!")