Yeb Havinga commited on
Commit
a9f2b23
·
1 Parent(s): 4c45953

Syntactic changes

Browse files
Files changed (1) hide show
  1. app.py +48 -49
app.py CHANGED
@@ -1,20 +1,24 @@
1
  import json
2
  import os
3
- import pprint
4
  import time
5
  from random import randint
6
 
7
  import psutil
8
  import streamlit as st
9
  import torch
10
- from transformers import (AutoModelForCausalLM, AutoTokenizer, pipeline,
11
- set_seed)
 
 
 
 
 
12
 
13
  device = torch.cuda.device_count() - 1
14
 
15
 
16
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
17
- def load_model(model_name):
18
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
  try:
20
  if not os.path.exists(".streamlit/secrets.toml"):
@@ -23,70 +27,68 @@ def load_model(model_name):
23
  except FileNotFoundError:
24
  access_token = os.environ.get("HF_ACCESS_TOKEN", None)
25
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token)
26
- model = AutoModelForCausalLM.from_pretrained(
27
- model_name, use_auth_token=access_token
 
 
 
28
  )
 
29
  if device != -1:
30
  model.to(f"cuda:{device}")
31
  return tokenizer, model
32
 
33
 
34
- class StoryGenerator:
35
- def __init__(self, model_name):
36
- self.model_name = model_name
 
 
37
  self.tokenizer = None
38
  self.model = None
39
- self.generator = None
40
- self.model_loaded = False
41
 
42
  def load(self):
43
- if not self.model_loaded:
44
- self.tokenizer, self.model = load_model(self.model_name)
45
- self.generator = pipeline(
46
- "text-generation",
 
47
  model=self.model,
48
  tokenizer=self.tokenizer,
49
  device=device,
50
  )
51
- self.model_loaded = True
52
 
53
  def get_text(self, text: str, **generate_kwargs) -> str:
54
- return self.generator(text, **generate_kwargs)
55
 
56
 
57
- STORY_GENERATORS = [
58
  {
59
  "model_name": "yhavinga/gpt-neo-125M-dutch-nedd",
60
  "desc": "Dutch GPTNeo Small",
61
- "story_generator": None,
 
62
  },
63
  {
64
  "model_name": "yhavinga/gpt2-medium-dutch-nedd",
65
  "desc": "Dutch GPT2 Medium",
66
- "story_generator": None,
 
67
  },
68
- # {
69
- # "model_name": "yhavinga/gpt-neo-125M-dutch",
70
- # "desc": "Dutch GPTNeo Small",
71
- # "story_generator": None,
72
- # },
73
- # {
74
- # "model_name": "yhavinga/gpt2-medium-dutch",
75
- # "desc": "Dutch GPT2 Medium",
76
- # "story_generator": None,
77
- # },
78
  ]
79
 
80
 
81
  def instantiate_models():
82
- for sg in STORY_GENERATORS:
83
- sg["story_generator"] = StoryGenerator(sg["model_name"])
84
- with st.spinner(text=f"Loading the model {sg['desc']} ..."):
85
- sg["story_generator"].load()
86
 
87
 
88
  def set_new_seed():
89
- seed = randint(0, 2 ** 32 - 1)
90
  set_seed(seed)
91
  return seed
92
 
@@ -104,14 +106,13 @@ def main():
104
  st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
105
 
106
  st.sidebar.image("demon-reading-Stewart-Orr.png", width=200)
107
-
108
  st.sidebar.markdown(
109
  """# Netherator
110
- Teller of tales from the Netherlands"""
111
  )
112
 
113
  model_desc = st.sidebar.selectbox(
114
- "Model", [sg["desc"] for sg in STORY_GENERATORS], index=1
115
  )
116
 
117
  st.sidebar.title("Parameters:")
@@ -126,7 +127,7 @@ Teller of tales from the Netherlands"""
126
  # )
127
  max_length = st.sidebar.number_input(
128
  "Lengte van de tekst",
129
- value=300,
130
  max_value=512,
131
  )
132
  no_repeat_ngram_size = st.sidebar.number_input(
@@ -147,7 +148,7 @@ Teller of tales from the Netherlands"""
147
  "Num beams", min_value=1, max_value=10, value=4
148
  )
149
  length_penalty = st.sidebar.number_input(
150
- "Length penalty", min_value=0.0, max_value=5.0, value=1.5, step=0.1
151
  )
