text-to-amr / app.py
Bram Vanroy
add description
55fbc57
raw
history blame
5.02 kB
from collections import Counter
import graphviz
import penman
from penman.models.noop import NoOpModel
from mbart_amr.data.linearization import linearized2penmanstr
from transformers import LogitsProcessorList
import streamlit as st
from utils import get_resources, LANGUAGES, translate
st.title("πŸ‘©β€πŸ’» Multilingual text to AMR")
with st.form("input data"):
text_col, lang_col = st.columns((4, 1))
text = text_col.text_input(label="Input text")
src_lang = lang_col.selectbox(label="Language", options=list(LANGUAGES.keys()), index=0)
submitted = st.form_submit_button("Submit")
if submitted:
multilingual = src_lang != "English"
model, tokenizer, logitsprocessor = get_resources(multilingual)
gen_kwargs = {
"max_length": model.config.max_length,
"num_beams": model.config.num_beams,
"logits_processor": LogitsProcessorList([logitsprocessor])
}
linearized = translate(text, src_lang, model, tokenizer, **gen_kwargs)
penman_str = linearized2penmanstr(linearized)
try:
graph = penman.decode(penman_str, model=NoOpModel())
except Exception as exc:
st.write(f"The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
f" to a valid graph but note that this is invalid Penman.")
st.code(penman_str)
with st.expander("Error trace"):
st.write(exc)
else:
visualized = graphviz.Digraph(node_attr={"color": "#3aafa9", "style": "rounded,filled", "shape": "box",
"fontcolor": "white"})
# Count which names occur multiple times, e.g. t/talk-01 t2/talk-01
nodename_c = Counter([item[2] for item in graph.triples if item[1] == ":instance"])
# Generated initial nodenames for each variable, e.g. {"t": "talk-01", "t2": "talk-01"}
nodenames = {item[0]: item[2] for item in graph.triples if item[1] == ":instance"}
# Modify nodenames, so that the values are unique, e.g. {"t": "talk-01 (1)", "t2": "talk-01 (2)"}
# but only the value occurs more than once
nodename_str_c = Counter()
for varname in nodenames:
nodename = nodenames[varname]
if nodename_c[nodename] > 1:
nodename_str_c[nodename] += 1
nodenames[varname] = f"{nodename} ({nodename_str_c[nodename]})"
def get_node_name(item: str):
return nodenames[item] if item in nodenames else item
try:
for triple in graph.triples:
if triple[1] == ":instance":
continue
else:
visualized.edge(get_node_name(triple[0]), get_node_name(triple[2]), label=triple[1])
except Exception as exc:
st.write("The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
" to a valid graph but note that this is probably invalid Penman.")
st.code(penman_str)
st.write("The initial linearized output of the model was:")
st.code(linearized)
with st.expander("Error trace"):
st.write(exc)
else:
st.subheader("Graph visualization")
st.graphviz_chart(visualized, use_container_width=True)
# Download
img = visualized.pipe(format="png")
st.download_button("Download graph", img, mime="image/png")
# Additional info
st.subheader("Model output and Penman graph")
st.write("The linearized output of the model (after some post-processing) is:")
st.code(linearized)
st.write("When converted into Penman, it looks like this:")
st.code(penman.encode(graph))
########################
# Information, socials #
########################
st.markdown("## Project: SignON 🀟")
st.markdown("""
<div style="display: flex">
<img style="margin-right: 1em" alt="SignON logo" src="https://signon-project.eu/wp-content/uploads/2021/05/SignOn_Favicon_500x500px.png" width=64 height=64>
<p><a href="https://signon-project.eu/" target="_blank" title="SignON homepage">SignON</a> aims to bridge the
communication gap between Deaf, hard of hearing and hearing people through an accessible translation service to
translate between languages and modalities with particular attention to sign languages.</p>
</div>
<p>This space and the accompanying models and public code are part of the SignON-project. AMR (abstract meaning
representation) is used as an interlingua to translate between modalities and languages.</p>
""", unsafe_allow_html=True)
st.markdown("## Contact βœ’οΈ")
st.markdown("Would you like additional functionality in the demo? Or just want to get in touch?"
" Give me a shout on [Twitter](https://twitter.com/BramVanroy)"
" or add me on [LinkedIn](https://www.linkedin.com/in/bramvanroy/)!")