Bram Vanroy commited on
Commit
05de9a6
Β·
1 Parent(s): 4d85339

Make stateful

Browse files

This also gets rid of this issue which https://github.com/streamlit/streamlit/issues/6451

Files changed (3) hide show
  1. app.py +47 -65
  2. requirements.txt +7 -4
  3. utils.py +64 -26
app.py CHANGED
@@ -3,10 +3,7 @@ from collections import Counter
3
 
4
  import graphviz
5
  import penman
6
- from mbart_amr.data.linearization import linearized2penmanstr
7
- from penman.models.noop import NoOpModel
8
- import streamlit as st
9
- from transformers import LogitsProcessorList
10
 
11
  from utils import get_resources, LANGUAGES, translate
12
 
@@ -17,43 +14,41 @@ st.set_page_config(
17
  page_icon="πŸ‘©β€πŸ’»"
18
  )
19
 
20
- st.title("πŸ‘©β€πŸ’» Multilingual text to AMR ᡇᡉᡗᡃ")
 
 
 
 
 
 
 
21
 
22
- with st.form("input data"):
23
- text_col, lang_col = st.columns((4, 1))
24
- text = text_col.text_input(label="Input text")
25
- src_lang = lang_col.selectbox(label="Language", options=list(LANGUAGES.keys()), index=0)
26
- submitted = st.form_submit_button("Submit")
 
 
27
 
28
  error_ct = st.empty()
29
- if submitted:
30
- text = text.strip()
31
- if not text:
32
- error_ct.error("Text cannot be empty!", icon="⚠️")
33
- else:
34
  error_ct.info("Generating abstract meaning representation (AMR)...", icon="πŸ’»")
35
- multilingual = src_lang != "English"
36
- model, tokenizer, logitsprocessor = get_resources(multilingual)
37
  gen_kwargs = {
38
- "max_length": model.config.max_length,
39
- "num_beams": model.config.num_beams,
40
- "logits_processor": LogitsProcessorList([logitsprocessor])
41
  }
42
 
43
- linearized = translate(text, src_lang, model, tokenizer, **gen_kwargs)
44
- penman_str = linearized2penmanstr(linearized)
45
  error_ct.empty()
46
 
47
- try:
48
- graph = penman.decode(penman_str, model=NoOpModel())
49
- except Exception as exc:
50
- st.write(f"The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
51
- f" to a valid graph but note that this is invalid Penman.")
52
- st.code(penman_str)
53
-
54
- with st.expander("Error trace"):
55
- st.write(exc)
56
  else:
 
57
  visualized = graphviz.Digraph(node_attr={"color": "#3aafa9", "style": "rounded,filled", "shape": "box",
58
  "fontcolor": "white"})
59
 
@@ -74,40 +69,27 @@ if submitted:
74
  def get_node_name(item: str):
75
  return nodenames[item] if item in nodenames else item
76
 
77
- try:
78
- for triple in graph.triples:
79
- if triple[1] == ":instance":
80
- continue
81
- else:
82
- visualized.edge(get_node_name(triple[0]), get_node_name(triple[2]), label=triple[1])
83
- except Exception as exc:
84
- st.write("The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt"
85
- " to a valid graph but note that this is probably invalid Penman.")
86
- st.code(penman_str)
87
- st.write("The initial linearized output of the model was:")
88
- st.code(linearized)
89
-
90
- with st.expander("Error trace"):
91
- st.write(exc)
92
- else:
93
- st.subheader("Graph visualization")
94
- st.graphviz_chart(visualized, use_container_width=True)
95
-
96
- # Download link
97
- def create_download_link(img_bytes: bytes):
98
- encoded = base64.b64encode(img_bytes).decode("utf-8")
99
- return f'<a href="data:image/png;charset=utf-8;base64,{encoded}" download="amr-graph.png">Download graph</a>'
100
-
101
- img = visualized.pipe(format="png")
102
- st.markdown(create_download_link(img), unsafe_allow_html=True)
103
-
104
- # Additional info
105
- st.subheader("Model output and Penman graph")
106
- st.write("The linearized output of the model (after some post-processing) is:")
107
- st.code(linearized)
108
- st.write("When converted into Penman, it looks like this:")
109
- st.code(penman.encode(graph))
110
-
111
 
