yhavinga commited on
Commit
cd377f8
·
1 Parent(s): f839da7

Add streaming, disable translation

Browse files

* Also upgrade transformers, add sentencepiece

Files changed (2) hide show
  1. app.py +110 -81
  2. requirements.txt +8 -5
app.py CHANGED
@@ -10,6 +10,7 @@ from transformers import (
10
  AutoModelForCausalLM,
11
  AutoModelForSeq2SeqLM,
12
  AutoTokenizer,
 
13
  pipeline,
14
  set_seed,
15
  )
@@ -41,6 +42,20 @@ def load_model(model_name, task):
41
  return tokenizer, model
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  class Generator:
45
  def __init__(self, model_name, task, desc):
46
  self.model_name = model_name
@@ -52,18 +67,38 @@ class Generator:
52
  self.load()
53
 
54
  def load(self):
55
- if not self.pipeline:
56
  print(f"Loading model {self.model_name}")
57
  self.tokenizer, self.model = load_model(self.model_name, self.task)
58
- self.pipeline = pipeline(
59
- task=self.task,
60
- model=self.model,
61
- tokenizer=self.tokenizer,
62
- device=device,
63
- )
64
 
65
- def get_text(self, text: str, **generate_kwargs) -> str:
66
- return self.pipeline(text, **generate_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
 
69
  class GeneratorFactory:
@@ -82,11 +117,11 @@ class GeneratorFactory:
82
  "desc": "GPT2 Medium Dutch (book finetune)",
83
  "task": "text-generation",
84
  },
85
- {
86
- "model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
87
- "desc": "Dutch<->English T5 small 24 layers",
88
- "task": TRANSLATION_NL_TO_EN,
89
- },
90
  ]
91
  for g in GENERATOR_LIST:
92
  with st.spinner(text=f"Loading the model {g['desc']} ..."):
@@ -148,12 +183,13 @@ def main():
148
  repetition_penalty = st.sidebar.number_input(
149
  "Repetition penalty", min_value=0.0, max_value=5.0, value=1.2, step=0.1
150
  )
151
- num_return_sequences = st.sidebar.number_input(
152
- "Num return sequences", min_value=1, max_value=5, value=1
153
- )
 
154
  seed_placeholder = st.sidebar.empty()
155
  if "seed" not in st.session_state:
156
- print(f"Session state {st.session_state} does not contain seed")
157
  st.session_state["seed"] = 4162549114
158
  print(f"Seed is set to: {st.session_state['seed']}")
159
 
@@ -218,69 +254,62 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
218
  )
219
 
220
  if st.button("Run"):
221
- estimate = max_length / 18
222
- if device == -1:
223
- ## cpu
224
- estimate = estimate * (1 + 0.7 * (num_return_sequences - 1))
225
- if sampling_mode == "Beam Search":
226
- estimate = estimate * (1.1 + 0.3 * (num_beams - 1))
227
- else:
228
- ## gpu
229
- estimate = estimate * (1 + 0.1 * (num_return_sequences - 1))
230
- estimate = 0.5 + estimate / 5
231
- if sampling_mode == "Beam Search":
232
- estimate = estimate * (1.0 + 0.1 * (num_beams - 1))
233
- estimate = int(estimate)
234
-
235
- with st.spinner(
236
- text=f"Please wait ~ {estimate} second{'s' if estimate != 1 else ''} while getting results ..."
237
- ):
238
- memory = psutil.virtual_memory()
239
- generator = generators.get_generator(desc=model_desc)
240
- set_seed(seed)
241
- time_start = time.time()
242
- result = generator.get_text(text=st.session_state.text, **params)
243
- time_end = time.time()
244
- time_diff = time_end - time_start
245
- st.subheader("Result")
246
-
247
- for text in result:
248
- st.write(text.get("generated_text").replace("\n", " \n"))
249
- st.text("*Translation*")
250
- translate_params = {
251
- "num_return_sequences": 1,
252
- "num_beams": 4,
253
- "early_stopping": True,
254
- "length_penalty": 1.1,
255
- "max_length": 200,
256
- }
257
- text_lines = [
258
- "translate Dutch to English: " + t
259
- for t in text.get("generated_text").splitlines()
260
- ]
261
- translated_lines = [
262
- t["translation_text"]
263
- for t in generators.get_generator(
264
- task=TRANSLATION_NL_TO_EN
265
- ).get_text(text_lines, **translate_params)
266
- ]
267
- translation = " \n".join(translated_lines)
268
- st.write(translation)
269
- st.write("---")
270
- #
271
- info = f"""
272
- ---
273
- *Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB*
274
- *Text generated using seed {seed} in {time_diff:.5} seconds*
275
- """
276
- st.write(info)
277
-
278
- params["seed"] = seed
279
- params["prompt"] = st.session_state.text
280
- params["model"] = generator.model_name
281
- params_text = json.dumps(params)
282
- print(params_text)
283
- st.json(params_text)
284
 
285
 
286
  if __name__ == "__main__":
 
10
  AutoModelForCausalLM,
11
  AutoModelForSeq2SeqLM,
12
  AutoTokenizer,
13
+ TextIteratorStreamer,
14
  pipeline,
15
  set_seed,
16
  )
 
42
  return tokenizer, model
43
 
44
 
45
+ class StreamlitTextIteratorStreamer(TextIteratorStreamer):
46
+ def __init__(
47
+ self, output_placeholder, tokenizer, skip_prompt=False, **decode_kwargs
48
+ ):
49
+ super().__init__(tokenizer, skip_prompt, **decode_kwargs)
50
+ self.output_placeholder = output_placeholder
51
+ self.output_text = ""
52
+
53
+ def on_finalized_text(self, text: str, stream_end: bool = False):
54
+ self.output_text += text
55
+ self.output_placeholder.markdown(self.output_text, unsafe_allow_html=True)
56
+ super().on_finalized_text(text, stream_end)
57
+
58
+
59
  class Generator:
