Spaces:
Running
Running
Bram Vanroy
commited on
Commit
Β·
05de9a6
1
Parent(s):
4d85339
Make stateful
Browse filesThis also gets rid of this issue which https://github.com/streamlit/streamlit/issues/6451
- app.py +47 -65
- requirements.txt +7 -4
- utils.py +64 -26
app.py
CHANGED
@@ -3,10 +3,7 @@ from collections import Counter
|
|
3 |
|
4 |
import graphviz
|
5 |
import penman
|
6 |
-
from
|
7 |
-
from penman.models.noop import NoOpModel
|
8 |
-
import streamlit as st
|
9 |
-
from transformers import LogitsProcessorList
|
10 |
|
11 |
from utils import get_resources, LANGUAGES, translate
|
12 |
|
@@ -17,43 +14,41 @@ st.set_page_config(
|
|
17 |
page_icon="π©βπ»"
|
18 |
)
|
19 |
|
20 |
-
st.title("π©βπ» Multilingual text to AMR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
27 |
|
28 |
error_ct = st.empty()
|
29 |
-
if
|
30 |
-
|
31 |
-
|
32 |
-
error_ct.error("Text cannot be empty!", icon="β οΈ")
|
33 |
-
else:
|
34 |
error_ct.info("Generating abstract meaning representation (AMR)...", icon="π»")
|
35 |
-
|
36 |
-
model, tokenizer, logitsprocessor = get_resources(multilingual)
|
37 |
gen_kwargs = {
|
38 |
-
"
|
39 |
-
"num_beams":
|
40 |
-
"logits_processor": LogitsProcessorList([logitsprocessor])
|
41 |
}
|
42 |
|
43 |
-
|
44 |
-
penman_str = linearized2penmanstr(linearized)
|
45 |
error_ct.empty()
|
46 |
|
47 |
-
|
48 |
-
graph
|
49 |
-
except Exception as exc:
|
50 |
-
st.write(f"The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
|
51 |
-
f" to a valid graph but note that this is invalid Penman.")
|
52 |
-
st.code(penman_str)
|
53 |
-
|
54 |
-
with st.expander("Error trace"):
|
55 |
-
st.write(exc)
|
56 |
else:
|
|
|
57 |
visualized = graphviz.Digraph(node_attr={"color": "#3aafa9", "style": "rounded,filled", "shape": "box",
|
58 |
"fontcolor": "white"})
|
59 |
|
@@ -74,40 +69,27 @@ if submitted:
|
|
74 |
def get_node_name(item: str):
|
75 |
return nodenames[item] if item in nodenames else item
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
encoded = base64.b64encode(img_bytes).decode("utf-8")
|
99 |
-
return f'<a href="data:image/png;charset=utf-8;base64,{encoded}" download="amr-graph.png">Download graph</a>'
|
100 |
-
|
101 |
-
img = visualized.pipe(format="png")
|
102 |
-
st.markdown(create_download_link(img), unsafe_allow_html=True)
|
103 |
-
|
104 |
-
# Additional info
|
105 |
-
st.subheader("Model output and Penman graph")
|
106 |
-
st.write("The linearized output of the model (after some post-processing) is:")
|
107 |
-
st.code(linearized)
|
108 |
-
st.write("When converted into Penman, it looks like this:")
|
109 |
-
st.code(penman.encode(graph))
|
110 |
-
|
111 |
|
112 |
########################
|
113 |
# Information, socials #
|
|
|
3 |
|
4 |
import graphviz
|
5 |
import penman
|
6 |
+
from multi_amr.data.postprocessing_graph import ParsedStatus
|
|
|
|
|
|
|
7 |
|
8 |
from utils import get_resources, LANGUAGES, translate
|
9 |
|
|
|
14 |
page_icon="π©βπ»"
|
15 |
)
|
16 |
|
17 |
+
st.title("π©βπ» Multilingual text to AMR")
|
18 |
+
|
19 |
+
if "text" not in st.session_state:
|
20 |
+
st.session_state["text"] = ""
|
21 |
+
if "language" not in st.session_state:
|
22 |
+
st.session_state["language"] = "English"
|
23 |
+
if "use_multilingual" not in st.session_state:
|
24 |
+
st.session_state["use_multilingual"] = False
|
25 |
|
26 |
+
text_col, lang_col = st.columns((4, 1))
|
27 |
+
text = text_col.text_input(label="Input text", key="text")
|
28 |
+
src_lang = lang_col.selectbox(label="Language", options=list(LANGUAGES.keys()), index=0, key="language")
|
29 |
+
multilingual = st.checkbox("Use multilingual model", label_visibility="visible", key="use_multilingual",
|
30 |
+
help="Whether to use a single multilingual model that was trained on English, Spanish and"
|
31 |
+
" Dutch together, or (if not checked) language-specific models. Enabling this will"
|
32 |
+
" results in worse performance but can be of interest for research purposes.")
|
33 |
|
34 |
error_ct = st.empty()
|
35 |
+
if st.session_state["text"]:
|
36 |
+
if st.button("Submit"):
|
37 |
+
text = text.strip()
|
|
|
|
|
38 |
error_ct.info("Generating abstract meaning representation (AMR)...", icon="π»")
|
39 |
+
model, tokenizer = get_resources(multilingual, src_lang)
|
|
|
40 |
gen_kwargs = {
|
41 |
+
"max_new_tokens": 512,
|
42 |
+
"num_beams": 5,
|
|
|
43 |
}
|
44 |
|
45 |
+
outputs = translate(text, src_lang, model, tokenizer, **gen_kwargs)
|
|
|
46 |
error_ct.empty()
|
47 |
|
48 |
+
if outputs["status"][0] == ParsedStatus.BACKOFF:
|
49 |
+
st.write(f"The system could not generate a valid graph no matter how hard it tried.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
else:
|
51 |
+
graph = outputs["graph"][0]
|
52 |
visualized = graphviz.Digraph(node_attr={"color": "#3aafa9", "style": "rounded,filled", "shape": "box",
|
53 |
"fontcolor": "white"})
|
54 |
|
|
|
69 |
def get_node_name(item: str):
|
70 |
return nodenames[item] if item in nodenames else item
|
71 |
|
72 |
+
for triple in graph.triples:
|
73 |
+
if triple[1] == ":instance":
|
74 |
+
continue
|
75 |
+
else:
|
76 |
+
visualized.edge(get_node_name(triple[0]), get_node_name(triple[2]), label=triple[1])
|
77 |
+
st.subheader("Graph visualization")
|
78 |
+
st.graphviz_chart(visualized, use_container_width=True)
|
79 |
+
|
80 |
+
# Download link
|
81 |
+
def create_download_link(img_bytes: bytes):
|
82 |
+
encoded = base64.b64encode(img_bytes).decode("utf-8")
|
83 |
+
return f'<a href="data:image/png;charset=utf-8;base64,{encoded}" download="amr-graph.png">Download graph</a>'
|
84 |
+
|
85 |
+
img = visualized.pipe(format="png")
|
86 |
+
st.markdown(create_download_link(img), unsafe_allow_html=True)
|
87 |
+
|
88 |
+
# Additional info
|
89 |
+
st.subheader("PENMAN representation")
|
90 |
+
st.code(penman.encode(graph))
|
91 |
+
else:
|
92 |
+
error_ct.warning("Text cannot be empty!", icon="β οΈ")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
########################
|
95 |
# Information, socials #
|
requirements.txt
CHANGED
@@ -1,7 +1,10 @@
|
|
|
|
1 |
altair==4.2.2
|
2 |
graphviz==0.20.1
|
3 |
-
|
|
|
4 |
penman==1.2.2
|
5 |
-
streamlit==1.
|
6 |
-
torch==
|
7 |
-
|
|
|
|
1 |
+
accelerate==0.22.0
|
2 |
altair==4.2.2
|
3 |
graphviz==0.20.1
|
4 |
+
multi_amr @ git+https://github.com/BramVanroy/multilingual-text-to-amr@v1.0.0-alpha
|
5 |
+
optimum==1.10.1
|
6 |
penman==1.2.2
|
7 |
+
streamlit==1.26.0
|
8 |
+
torch==2.0.1
|
9 |
+
transformers==4.33.1
|
10 |
+
wheel
|
utils.py
CHANGED
@@ -1,63 +1,101 @@
|
|
1 |
-
from typing import Tuple
|
2 |
|
|
|
|
|
3 |
from optimum.bettertransformer import BetterTransformer
|
4 |
-
from mbart_amr.constraints.constraints import AMRLogitsProcessor
|
5 |
-
from mbart_amr.data.tokenization import AMRMBartTokenizer
|
6 |
import streamlit as st
|
7 |
import torch
|
8 |
from torch.quantization import quantize_dynamic
|
9 |
from torch import nn, qint8
|
10 |
-
from transformers import MBartForConditionalGeneration
|
|
|
11 |
|
12 |
|
13 |
@st.cache_resource(show_spinner=False)
|
14 |
-
def get_resources(multilingual: bool, quantize: bool = True, no_cuda: bool = False) -> Tuple[MBartForConditionalGeneration,
|
15 |
"""Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual
|
16 |
model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized
|
17 |
for better performance.
|
18 |
|
19 |
-
:param multilingual: whether
|
|
|
20 |
:param quantize: whether to quantize the model with PyTorch's 'quantize_dynamic'
|
21 |
:param no_cuda: whether to disable CUDA, even if it is available
|
22 |
-
:return: the loaded model,
|
23 |
"""
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
model = BetterTransformer.transform(model, keep_original_model=False)
|
33 |
-
model.resize_token_embeddings(len(tokenizer))
|
34 |
|
35 |
if torch.cuda.is_available() and not no_cuda:
|
36 |
model = model.to("cuda")
|
37 |
elif quantize: # Quantization not supported on CUDA
|
38 |
model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)
|
39 |
|
40 |
-
|
41 |
|
42 |
-
return model, tokenizer, logits_processor
|
43 |
|
44 |
-
|
45 |
-
def translate(text: str, src_lang: str, model: MBartForConditionalGeneration, tokenizer: AMRMBartTokenizer, **gen_kwargs) -> str:
|
46 |
"""Translates a given text of a given source language with a given model and tokenizer. The generation is guided by
|
47 |
potential keyword-arguments, which can include arguments such as max length, logits processors, etc.
|
48 |
|
49 |
-
:param
|
50 |
:param src_lang: source language
|
51 |
:param model: MBART model
|
52 |
-
:param
|
53 |
:param gen_kwargs: potential keyword arguments for the generation process
|
54 |
:return: the translation (linearized AMR graph)
|
55 |
"""
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
|
63 |
LANGUAGES = {
|
|
|
1 |
+
from typing import Tuple, Union, Dict, List
|
2 |
|
3 |
+
from multi_amr.data.postprocessing_graph import ParsedStatus
|
4 |
+
from multi_amr.data.tokenization import AMRTokenizerWrapper
|
5 |
from optimum.bettertransformer import BetterTransformer
|
|
|
|
|
6 |
import streamlit as st
|
7 |
import torch
|
8 |
from torch.quantization import quantize_dynamic
|
9 |
from torch import nn, qint8
|
10 |
+
from transformers import MBartForConditionalGeneration, AutoConfig
|
11 |
+
import penman
|
12 |
|
13 |
|
14 |
@st.cache_resource(show_spinner=False)
|
15 |
+
def get_resources(multilingual: bool, src_lang: str, quantize: bool = True, no_cuda: bool = False) -> Tuple[MBartForConditionalGeneration, AMRTokenizerWrapper]:
|
16 |
"""Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual
|
17 |
model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized
|
18 |
for better performance.
|
19 |
|
20 |
+
:param multilingual: whether to load the multilingual model or not
|
21 |
+
:param src_lang: source language
|
22 |
:param quantize: whether to quantize the model with PyTorch's 'quantize_dynamic'
|
23 |
:param no_cuda: whether to disable CUDA, even if it is available
|
24 |
+
:return: the loaded model, and tokenizer wrapper
|
25 |
"""
|
26 |
+
model_name = "BramVanroy/mbart-large-cc25-ft-amr30-en_es_nl"
|
27 |
+
if not multilingual:
|
28 |
+
if src_lang == "English":
|
29 |
+
model_name = "BramVanroy/mbart-large-cc25-ft-amr30-en"
|
30 |
+
elif src_lang == "Spanish":
|
31 |
+
model_name = "BramVanroy/mbart-large-cc25-ft-amr30-es"
|
32 |
+
elif src_lang == "Dutch":
|
33 |
+
model_name = "BramVanroy/mbart-large-cc25-ft-amr30-nl"
|
34 |
+
else:
|
35 |
+
raise ValueError(f"Language {src_lang} not supported")
|
36 |
+
|
37 |
+
# Tokenizer src_lang is reset during translation to the right language
|
38 |
+
tok_wrapper = AMRTokenizerWrapper.from_pretrained(model_name, src_lang="en_XX")
|
39 |
+
|
40 |
+
config = AutoConfig.from_pretrained(model_name)
|
41 |
+
config.decoder_start_token_id = tok_wrapper.amr_token_id
|
42 |
+
|
43 |
+
model = MBartForConditionalGeneration.from_pretrained(model_name, config=config)
|
44 |
+
model.eval()
|
45 |
+
|
46 |
+
embedding_size = model.get_input_embeddings().weight.shape[0]
|
47 |
+
if len(tok_wrapper.tokenizer) > embedding_size:
|
48 |
+
model.resize_token_embeddings(len(tok_wrapper.tokenizer))
|
49 |
|
50 |
model = BetterTransformer.transform(model, keep_original_model=False)
|
|
|
51 |
|
52 |
if torch.cuda.is_available() and not no_cuda:
|
53 |
model = model.to("cuda")
|
54 |
elif quantize: # Quantization not supported on CUDA
|
55 |
model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)
|
56 |
|
57 |
+
return model, tok_wrapper
|
58 |
|
|
|
59 |
|
60 |
+
def translate(texts: List[str], src_lang: str, model: MBartForConditionalGeneration, tok_wrapper: AMRTokenizerWrapper, **gen_kwargs) -> Dict[str, List[Union[penman.Graph, ParsedStatus]]]:
|
|
|
61 |
"""Translates a given text of a given source language with a given model and tokenizer. The generation is guided by
|
62 |
potential keyword-arguments, which can include arguments such as max length, logits processors, etc.
|
63 |
|
64 |
+
:param texts: source text to translate (potentially a batch)
|
65 |
:param src_lang: source language
|
66 |
:param model: MBART model
|
67 |
+
:param tok_wrapper: MBART tokenizer wrapper
|
68 |
:param gen_kwargs: potential keyword arguments for the generation process
|
69 |
:return: the translation (linearized AMR graph)
|
70 |
"""
|
71 |
+
if isinstance(texts, str):
|
72 |
+
texts = [texts]
|
73 |
+
|
74 |
+
tok_wrapper.src_lang = LANGUAGES[src_lang]
|
75 |
+
encoded = tok_wrapper(texts, return_tensors="pt").to(model.device)
|
76 |
+
with torch.no_grad():
|
77 |
+
generated = model.generate(**encoded, output_scores=True, return_dict_in_generate=True, **gen_kwargs)
|
78 |
+
|
79 |
+
generated["sequences"] = generated["sequences"].cpu()
|
80 |
+
generated["sequences_scores"] = generated["sequences_scores"].cpu()
|
81 |
+
best_scoring_results = {"graph": [], "status": []}
|
82 |
+
beam_size = gen_kwargs["num_beams"]
|
83 |
+
|
84 |
+
# Select the best item from the beam: the sequence with best status and highest score
|
85 |
+
for sample_idx in range(0, len(generated["sequences_scores"]), beam_size):
|
86 |
+
sequences = generated["sequences"][sample_idx: sample_idx + beam_size]
|
87 |
+
scores = generated["sequences_scores"][sample_idx: sample_idx + beam_size].tolist()
|
88 |
+
outputs = tok_wrapper.batch_decode_amr_ids(sequences)
|
89 |
+
statuses = outputs["status"]
|
90 |
+
graphs = outputs["graph"]
|
91 |
+
zipped = zip(statuses, scores, graphs)
|
92 |
+
# Lowest status first (OK=0, FIXED=1, BACKOFF=2), highest score second
|
93 |
+
best = sorted(zipped, key=lambda item: (item[0].value, -item[1]))[0]
|
94 |
+
best_scoring_results["graph"].append(best[2])
|
95 |
+
best_scoring_results["status"].append(best[0])
|
96 |
+
|
97 |
+
# Returns dictionary with "graph" and "status" keys
|
98 |
+
return best_scoring_results
|
99 |
|
100 |
|
101 |
LANGUAGES = {
|