112
  ########################
113
  # Information, socials #
 
3
 
4
  import graphviz
5
  import penman
6
+ from multi_amr.data.postprocessing_graph import ParsedStatus
 
 
 
7
 
8
  from utils import get_resources, LANGUAGES, translate
9
 
 
14
  page_icon="πŸ‘©β€πŸ’»"
15
  )
16
 
17
+ st.title("πŸ‘©β€πŸ’» Multilingual text to AMR")
18
+
19
+ if "text" not in st.session_state:
20
+ st.session_state["text"] = ""
21
+ if "language" not in st.session_state:
22
+ st.session_state["language"] = "English"
23
+ if "use_multilingual" not in st.session_state:
24
+ st.session_state["use_multilingual"] = False
25
 
26
+ text_col, lang_col = st.columns((4, 1))
27
+ text = text_col.text_input(label="Input text", key="text")
28
+ src_lang = lang_col.selectbox(label="Language", options=list(LANGUAGES.keys()), index=0, key="language")
29
+ multilingual = st.checkbox("Use multilingual model", label_visibility="visible", key="use_multilingual",
30
+ help="Whether to use a single multilingual model that was trained on English, Spanish and"
31
+ " Dutch together, or (if not checked) language-specific models. Enabling this will"
32
+ " results in worse performance but can be of interest for research purposes.")
33
 
34
  error_ct = st.empty()
35
+ if st.session_state["text"]:
36
+ if st.button("Submit"):
37
+ text = text.strip()
 
 
38
  error_ct.info("Generating abstract meaning representation (AMR)...", icon="πŸ’»")
39
+ model, tokenizer = get_resources(multilingual, src_lang)
 
40
  gen_kwargs = {
41
+ "max_new_tokens": 512,
42
+ "num_beams": 5,
 
43
  }
44
 
45
+ outputs = translate(text, src_lang, model, tokenizer, **gen_kwargs)
 
46
  error_ct.empty()
47
 
48
+ if outputs["status"][0] == ParsedStatus.BACKOFF:
49
+ st.write(f"The system could not generate a valid graph no matter how hard it tried.")
 
 
 
 
 
 
 
50
  else:
51
+ graph = outputs["graph"][0]
52
  visualized = graphviz.Digraph(node_attr={"color": "#3aafa9", "style": "rounded,filled", "shape": "box",
53
  "fontcolor": "white"})
54
 
 
69
  def get_node_name(item: str):
70
  return nodenames[item] if item in nodenames else item
71
 
72
+ for triple in graph.triples:
73
+ if triple[1] == ":instance":
74
+ continue
75
+ else:
76
+ visualized.edge(get_node_name(triple[0]), get_node_name(triple[2]), label=triple[1])
77
+ st.subheader("Graph visualization")
78
+ st.graphviz_chart(visualized, use_container_width=True)
79
+
80
+ # Download link
81
+ def create_download_link(img_bytes: bytes):
82
+ encoded = base64.b64encode(img_bytes).decode("utf-8")
83
+ return f'<a href="data:image/png;charset=utf-8;base64,{encoded}" download="amr-graph.png">Download graph</a>'
84
+
85
+ img = visualized.pipe(format="png")
86
+ st.markdown(create_download_link(img), unsafe_allow_html=True)
87
+
88
+ # Additional info
89
+ st.subheader("PENMAN representation")
90
+ st.code(penman.encode(graph))
91
+ else:
92
+ error_ct.warning("Text cannot be empty!", icon="⚠️")
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  ########################
95
  # Information, socials #
requirements.txt CHANGED
@@ -1,7 +1,10 @@
 
1
  altair==4.2.2
2
  graphviz==0.20.1
3
- optimum==1.7.1
 
4
  penman==1.2.2
5
- streamlit==1.19.0
6
- torch==1.13.1
7
- git+https://github.com/BramVanroy/multilingual-text-to-amr@5859af0d870acd2f76d71e5a7d12fa35a7a2059b#egg=mbart-amr
 
 
1
+ accelerate==0.22.0
2
  altair==4.2.2
3
  graphviz==0.20.1
4
+ multi_amr @ git+https://github.com/BramVanroy/multilingual-text-to-amr@v1.0.0-alpha
5
+ optimum==1.10.1
6
  penman==1.2.2
