Yeb Havinga commited on
Commit
5da87aa
·
1 Parent(s): 5cf4ee2

Refactor model+task code using a factory. Run black

Browse files
Files changed (1) hide show
  1. app.py +48 -38
app.py CHANGED
@@ -6,13 +6,8 @@ from random import randint
6
  import psutil
7
  import streamlit as st
8
  import torch
9
- from transformers import (
10
- AutoModelForCausalLM,
11
- AutoModelForSeq2SeqLM,
12
- AutoTokenizer,
13
- pipeline,
14
- set_seed,
15
- )
16
 
17
  device = torch.cuda.device_count() - 1
18
 
@@ -39,11 +34,11 @@ def load_model(model_name, task):
39
  return tokenizer, model
40
 
41
 
42
- class ModelTask:
43
- def __init__(self, p):
44
- self.model_name = p["model_name"]
45
- self.task = p["task"]
46
- self.desc = p["desc"]
47
  self.tokenizer = None
48
  self.model = None
49
  self.pipeline = None
@@ -64,27 +59,47 @@ class ModelTask:
64
  return self.pipeline(text, **generate_kwargs)
65
 
66
 
67
- PIPELINES = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  {
69
  "model_name": "yhavinga/gpt-neo-125M-dutch-nedd",
70
- "desc": "Dutch GPTNeo Small",
71
  "task": "text-generation",
72
- "pipeline": None,
73
  },
74
  {
75
  "model_name": "yhavinga/gpt2-medium-dutch-nedd",
76
- "desc": "Dutch GPT2 Medium",
77
  "task": "text-generation",
78
- "pipeline": None,
 
 
 
 
79
  },
80
  ]
81
 
 
 
82
 
83
- def instantiate_models():
84
- for p in PIPELINES:
85
- p["pipeline"] = ModelTask(p)
86
- with st.spinner(text=f"Loading the model {p['desc']} ..."):
87
- p["pipeline"].load()
88
 
89
 
90
  def main():
@@ -94,7 +109,7 @@ def main():
94
  initial_sidebar_state="expanded", # Can be "auto", "expanded", "collapsed"
95
  page_icon="📚", # String, anything supported by st.image, or None.
96
  )
97
- instantiate_models()
98
 
99
  with open("style.css") as f:
100
  st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
@@ -106,7 +121,7 @@ def main():
106
  )
107
 
108
  model_desc = st.sidebar.selectbox(
109
- "Model", [p["desc"] for p in PIPELINES], index=1
110
  )
111
 
112
  st.sidebar.title("Parameters:")
@@ -138,13 +153,13 @@ def main():
138
  print(f"Seed is set to: {st.session_state['seed']}")
139
 
140
  seed = seed_placeholder.number_input(
141
- "Seed", min_value=0, max_value=2 ** 32 - 1, value=st.session_state["seed"]
142
  )
143
 
144
  def set_random_seed():
145
- st.session_state["seed"] = randint(0, 2 ** 32 - 1)
146
  seed = seed_placeholder.number_input(
147
- "Seed", min_value=0, max_value=2 ** 32 - 1, value=st.session_state["seed"]
148
  )
149
  print(f"New random seed set to: {seed}")
150
 
@@ -152,7 +167,7 @@ def main():
152
  set_random_seed()
153
 
154
  if sampling_mode := st.sidebar.selectbox(
155
- "select a Mode", index=0, options=["Top-k Sampling", "Beam Search"]
156
  ):
157
  if sampling_mode == "Beam Search":
158
  num_beams = st.sidebar.number_input(
@@ -171,7 +186,9 @@ def main():
171
  "length_penalty": length_penalty,
172
  }
173
  else:
174
- top_k = st.sidebar.number_input("Top K", min_value=0, max_value=100, value=50)
 
 
175
  top_p = st.sidebar.number_input(
176
  "Top P", min_value=0.0, max_value=1.0, value=0.95, step=0.05
177
  )
@@ -211,17 +228,10 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
211
  estimate = int(estimate)
212
 
213
  with st.spinner(
214
- text=f"Please wait ~ {estimate} second{'s' if estimate != 1 else ''} while getting results ..."
215
  ):
216
  memory = psutil.virtual_memory()
217
- generator = next(
218
- (
219
- x["pipeline"]
220
- for x in PIPELINES
221
- if x["desc"] == model_desc
222
- ),
223
- None,
224
- )
225
  set_seed(seed)
