yhavinga commited on
Commit
0ed2b71
โ€ข
1 Parent(s): 1a85226

Adapt to streaming interface (only when num_beams is equal to 1)

Browse files
Files changed (3) hide show
  1. app.py +30 -3
  2. generator.py +14 -12
  3. requirements.txt +1 -1
app.py CHANGED
@@ -4,6 +4,7 @@ import psutil
4
  import streamlit as st
5
  import torch
6
  from langdetect import detect
 
7
 
8
  from default_texts import default_texts
9
  from generator import GeneratorFactory
@@ -60,6 +61,20 @@ GENERATOR_LIST = [
60
  ]
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def main():
64
  st.set_page_config( # Alternate names: setup_page, page, layout
65
  page_title="Rosetta en/nl", # String or None. Strings get appended with "โ€ข Streamlit".
@@ -132,16 +147,28 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
132
  left.error("Num beams should be a multiple of num beam groups")
133
  return
134
 
 
 
 
 
135
  for generator in generators.filter(task=task):
136
- right.markdown(f"๐Ÿงฎ **Model `{generator}`**")
 
 
 
 
 
 
 
137
  time_start = time.time()
138
  result, params_used = generator.generate(
139
- text=st.session_state.text, **params
140
  )
141
  time_end = time.time()
142
  time_diff = time_end - time_start
143
 
144
- right.write(result.replace("\n", " \n"))
 
145
  text_line = ", ".join([f"{k}={v}" for k, v in params_used.items()])
146
  right.markdown(f" ๐Ÿ•™ *generated in {time_diff:.2f}s, `{text_line}`*")
147
 
 
4
  import streamlit as st
5
  import torch
6
  from langdetect import detect
7
+ from transformers import TextIteratorStreamer
8
 
9
  from default_texts import default_texts
10
  from generator import GeneratorFactory
 
61
  ]
62
 
63
 
64
+ class StreamlitTextIteratorStreamer(TextIteratorStreamer):
65
+ def __init__(
66
+ self, output_placeholder, tokenizer, skip_prompt=False, **decode_kwargs
67
+ ):
68
+ super().__init__(tokenizer, skip_prompt, **decode_kwargs)
69
+ self.output_placeholder = output_placeholder
70
+ self.output_text = ""
71
+
72
+ def on_finalized_text(self, text: str, stream_end: bool = False):
73
+ self.output_text += text
74
+ self.output_placeholder.markdown(self.output_text, unsafe_allow_html=True)
75
+ super().on_finalized_text(text, stream_end)
76
+
77
+
78
  def main():
