Spaces:
Running
Running
Bram Vanroy
commited on
Commit
Β·
05b9456
1
Parent(s):
f8b0e70
add check for empty input and show info/error
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from collections import Counter
|
2 |
|
3 |
import graphviz
|
@@ -18,83 +19,94 @@ with st.form("input data"):
|
|
18 |
src_lang = lang_col.selectbox(label="Language", options=list(LANGUAGES.keys()), index=0)
|
19 |
submitted = st.form_submit_button("Submit")
|
20 |
|
|
|
21 |
if submitted:
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
"max_length": model.config.max_length,
|
26 |
-
"num_beams": model.config.num_beams,
|
27 |
-
"logits_processor": LogitsProcessorList([logitsprocessor])
|
28 |
-
}
|
29 |
-
|
30 |
-
linearized = translate(text, src_lang, model, tokenizer, **gen_kwargs)
|
31 |
-
penman_str = linearized2penmanstr(linearized)
|
32 |
-
|
33 |
-
try:
|
34 |
-
graph = penman.decode(penman_str, model=NoOpModel())
|
35 |
-
except Exception as exc:
|
36 |
-
st.write(f"The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
|
37 |
-
f" to a valid graph but note that this is invalid Penman.")
|
38 |
-
st.code(penman_str)
|
39 |
-
|
40 |
-
with st.expander("Error trace"):
|
41 |
-
st.write(exc)
|
42 |
else:
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
nodename = nodenames[varname]
|
56 |
-
if nodename_c[nodename] > 1:
|
57 |
-
nodename_str_c[nodename] += 1
|
58 |
-
nodenames[varname] = f"{nodename} ({nodename_str_c[nodename]})"
|
59 |
-
|
60 |
-
def get_node_name(item: str):
|
61 |
-
return nodenames[item] if item in nodenames else item
|
62 |
|
63 |
try:
|
64 |
-
|
65 |
-
if triple[1] == ":instance":
|
66 |
-
continue
|
67 |
-
else:
|
68 |
-
visualized.edge(get_node_name(triple[0]), get_node_name(triple[2]), label=triple[1])
|
69 |
except Exception as exc:
|
70 |
-
st.write("The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
|
71 |
-
" to a valid graph but note that this is
|
72 |
st.code(penman_str)
|
73 |
-
st.write("The initial linearized output of the model was:")
|
74 |
-
st.code(linearized)
|
75 |
|
76 |
with st.expander("Error trace"):
|
77 |
st.write(exc)
|
78 |
else:
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
#
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
|
94 |
########################
|
95 |
# Information, socials #
|
96 |
########################
|
97 |
-
st.
|
98 |
|
99 |
st.markdown("""
|
100 |
<div style="display: flex">
|
@@ -108,7 +120,7 @@ st.markdown("""
|
|
108 |
""", unsafe_allow_html=True)
|
109 |
|
110 |
|
111 |
-
st.
|
112 |
|
113 |
st.markdown("Would you like additional functionality in the demo? Or just want to get in touch?"
|
114 |
" Give me a shout on [Twitter](https://twitter.com/BramVanroy)"
|
|
|
1 |
+
import base64
|
2 |
from collections import Counter
|
3 |
|
4 |
import graphviz
|
|
|
19 |
src_lang = lang_col.selectbox(label="Language", options=list(LANGUAGES.keys()), index=0)
|
20 |
submitted = st.form_submit_button("Submit")
|
21 |
|
22 |
+
error_ct = st.empty()
|
23 |
if submitted:
|
24 |
+
text = text.strip()
|
25 |
+
if not text:
|
26 |
+
error_ct.error("Text cannot be empty!", icon="β οΈ")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
else:
|
28 |
+
error_ct.info("Generating abstract meaning representation (AMR)...", icon="π»")
|
29 |
+
multilingual = src_lang != "English"
|
30 |
+
model, tokenizer, logitsprocessor = get_resources(multilingual)
|
31 |
+
gen_kwargs = {
|
32 |
+
"max_length": model.config.max_length,
|
33 |
+
"num_beams": model.config.num_beams,
|
34 |
+
"logits_processor": LogitsProcessorList([logitsprocessor])
|
35 |
+
}
|
36 |
+
|
37 |
+
linearized = translate(text, src_lang, model, tokenizer, **gen_kwargs)
|
38 |
+
penman_str = linearized2penmanstr(linearized)
|
39 |
+
error_ct.empty()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
try:
|
42 |
+
graph = penman.decode(penman_str, model=NoOpModel())
|
|
|
|
|
|
|
|
|
43 |
except Exception as exc:
|
44 |
+
st.write(f"The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
|
45 |
+
f" to a valid graph but note that this is invalid Penman.")
|
46 |
st.code(penman_str)
|
|
|
|
|
47 |
|
48 |
with st.expander("Error trace"):
|
49 |
st.write(exc)
|
50 |
else:
|
51 |
+
visualized = graphviz.Digraph(node_attr={"color": "#3aafa9", "style": "rounded,filled", "shape": "box",
|
52 |
+
"fontcolor": "white"})
|
53 |
+
|
54 |
+
# Count which names occur multiple times, e.g. t/talk-01 t2/talk-01
|
55 |
+
nodename_c = Counter([item[2] for item in graph.triples if item[1] == ":instance"])
|
56 |
+
# Generated initial nodenames for each variable, e.g. {"t": "talk-01", "t2": "talk-01"}
|
57 |
+
nodenames = {item[0]: item[2] for item in graph.triples if item[1] == ":instance"}
|
58 |
+
|
59 |
+
# Modify nodenames, so that the values are unique, e.g. {"t": "talk-01 (1)", "t2": "talk-01 (2)"}
|
60 |
+
# but only the value occurs more than once
|
61 |
+
nodename_str_c = Counter()
|
62 |
+
for varname in nodenames:
|
63 |
+
nodename = nodenames[varname]
|
64 |
+
if nodename_c[nodename] > 1:
|
65 |
+
nodename_str_c[nodename] += 1
|
66 |
+
nodenames[varname] = f"{nodename} ({nodename_str_c[nodename]})"
|
67 |
+
|
68 |
+
def get_node_name(item: str):
|
69 |
+
return nodenames[item] if item in nodenames else item
|
70 |
+
|
71 |
+
try:
|
72 |
+
for triple in graph.triples:
|
73 |
+
if triple[1] == ":instance":
|
74 |
+
continue
|
75 |
+
else:
|
76 |
+
visualized.edge(get_node_name(triple[0]), get_node_name(triple[2]), label=triple[1])
|
77 |
+
except Exception as exc:
|
78 |
+
st.write("The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
|
79 |
+
" to a valid graph but note that this is probably invalid Penman.")
|
80 |
+
st.code(penman_str)
|
81 |
+
st.write("The initial linearized output of the model was:")
|
82 |
+
st.code(linearized)
|
83 |
+
|
84 |
+
with st.expander("Error trace"):
|
85 |
+
st.write(exc)
|
86 |
+
else:
|
87 |
+
st.subheader("Graph visualization")
|
88 |
+
st.graphviz_chart(visualized, use_container_width=True)
|
89 |
+
|
90 |
+
# Download link
|
91 |
+
def create_download_link(img_bytes: bytes):
|
92 |
+
encoded = base64.b64encode(img_bytes).decode("utf-8")
|
93 |
+
return f'<a href="data:image/png;charset=utf-8;base64,{encoded}" download="amr-graph.png">Download graph</a>'
|
94 |
+
|
95 |
+
img = visualized.pipe(format="png")
|
96 |
+
st.markdown(create_download_link(img), unsafe_allow_html=True)
|
97 |
+
|
98 |
+
# Additional info
|
99 |
+
st.subheader("Model output and Penman graph")
|
100 |
+
st.write("The linearized output of the model (after some post-processing) is:")
|
101 |
+
st.code(linearized)
|
102 |
+
st.write("When converted into Penman, it looks like this:")
|
103 |
+
st.code(penman.encode(graph))
|
104 |
|
105 |
|
106 |
########################
|
107 |
# Information, socials #
|
108 |
########################
|
109 |
+
st.header("Project: SignON π€")
|
110 |
|
111 |
st.markdown("""
|
112 |
<div style="display: flex">
|
|
|
120 |
""", unsafe_allow_html=True)
|
121 |
|
122 |
|
123 |
+
st.header("Contact βοΈ")
|
124 |
|
125 |
st.markdown("Would you like additional functionality in the demo? Or just want to get in touch?"
|
126 |
" Give me a shout on [Twitter](https://twitter.com/BramVanroy)"
|