152
  params = {
153
  "max_length": max_length,
@@ -159,14 +160,12 @@ Teller of tales from the Netherlands"""
159
  "length_penalty": length_penalty,
160
  }
161
  else:
162
- top_k = st.sidebar.number_input(
163
- "Top K", min_value=0, max_value=100, value=50
164
- )
165
  top_p = st.sidebar.number_input(
166
  "Top P", min_value=0.0, max_value=1.0, value=0.95, step=0.05
167
  )
168
  temperature = st.sidebar.number_input(
169
- "Temperature", min_value=0.05, max_value=1.0, value=0.8, step=0.05
170
  )
171
  params = {
172
  "max_length": max_length,
@@ -204,17 +203,17 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
204
  text=f"Please wait ~ {estimate} second{'s' if estimate != 1 else ''} while getting results ..."
205
  ):
206
  memory = psutil.virtual_memory()
207
- story_generator = next(
208
  (
209
- x["story_generator"]
210
- for x in STORY_GENERATORS
211
  if x["desc"] == model_desc
212
  ),
213
  None,
214
  )
215
  seed = set_new_seed()
216
  time_start = time.time()
217
- result = story_generator.get_text(text=st.session_state.text, **params)
218
  time_end = time.time()
219
  time_diff = time_end - time_start
220
 
@@ -235,7 +234,7 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
235
 
236
  params["seed"] = seed
237
  params["prompt"] = st.session_state.text
238
- params["model"] = story_generator.model_name
239
  params_text = json.dumps(params)
240
  print(params_text)
241
  st.json(params_text)
 
1
  import json
2
  import os
 
3
  import time
4
  from random import randint
5
 
6
  import psutil
7
  import streamlit as st
8
  import torch
9
+ from transformers import (
10
+ AutoModelForCausalLM,
11
+ AutoModelForSeq2SeqLM,
12
+ AutoTokenizer,
13
+ pipeline,
14
+ set_seed,
15
+ )
16
 
17
  device = torch.cuda.device_count() - 1
18
 
19
 
20
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
21
+ def load_model(model_name, task):
22
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
  try:
24
  if not os.path.exists(".streamlit/secrets.toml"):
 
27
  except FileNotFoundError:
28
  access_token = os.environ.get("HF_ACCESS_TOKEN", None)
29
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token)
30
+ if tokenizer.pad_token is None:
31
+ print("Adding pad_token to the tokenizer")
32
+ tokenizer.pad_token = tokenizer.eos_token
33
+ auto_model_class = (
34
+ AutoModelForSeq2SeqLM if "translation" in task else AutoModelForCausalLM
35
  )
36
+ model = auto_model_class.from_pretrained(model_name, use_auth_token=access_token)
37
  if device != -1:
38
  model.to(f"cuda:{device}")
39
  return tokenizer, model
40
 
41
 
42
+ class ModelTask:
43
+ def __init__(self, p):
44
+ self.model_name = p["model_name"]
45
+ self.task = p["task"]
46
+ self.desc = p["desc"]
47
  self.tokenizer = None
48
  self.model = None
49
+ self.pipeline = None
50
+ self.load()
51
 
52
  def load(self):
53
+ if not self.pipeline:
54
+ print(f"Loading model {self.model_name}")
55
+ self.tokenizer, self.model = load_model(self.model_name, self.task)
56
+ self.pipeline = pipeline(
57
+ task=self.task,
58
  model=self.model,
59
  tokenizer=self.tokenizer,
60
  device=device,
61
  )
 
62
 
63
  def get_text(self, text: str, **generate_kwargs) -> str:
64
+ return self.pipeline(text, **generate_kwargs)
65
 
66
 
67
+ PIPELINES = [
68
  {
69
  "model_name": "yhavinga/gpt-neo-125M-dutch-nedd",
70
  "desc": "Dutch GPTNeo Small",
71
+ "task": "text-generation",
72
+ "pipeline": None,
73
  },
74
  {
75
  "model_name": "yhavinga/gpt2-medium-dutch-nedd",
76
  "desc": "Dutch GPT2 Medium",
77
+ "task": "text-generation",
78
+ "pipeline": None,
79
  },
 
 
 
 
 
 
 
 
 
 
80
  ]
81
 
82
 
83
  def instantiate_models():
84
+ for p in PIPELINES:
85
+ p["pipeline"] = ModelTask(p)
86
+ with st.spinner(text=f"Loading the model {p['desc']} ..."):
87
+ p["pipeline"].load()
88
 
89
 
90
  def set_new_seed():
91
+ seed = randint(0, 2**32 - 1)
92
  set_seed(seed)
93
  return seed
94
 
 
106
  st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
107
 
108
  st.sidebar.image("demon-reading-Stewart-Orr.png", width=200)
 
109
  st.sidebar.markdown(
110
  """# Netherator
111
+ Nederlandse verhalenverteller"""
112
  )
113
 
114
  model_desc = st.sidebar.selectbox(
115
+ "Model", [p["desc"] for p in PIPELINES], index=1
116
  )
117
 
118
  st.sidebar.title("Parameters:")
 
127
  # )
128
  max_length = st.sidebar.number_input(
129
  "Lengte van de tekst",
130
+ value=200,
131
  max_value=512,
132
  )
133
  no_repeat_ngram_size = st.sidebar.number_input(
 
148
  "Num beams", min_value=1, max_value=10, value=4
149
  )
150
  length_penalty = st.sidebar.number_input(
151
+ "Length penalty", min_value=0.0, max_value=2.0, value=1.0, step=0.1
152
  )
153
  params = {
154
  "max_length": max_length,
 
160
  "length_penalty": length_penalty,
161
  }
162
  else:
163
+ top_k = st.sidebar.number_input("Top K", min_value=0, max_value=100, value=50)
 
 
164
  top_p = st.sidebar.number_input(
165
  "Top P", min_value=0.0, max_value=1.0, value=0.95, step=0.05
166
  )
167
  temperature = st.sidebar.number_input(
168
+ "Temperature", min_value=0.05, max_value=1.0, value=1.0, step=0.05
169
  )
170
  params = {
171
  "max_length": max_length,
 
203
  text=f"Please wait ~ {estimate} second{'s' if estimate != 1 else ''} while getting results ..."
204
  ):
205
  memory = psutil.virtual_memory()
206
+ generator = next(
207
  (
208
+ x["pipeline"]
209
+ for x in PIPELINES
210
  if x["desc"] == model_desc
211
  ),
212
  None,
213
  )
214
  seed = set_new_seed()
215
  time_start = time.time()
216
+ result = generator.get_text(text=st.session_state.text, **params)
217
  time_end = time.time()
218
  time_diff = time_end - time_start
219
 
 
234
 
235
  params["seed"] = seed
236
  params["prompt"] = st.session_state.text
237
+ params["model"] = generator.model_name
238
  params_text = json.dumps(params)
239
  print(params_text)
240
  st.json(params_text)