Yeb Havinga commited on
Commit
f839da7
·
1 Parent(s): 5da87aa

Add translation. Keep GeneratorFactory in the session cache.

Browse files
Files changed (1) hide show
  1. app.py +71 -49
app.py CHANGED
@@ -6,11 +6,18 @@ from random import randint
6
  import psutil
7
  import streamlit as st
8
  import torch
9
- from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
10
- AutoTokenizer, pipeline, set_seed)
 
 
 
 
 
11
 
12
  device = torch.cuda.device_count() - 1
13
 
 
 
14
 
15
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
16
  def load_model(model_name, task):
@@ -63,43 +70,45 @@ class GeneratorFactory:
63
  def __init__(self):
64
  self.generators = []
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def add_generator(self, model_name, task, desc):
67
- g = Generator(model_name, task, desc)
68
- g.load()
69
- self.generators.append(g)
 
 
70
 
71
- def get_generator(self, model_desc):
72
  for g in self.generators:
73
- if g.desc == model_desc:
74
  return g
75
  return None
76
 
77
-
78
- GENERATORS = [
79
- {
80
- "model_name": "yhavinga/gpt-neo-125M-dutch-nedd",
81
- "desc": "GPT-Neo Small Dutch(book finetune)",
82
- "task": "text-generation",
83
- },
84
- {
85
- "model_name": "yhavinga/gpt2-medium-dutch-nedd",
86
- "desc": "GPT2 Medium Dutch (book finetune)",
87
- "task": "text-generation",
88
- },
89
- {
90
- "model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
91
- "desc": "Dutch<->English T5 small 24 layers",
92
- "task": "translation_nl_to_en",
93
- },
94
- ]
95
-
96
- generators = GeneratorFactory()
97
-
98
-
99
- def instantiate_generators():
100
- for g in GENERATORS:
101
- with st.spinner(text=f"Loading the model {g['desc']} ..."):
102
- generators.add_generator(**g)
103
 
104
 
105
  def main():
@@ -109,7 +118,11 @@ def main():
109
  initial_sidebar_state="expanded", # Can be "auto", "expanded", "collapsed"
110
  page_icon="📚", # String, anything supported by st.image, or None.
111
  )
112
- instantiate_generators()
 
 
 
 
113
 
114
  with open("style.css") as f:
115
  st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
@@ -119,18 +132,11 @@ def main():
119
  """# Netherator
120
  Nederlandse verhalenverteller"""
121
  )
122
-
123
- model_desc = st.sidebar.selectbox(
124
- "Model", [p["desc"] for p in GENERATORS if "generation" in p["task"]], index=1
125
- )
126
-
127
  st.sidebar.title("Parameters:")
128
-
129
  if "prompt_box" not in st.session_state:
130
  st.session_state["prompt_box"] = "Het was een koude winterdag"
131
-
132
  st.session_state["text"] = st.text_area("Enter text", st.session_state.prompt_box)
133
-
134
  max_length = st.sidebar.number_input(
135
  "Lengte van de tekst",
136
  value=200,
@@ -145,7 +151,6 @@ def main():
145
  num_return_sequences = st.sidebar.number_input(
146
  "Num return sequences", min_value=1, max_value=5, value=1
147
  )
148
-
149
  seed_placeholder = st.sidebar.empty()
150
  if "seed" not in st.session_state:
151
  print(f"Session state {st.session_state} does not contain seed")
@@ -231,20 +236,37 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
231
  text=f"Please wait ~ {estimate} second{'s' if estimate != 1 else ''} while getting results ..."
232
  ):
233
  memory = psutil.virtual_memory()
234
- generator = generators.get_generator(model_desc)
235
  set_seed(seed)
236
  time_start = time.time()
237
  result = generator.get_text(text=st.session_state.text, **params)
238
  time_end = time.time()
239
  time_diff = time_end - time_start
240
-
241
  st.subheader("Result")
 
242
  for text in result:
243
  st.write(text.get("generated_text").replace("\n", " \n"))