60
  def __init__(self, model_name, task, desc):
61
  self.model_name = model_name
 
67
  self.load()
68
 
69
  def load(self):
70
+ if not self.model:
71
  print(f"Loading model {self.model_name}")
72
  self.tokenizer, self.model = load_model(self.model_name, self.task)
 
 
 
 
 
 
73
 
74
+ def generate(self, text: str, streamer=None, **generate_kwargs) -> (str, dict):
75
+ batch_encoded = self.tokenizer(
76
+ text,
77
+ max_length=generate_kwargs["max_length"],
78
+ padding=False,
79
+ truncation=False,
80
+ return_tensors="pt",
81
+ )
82
+ if device != -1:
83
+ batch_encoded.to(f"cuda:{device}")
84
+ logits = self.model.generate(
85
+ batch_encoded["input_ids"],
86
+ attention_mask=batch_encoded["attention_mask"],
87
+ streamer=streamer,
88
+ **generate_kwargs,
89
+ )
90
+ decoded_preds = self.tokenizer.batch_decode(
91
+ logits.cpu().numpy(), skip_special_tokens=False
92
+ )
93
+
94
+ def replace_tokens(pred):
95
+ pred = pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
96
+ if hasattr(self.tokenizer, "newline_token"):
97
+ pred = pred.replace(self.tokenizer.newline_token, "\n")
98
+ return pred
99
+
100
+ decoded_preds = list(map(replace_tokens, decoded_preds))
101
+ return decoded_preds[0], generate_kwargs
102
 
103
 
104
  class GeneratorFactory:
 
117
  "desc": "GPT2 Medium Dutch (book finetune)",
118
  "task": "text-generation",
119
  },
120
+ # {
121
+ # "model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
122
+ # "desc": "Dutch<->English T5 small 24 layers",
123
+ # "task": TRANSLATION_NL_TO_EN,
124
+ # },
125
  ]
126
  for g in GENERATOR_LIST:
127
  with st.spinner(text=f"Loading the model {g['desc']} ..."):
 
183
  repetition_penalty = st.sidebar.number_input(
184
  "Repetition penalty", min_value=0.0, max_value=5.0, value=1.2, step=0.1
185
  )
186
+ num_return_sequences = 1
187
+ # st.sidebar.number_input(
188
+ # "Num return sequences", min_value=1, max_value=5, value=1
189
+ # )
190
  seed_placeholder = st.sidebar.empty()
191
  if "seed" not in st.session_state:
192
+ print(f"Session state does not contain seed")
193
  st.session_state["seed"] = 4162549114
194
  print(f"Seed is set to: {st.session_state['seed']}")
195
 
 
254
  )
255
 
256
  if st.button("Run"):
257
+ memory = psutil.virtual_memory()
258
+ st.subheader("Result")
259
+ container = st.container()
260
+ output_placeholder = container.empty()
261
+ streaming_enabled = True # sampling_mode != "Beam Search" or num_beams == 1
262
+ generator = generators.get_generator(desc=model_desc)
263
+ streamer = (
264
+ StreamlitTextIteratorStreamer(output_placeholder, generator.tokenizer)
265
+ if streaming_enabled
266
+ else None
267
+ )
268
+ set_seed(seed)
269
+ time_start = time.time()
270
+ result = generator.generate(
271
+ text=st.session_state.text, streamer=streamer, **params
272
+ )
273
+ time_end = time.time()
274
+ time_diff = time_end - time_start
275
+
276
+ # for text in result:
277
+ # st.write(text.get("generated_text").replace("\n", " \n"))
278
+ # st.text("*Translation*")
279
+ # translate_params = {
280
+ # "num_return_sequences": 1,
281
+ # "num_beams": 4,
282
+ # "early_stopping": True,
283
+ # "length_penalty": 1.1,
284
+ # "max_length": 200,
285
+ # }
286
+ # text_lines = [
287
+ # "translate Dutch to English: " + t
288
+ # for t in text.get("generated_text").splitlines()
289
+ # ]
290
+ # translated_lines = [
291
+ # t["translation_text"]
292
+ # for t in generators.get_generator(
293
+ # task=TRANSLATION_NL_TO_EN
294
+ # ).get_text(text_lines, **translate_params)
295
+ # ]
296
+ # translation = " \n".join(translated_lines)
297
+ # st.write(translation)
298
+ # st.write("---")
299
+ #
300
+ info = f"""
301
+ ---
302
+ *Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB*
303
+ *Text generated using seed {seed} in {time_diff:.5} seconds*
304
+ """
305
+ st.write(info)
306
+
307
+ params["seed"] = seed
308
+ params["prompt"] = st.session_state.text
309
+ params["model"] = generator.model_name
310
+ params_text = json.dumps(params)
311
+ # print(params_text)
312
+ st.json(params_text)
 
 
 
 
 
 
 
313
 
314
 
315
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,7 +1,10 @@
1
- -f https://download.pytorch.org/whl/torch_stable.html
2
- streamlit==1.4.0
3
- torch==1.6.0+cpu
4
- torchvision==0.7.0+cpu
5
- transformers>=4.13.0
 
 
6
  mtranslate
7
  psutil
 
 
1
+ #-f https://download.pytorch.org/whl/torch_stable.html
2
+ -f https://download.pytorch.org/whl/cu116
3
+ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
4
+ protobuf<3.20
5
+ streamlit>=1.4.0,<=1.10.0
6
+ torch
7
+ git+https://github.com/huggingface/transformers.git@1905384fd576acf4b645a8216907f980b4788d9b
8
  mtranslate
9
  psutil
10
+ sentencepiece