226
  time_start = time.time()
227
  result = generator.get_text(text=st.session_state.text, **params)
 
6
  import psutil
7
  import streamlit as st
8
  import torch
9
+ from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
10
+ AutoTokenizer, pipeline, set_seed)
 
 
 
 
 
11
 
12
  device = torch.cuda.device_count() - 1
13
 
 
34
  return tokenizer, model
35
 
36
 
37
+ class Generator:
38
+ def __init__(self, model_name, task, desc):
39
+ self.model_name = model_name
40
+ self.task = task
41
+ self.desc = desc
42
  self.tokenizer = None
43
  self.model = None
44
  self.pipeline = None
 
59
  return self.pipeline(text, **generate_kwargs)
60
 
61
 
62
+ class GeneratorFactory:
63
+ def __init__(self):
64
+ self.generators = []
65
+
66
+ def add_generator(self, model_name, task, desc):
67
+ g = Generator(model_name, task, desc)
68
+ g.load()
69
+ self.generators.append(g)
70
+
71
+ def get_generator(self, model_desc):
72
+ for g in self.generators:
73
+ if g.desc == model_desc:
74
+ return g
75
+ return None
76
+
77
+
78
+ GENERATORS = [
79
  {
80
  "model_name": "yhavinga/gpt-neo-125M-dutch-nedd",
81
+ "desc": "GPT-Neo Small Dutch(book finetune)",
82
  "task": "text-generation",
 
83
  },
84
  {
85
  "model_name": "yhavinga/gpt2-medium-dutch-nedd",
86
+ "desc": "GPT2 Medium Dutch (book finetune)",
87
  "task": "text-generation",
88
+ },
89
+ {
90
+ "model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
91
+ "desc": "Dutch<->English T5 small 24 layers",
92
+ "task": "translation_nl_to_en",
93
  },
94
  ]
95
 
96
+ generators = GeneratorFactory()
97
+
98
 
99
+ def instantiate_generators():
100
+ for g in GENERATORS:
101
+ with st.spinner(text=f"Loading the model {g['desc']} ..."):
102
+ generators.add_generator(**g)
 
103
 
104
 
105
  def main():
 
109
  initial_sidebar_state="expanded", # Can be "auto", "expanded", "collapsed"
110
  page_icon="📚", # String, anything supported by st.image, or None.
111
  )
112
+ instantiate_generators()
113
 
114
  with open("style.css") as f:
115
  st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
 
121
  )
122
 
123
  model_desc = st.sidebar.selectbox(
124
+ "Model", [p["desc"] for p in GENERATORS if "generation" in p["task"]], index=1
125
  )
126
 
127
  st.sidebar.title("Parameters:")
 
153
  print(f"Seed is set to: {st.session_state['seed']}")
154
 
155
  seed = seed_placeholder.number_input(
156
+ "Seed", min_value=0, max_value=2**32 - 1, value=st.session_state["seed"]
157
  )
158
 
159
  def set_random_seed():
160
+ st.session_state["seed"] = randint(0, 2**32 - 1)
161
  seed = seed_placeholder.number_input(
162
+ "Seed", min_value=0, max_value=2**32 - 1, value=st.session_state["seed"]
163
  )
164
  print(f"New random seed set to: {seed}")
165
 
 
167
  set_random_seed()
168
 
169
  if sampling_mode := st.sidebar.selectbox(
170
+ "select a Mode", index=0, options=["Top-k Sampling", "Beam Search"]
171
  ):
172
  if sampling_mode == "Beam Search":
173
  num_beams = st.sidebar.number_input(
 
186
  "length_penalty": length_penalty,
187
  }
188
  else:
189
+ top_k = st.sidebar.number_input(
190
+ "Top K", min_value=0, max_value=100, value=50
191
+ )
192
  top_p = st.sidebar.number_input(
193
  "Top P", min_value=0.0, max_value=1.0, value=0.95, step=0.05
194
  )
 
228
  estimate = int(estimate)
229
 
230
  with st.spinner(
231
+ text=f"Please wait ~ {estimate} second{'s' if estimate != 1 else ''} while getting results ..."
232
  ):
233
  memory = psutil.virtual_memory()
234
+ generator = generators.get_generator(model_desc)
 
 
 
 
 
 
 
235
  set_seed(seed)
236
  time_start = time.time()
237
  result = generator.get_text(text=st.session_state.text, **params)