79
  st.set_page_config( # Alternate names: setup_page, page, layout
80
  page_title="Rosetta en/nl", # String or None. Strings get appended with "โ€ข Streamlit".
 
147
  left.error("Num beams should be a multiple of num beam groups")
148
  return
149
 
150
+ streaming_enabled = num_beams == 1
151
+ if not streaming_enabled:
152
+ left.markdown("*`num_beams > 1` so streaming is disabled*")
153
+
154
  for generator in generators.filter(task=task):
155
+ model_container = right.container()
156
+ model_container.markdown(f"๐Ÿงฎ **Model `{generator}`**")
157
+ output_placeholder = model_container.empty()
158
+ streamer = (
159
+ StreamlitTextIteratorStreamer(output_placeholder, generator.tokenizer)
160
+ if streaming_enabled
161
+ else None
162
+ )
163
  time_start = time.time()
164
  result, params_used = generator.generate(
165
+ text=st.session_state.text, streamer=streamer, **params
166
  )
167
  time_end = time.time()
168
  time_diff = time_end - time_start
169
 
170
+ if not streaming_enabled:
171
+ right.write(result.replace("\n", " \n"))
172
  text_line = ", ".join([f"{k}={v}" for k, v in params_used.items()])
173
  right.markdown(f" ๐Ÿ•™ *generated in {time_diff:.2f}s, `{text_line}`*")
174
 
generator.py CHANGED
@@ -20,7 +20,7 @@ def get_access_token():
20
 
21
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
22
  def load_model(model_name):
23
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
24
  tokenizer = AutoTokenizer.from_pretrained(
25
  model_name,
26
  from_flax=True,
@@ -30,19 +30,18 @@ def load_model(model_name):
30
  if tokenizer.pad_token is None:
31
  print("Adding pad_token to the tokenizer")
32
  tokenizer.pad_token = tokenizer.eos_token
33
- try:
34
- model = AutoModelForSeq2SeqLM.from_pretrained(
35
- model_name, use_auth_token=get_access_token()
36
- )
37
- except EnvironmentError:
38
  try:
39
  model = AutoModelForSeq2SeqLM.from_pretrained(
40
- model_name, from_flax=True, use_auth_token=get_access_token()
 
 
 
41
  )
 
42
  except EnvironmentError:
43
- model = AutoModelForSeq2SeqLM.from_pretrained(
44
- model_name, from_tf=True, use_auth_token=get_access_token()
45
- )
46
  if device != -1:
47
  model.to(f"cuda:{device}")
48
  return tokenizer, model
@@ -89,7 +88,7 @@ class Generator:
89
  except TypeError:
90
  pass
91
 
92
- def generate(self, text: str, **generate_kwargs) -> (str, dict):
93
  # Replace two or more newlines with a single newline in text
94
  text = re.sub(r"\n{2,}", "\n", text)
95
 
@@ -98,7 +97,9 @@ class Generator:
98
  # if there are newlines in the text, and the model needs line-splitting, split the text and recurse
99
  if re.search(r"\n", text) and self.split_sentences:
100
  lines = text.splitlines()
101
- translated = [self.generate(line, **generate_kwargs)[0] for line in lines]
 
 
102
  return "\n".join(translated), generate_kwargs
103
 
104
  # if self.tokenizer has a newline_token attribute, replace \n with it
@@ -117,6 +118,7 @@ class Generator:
117
  logits = self.model.generate(
118
  batch_encoded["input_ids"],
119
  attention_mask=batch_encoded["attention_mask"],
 
120
  **generate_kwargs,
121
  )
122
  decoded_preds = self.tokenizer.batch_decode(
 
20
 
21
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
22
  def load_model(model_name):
23
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
24
  tokenizer = AutoTokenizer.from_pretrained(
25
  model_name,
26
  from_flax=True,
 
30
  if tokenizer.pad_token is None:
31
  print("Adding pad_token to the tokenizer")
32
  tokenizer.pad_token = tokenizer.eos_token
33
+ for framework in [None, "flax", "tf"]:
 
 
 
 
34
  try:
35
  model = AutoModelForSeq2SeqLM.from_pretrained(
36
+ model_name,
37
+ from_flax=(framework == "flax"),
38
+ from_tf=(framework == "tf"),
39
+ use_auth_token=get_access_token(),
40
  )
41
+ break
42
  except EnvironmentError:
43
+ if framework == "tf":
44
+ raise
 
45
  if device != -1:
46
  model.to(f"cuda:{device}")
47
  return tokenizer, model
 
88
  except TypeError:
89
  pass
90
 
91
+ def generate(self, text: str, streamer=None, **generate_kwargs) -> (str, dict):
92
  # Replace two or more newlines with a single newline in text
93
  text = re.sub(r"\n{2,}", "\n", text)
94
 
 
97
  # if there are newlines in the text, and the model needs line-splitting, split the text and recurse
98
  if re.search(r"\n", text) and self.split_sentences:
99
  lines = text.splitlines()
100
+ translated = [
101
+ self.generate(line, streamer, **generate_kwargs)[0] for line in lines
102
+ ]
103
  return "\n".join(translated), generate_kwargs
104
 
105
  # if self.tokenizer has a newline_token attribute, replace \n with it
 
118
  logits = self.model.generate(
119
  batch_encoded["input_ids"],
120
  attention_mask=batch_encoded["attention_mask"],
121
+ streamer=streamer,
122
  **generate_kwargs,
123
  )
124
  decoded_preds = self.tokenizer.batch_decode(
requirements.txt CHANGED
@@ -4,7 +4,7 @@
4
  protobuf<3.20
5
  streamlit>=1.4.0,<=1.10.0
6
  torch
7
- transformers>=4.13.0
8
  langdetect
9
  psutil
10
  jax[cuda]==0.3.16
 
4
  protobuf<3.20
5
  streamlit>=1.4.0,<=1.10.0
6
  torch
7
+ git+https://github.com/huggingface/transformers.git@1905384fd576acf4b645a8216907f980b4788d9b
8
  langdetect
9
  psutil
10
  jax[cuda]==0.3.16