text-to-amr / app.py
Bram Vanroy
init space
f3fd096
raw
history blame
5.17 kB
from collections import Counter
import graphviz
from optimum.bettertransformer import BetterTransformer
import penman
from penman.models.noop import NoOpModel
from mbart_amr.constraints.constraints import AMRLogitsProcessor
from mbart_amr.data.linearization import linearized2penmanstr
from mbart_amr.data.tokenization import AMRMBartTokenizer
from transformers import MBartForConditionalGeneration, LogitsProcessorList
import streamlit as st
if "logits_processor" not in st.session_state:
st.session_state["logits_processor"] = None
if "tokenizer" not in st.session_state:
st.session_state["tokenizer"] = None
if "model" not in st.session_state:
st.session_state["tokenizer"] = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-to-amr", src_lang="en_XX")
st.session_state["model"] = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-to-amr")
st.session_state["model"] = BetterTransformer.transform(st.session_state["model"], keep_original_model=False)
st.session_state["model"].resize_token_embeddings(len(st.session_state["tokenizer"]))
st.session_state["logits_processor"] = AMRLogitsProcessor(st.session_state["tokenizer"],
st.session_state["model"].config.max_length)
st.title("πŸ“ Parse text into AMR")
text = st.text_input(label="Text to transform (en)")
if text and "model" in st.session_state:
gen_kwargs = {
"max_length": st.session_state["model"].config.max_length,
"num_beams": st.session_state["model"].config.num_beams,
"logits_processor": LogitsProcessorList([st.session_state["logits_processor"]]) if st.session_state[
"logits_processor"] else None
}
encoded = st.session_state["tokenizer"](text, return_tensors="pt")
generated = st.session_state["model"].generate(**encoded, **gen_kwargs)
linearized = st.session_state["tokenizer"].decode_and_fix(generated)[0]
penman_str = linearized2penmanstr(linearized)
try:
graph = penman.decode(penman_str, model=NoOpModel())
except Exception as exc:
st.write(f"The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
f" to a valid graph but note that this is invalid Penman.")
st.code(penman_str)
with st.expander("Error trace"):
st.write(exc)
else:
visualized = graphviz.Digraph(node_attr={"color": "#3aafa9", "style": "rounded,filled", "shape": "box",
"fontcolor": "white"})
# Count which names occur multiple times, e.g. t/talk-01 t2/talk-01
nodename_c = Counter([item[2] for item in graph.triples if item[1] == ":instance"])
# Generated initial nodenames for each variable, e.g. {"t": "talk-01", "t2": "talk-01"}
nodenames = {item[0]: item[2] for item in graph.triples if item[1] == ":instance"}
# Modify nodenames, so that the values are unique, e.g. {"t": "talk-01 (1)", "t2": "talk-01 (2)"}
# but only the value occurs more than once
nodename_str_c = Counter()
for varname in nodenames:
nodename = nodenames[varname]
if nodename_c[nodename] > 1:
nodename_str_c[nodename] += 1
nodenames[varname] = f"{nodename} ({nodename_str_c[nodename]})"
def get_node_name(item: str):
return nodenames[item] if item in nodenames else item
try:
for triple in graph.triples:
if triple[1] == ":instance":
continue
else:
visualized.edge(get_node_name(triple[0]), get_node_name(triple[2]), label=triple[1])
except Exception as exc:
st.write("The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
" to a valid graph but note that this is probably invalid Penman.")
st.code(penman_str)
st.write("The initial linearized output of the model was:")
st.code(linearized)
with st.expander("Error trace"):
st.write(exc)
else:
st.subheader("Graph visualization")
st.graphviz_chart(visualized, use_container_width=True)
# Download
img = visualized.pipe(format="png")
st.download_button("Download graph", img, mime="image/png")
# Additional info
st.subheader("Model output and Penman graph")
st.write("The linearized output of the model (after some post-processing) is:")
st.code(linearized)
st.write("When converted into Penman, it looks like this:")
st.code(penman.encode(graph))
########################
# Information, socials #
########################
st.markdown("## Contact βœ’οΈ")
st.markdown("Would you like additional functionality in the demo? Or just want to get in touch?"
" Give me a shout on [Twitter](https://twitter.com/BramVanroy)"
" or add me on [LinkedIn](https://www.linkedin.com/in/bramvanroy/)!")