7
+ streamlit==1.26.0
8
+ torch==2.0.1
9
+ transformers==4.33.1
10
+ wheel
utils.py CHANGED
@@ -1,63 +1,101 @@
1
- from typing import Tuple
2
 
 
 
3
  from optimum.bettertransformer import BetterTransformer
4
- from mbart_amr.constraints.constraints import AMRLogitsProcessor
5
- from mbart_amr.data.tokenization import AMRMBartTokenizer
6
  import streamlit as st
7
  import torch
8
  from torch.quantization import quantize_dynamic
9
  from torch import nn, qint8
10
- from transformers import MBartForConditionalGeneration
 
11
 
12
 
13
  @st.cache_resource(show_spinner=False)
14
- def get_resources(multilingual: bool, quantize: bool = True, no_cuda: bool = False) -> Tuple[MBartForConditionalGeneration, AMRMBartTokenizer, AMRLogitsProcessor]:
15
  """Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual
16
  model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized
17
  for better performance.
18
 
19
- :param multilingual: whether or not to load the multilingual model. If not, loads the English-only model
 
20
  :param quantize: whether to quantize the model with PyTorch's 'quantize_dynamic'
21
  :param no_cuda: whether to disable CUDA, even if it is available
22
- :return: the loaded model, tokenizer, and logits processor
23
  """
24
- if multilingual:
25
- # Tokenizer src_lang is reset during translation to the right language
26
- tokenizer = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-es-nl-to-amr", src_lang="nl_XX")
27
- model = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-es-nl-to-amr")
28
- else:
29
- tokenizer = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-to-amr", src_lang="en_XX")
30
- model = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-to-amr")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  model = BetterTransformer.transform(model, keep_original_model=False)
33
- model.resize_token_embeddings(len(tokenizer))
34
 
35
  if torch.cuda.is_available() and not no_cuda:
36
  model = model.to("cuda")
37
  elif quantize: # Quantization not supported on CUDA
38
  model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)
39
 
40
- logits_processor = AMRLogitsProcessor(tokenizer, model.config.max_length)
41
 
42
- return model, tokenizer, logits_processor
43
 
44
-
45
- def translate(text: str, src_lang: str, model: MBartForConditionalGeneration, tokenizer: AMRMBartTokenizer, **gen_kwargs) -> str:
46
  """Translates a given text of a given source language with a given model and tokenizer. The generation is guided by
47
  potential keyword-arguments, which can include arguments such as max length, logits processors, etc.
48
 
49
- :param text: source text to translate
50
  :param src_lang: source language
51
  :param model: MBART model
52
- :param tokenizer: MBART tokenizer
53
  :param gen_kwargs: potential keyword arguments for the generation process
54
  :return: the translation (linearized AMR graph)
55
  """