244
-
245
- # st.text("*Translation*")
246
- # translation = translate(result, "en", "nl")
247
- # st.write(translation.replace("\n", " \n"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  #
249
  info = f"""
250
  ---
 
6
  import psutil
7
  import streamlit as st
8
  import torch
9
+ from transformers import (
10
+ AutoModelForCausalLM,
11
+ AutoModelForSeq2SeqLM,
12
+ AutoTokenizer,
13
+ pipeline,
14
+ set_seed,
15
+ )
16
 
17
  device = torch.cuda.device_count() - 1
18
 
19
+ TRANSLATION_NL_TO_EN = "translation_en_to_nl"
20
+
21
 
22
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
23
  def load_model(model_name, task):
 
70
  def __init__(self):
71
  self.generators = []
72
 
73
+ def instantiate_generators(self):
74
+ GENERATOR_LIST = [
75
+ {
76
+ "model_name": "yhavinga/gpt-neo-125M-dutch-nedd",
77
+ "desc": "GPT-Neo Small Dutch(book finetune)",
78
+ "task": "text-generation",
79
+ },
80
+ {
81
+ "model_name": "yhavinga/gpt2-medium-dutch-nedd",
82
+ "desc": "GPT2 Medium Dutch (book finetune)",
83
+ "task": "text-generation",
84
+ },
85
+ {
86
+ "model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
87
+ "desc": "Dutch<->English T5 small 24 layers",
88
+ "task": TRANSLATION_NL_TO_EN,
89
+ },
90
+ ]
91
+ for g in GENERATOR_LIST:
92
+ with st.spinner(text=f"Loading the model {g['desc']} ..."):
93
+ self.add_generator(**g)
94
+
95
+ return self
96
+
97
  def add_generator(self, model_name, task, desc):
98
+ # If the generator is not yet present, add it
99
+ if not self.get_generator(model_name=model_name, task=task, desc=desc):
100
+ g = Generator(model_name, task, desc)
101
+ g.load()
102
+ self.generators.append(g)
103
 
104
+ def get_generator(self, **kwargs):
105
  for g in self.generators:
106
+ if all([g.__dict__.get(k) == v for k, v in kwargs.items()]):
107
  return g
108
  return None
109
 
110
+ def gpt_descs(self):
111
+ return [g.desc for g in self.generators if g.task == "text-generation"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
 
114
  def main():
 
118
  initial_sidebar_state="expanded", # Can be "auto", "expanded", "collapsed"
119
  page_icon="📚", # String, anything supported by st.image, or None.
120
  )
121
+
122
+ if "generators" not in st.session_state:
123
+ st.session_state["generators"] = GeneratorFactory().instantiate_generators()
124
+
125
+ generators = st.session_state["generators"]
126
 
127
  with open("style.css") as f:
128
  st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
 
132
  """# Netherator
133
  Nederlandse verhalenverteller"""
134
  )
135
+ model_desc = st.sidebar.selectbox("Model", generators.gpt_descs(), index=1)
 
 
 
 
136
  st.sidebar.title("Parameters:")
 
137
  if "prompt_box" not in st.session_state:
138
  st.session_state["prompt_box"] = "Het was een koude winterdag"
 
139
  st.session_state["text"] = st.text_area("Enter text", st.session_state.prompt_box)
 
140
  max_length = st.sidebar.number_input(
141
  "Lengte van de tekst",
142
  value=200,
 
151
  num_return_sequences = st.sidebar.number_input(
152
  "Num return sequences", min_value=1, max_value=5, value=1
153
  )
 
154
  seed_placeholder = st.sidebar.empty()
155
  if "seed" not in st.session_state:
156
  print(f"Session state {st.session_state} does not contain seed")
 
236
  text=f"Please wait ~ {estimate} second{'s' if estimate != 1 else ''} while getting results ..."
237
  ):
238
  memory = psutil.virtual_memory()
239
+ generator = generators.get_generator(desc=model_desc)
240
  set_seed(seed)
241
  time_start = time.time()
242
  result = generator.get_text(text=st.session_state.text, **params)
243
  time_end = time.time()
244
  time_diff = time_end - time_start
 
245
  st.subheader("Result")
246
+
247
  for text in result:
248
  st.write(text.get("generated_text").replace("\n", " \n"))
249
+ st.text("*Translation*")
250
+ translate_params = {
251
+ "num_return_sequences": 1,
252
+ "num_beams": 4,
253
+ "early_stopping": True,
254
+ "length_penalty": 1.1,
255
+ "max_length": 200,
256
+ }
257
+ text_lines = [
258
+ "translate Dutch to English: " + t
259
+ for t in text.get("generated_text").splitlines()
260
+ ]
261
+ translated_lines = [
262
+ t["translation_text"]
263
+ for t in generators.get_generator(
264
+ task=TRANSLATION_NL_TO_EN
265
+ ).get_text(text_lines, **translate_params)
266
+ ]
267
+ translation = " \n".join(translated_lines)
268
+ st.write(translation)
269
+ st.write("---")
270
  #
271
  info = f"""
272
  ---