56
- tokenizer.src_lang = LANGUAGES[src_lang]
57
- encoded = tokenizer(text, return_tensors="pt")
58
- encoded = {k: v.to(model.device) for k, v in encoded.items()}
59
- generated = model.generate(**encoded, **gen_kwargs).cpu()
60
- return tokenizer.decode_and_fix(generated)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  LANGUAGES = {
 
1
+ from typing import Tuple, Union, Dict, List
2
 
3
+ from multi_amr.data.postprocessing_graph import ParsedStatus
4
+ from multi_amr.data.tokenization import AMRTokenizerWrapper
5
  from optimum.bettertransformer import BetterTransformer
 
 
6
  import streamlit as st
7
  import torch
8
  from torch.quantization import quantize_dynamic
9
  from torch import nn, qint8
10
+ from transformers import MBartForConditionalGeneration, AutoConfig
11
+ import penman
12
 
13
 
14
  @st.cache_resource(show_spinner=False)
15
+ def get_resources(multilingual: bool, src_lang: str, quantize: bool = True, no_cuda: bool = False) -> Tuple[MBartForConditionalGeneration, AMRTokenizerWrapper]:
16
  """Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual
17
  model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized
18
  for better performance.
19
 
20
+ :param multilingual: whether to load the multilingual model or not
21
+ :param src_lang: source language
22
  :param quantize: whether to quantize the model with PyTorch's 'quantize_dynamic'
23
  :param no_cuda: whether to disable CUDA, even if it is available
24
+ :return: the loaded model, and tokenizer wrapper
25
  """
26
+ model_name = "BramVanroy/mbart-large-cc25-ft-amr30-en_es_nl"
27
+ if not multilingual:
28
+ if src_lang == "English":
29
+ model_name = "BramVanroy/mbart-large-cc25-ft-amr30-en"
30
+ elif src_lang == "Spanish":
31
+ model_name = "BramVanroy/mbart-large-cc25-ft-amr30-es"
32
+ elif src_lang == "Dutch":
33
+ model_name = "BramVanroy/mbart-large-cc25-ft-amr30-nl"
34
+ else:
35
+ raise ValueError(f"Language {src_lang} not supported")
36
+
37
+ # Tokenizer src_lang is reset during translation to the right language
38
+ tok_wrapper = AMRTokenizerWrapper.from_pretrained(model_name, src_lang="en_XX")
39
+
40
+ config = AutoConfig.from_pretrained(model_name)
41
+ config.decoder_start_token_id = tok_wrapper.amr_token_id
42
+
43
+ model = MBartForConditionalGeneration.from_pretrained(model_name, config=config)
44
+ model.eval()
45
+
46
+ embedding_size = model.get_input_embeddings().weight.shape[0]
47
+ if len(tok_wrapper.tokenizer) > embedding_size:
48
+ model.resize_token_embeddings(len(tok_wrapper.tokenizer))
49
 
50
  model = BetterTransformer.transform(model, keep_original_model=False)
 
51
 
52
  if torch.cuda.is_available() and not no_cuda:
53
  model = model.to("cuda")
54
  elif quantize: # Quantization not supported on CUDA
55
  model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)
56
 
57
+ return model, tok_wrapper
58
 
 
59
 
60
+ def translate(texts: List[str], src_lang: str, model: MBartForConditionalGeneration, tok_wrapper: AMRTokenizerWrapper, **gen_kwargs) -> Dict[str, List[Union[penman.Graph, ParsedStatus]]]:
 
61
  """Translates a given text of a given source language with a given model and tokenizer. The generation is guided by
62
  potential keyword-arguments, which can include arguments such as max length, logits processors, etc.
63
 
64
+ :param texts: source text to translate (potentially a batch)
65
  :param src_lang: source language
66
  :param model: MBART model
67
+ :param tok_wrapper: MBART tokenizer wrapper
68
  :param gen_kwargs: potential keyword arguments for the generation process
69
  :return: the translation (linearized AMR graph)
70
  """
71
+ if isinstance(texts, str):
72
+ texts = [texts]
73
+
74
+ tok_wrapper.src_lang = LANGUAGES[src_lang]
75
+ encoded = tok_wrapper(texts, return_tensors="pt").to(model.device)
76
+ with torch.no_grad():
77
+ generated = model.generate(**encoded, output_scores=True, return_dict_in_generate=True, **gen_kwargs)
78
+
79
+ generated["sequences"] = generated["sequences"].cpu()
80
+ generated["sequences_scores"] = generated["sequences_scores"].cpu()
81
+ best_scoring_results = {"graph": [], "status": []}
82
+ beam_size = gen_kwargs["num_beams"]
83
+
84
+ # Select the best item from the beam: the sequence with best status and highest score
85
+ for sample_idx in range(0, len(generated["sequences_scores"]), beam_size):
86
+ sequences = generated["sequences"][sample_idx: sample_idx + beam_size]
87
+ scores = generated["sequences_scores"][sample_idx: sample_idx + beam_size].tolist()
88
+ outputs = tok_wrapper.batch_decode_amr_ids(sequences)
89
+ statuses = outputs["status"]
90
+ graphs = outputs["graph"]
91
+ zipped = zip(statuses, scores, graphs)
92
+ # Lowest status first (OK=0, FIXED=1, BACKOFF=2), highest score second
93
+ best = sorted(zipped, key=lambda item: (item[0].value, -item[1]))[0]
94
+ best_scoring_results["graph"].append(best[2])
95
+ best_scoring_results["status"].append(best[0])
96
+
97
+ # Returns dictionary with "graph" and "status" keys
98
+ return best_scoring_results
99
 
100
 
101
  LANGUAGES = {