Yeb Havinga commited on
Commit
4c45953
·
1 Parent(s): 43037cf
Files changed (7) hide show
  1. .gitignore +4 -0
  2. .streamlit/config.toml +8 -0
  3. README.md +6 -5
  4. app.py +245 -0
  5. demon-reading-Stewart-Orr.png +0 -0
  6. requirements.txt +7 -0
  7. style.css +42 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ venv
2
+ .idea
3
+ __pycache__
4
+ *~
.streamlit/config.toml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [server]
2
+ headless = true
3
+
4
+ [theme]
5
+ base="dark"
6
+ primaryColor="#139ace"
7
+ secondaryBackgroundColor="#2b2b39"
8
+ textColor="#cdd8d3"
README.md CHANGED
@@ -1,11 +1,12 @@
1
  ---
2
- title: Netherator
3
- emoji: 💩
4
- colorFrom: yellow
5
- colorTo: yellow
6
  sdk: streamlit
7
  app_file: app.py
8
- pinned: false
 
9
  ---
10
 
11
  # Configuration
 
1
  ---
2
+ title: Netherator - teller of tales from the Netherlands
3
+ emoji: 🧙
4
+ colorFrom: gray
5
+ colorTo: indigo
6
  sdk: streamlit
7
  app_file: app.py
8
+ pinned: true
9
+ sdk_version: 1.0.0
10
  ---
11
 
12
  # Configuration
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pprint
4
+ import time
5
+ from random import randint
6
+
7
+ import psutil
8
+ import streamlit as st
9
+ import torch
10
+ from transformers import (AutoModelForCausalLM, AutoTokenizer, pipeline,
11
+ set_seed)
12
+
13
+ device = torch.cuda.device_count() - 1
14
+
15
+
16
+ @st.cache(suppress_st_warning=True, allow_output_mutation=True)
17
+ def load_model(model_name):
18
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
+ try:
20
+ if not os.path.exists(".streamlit/secrets.toml"):
21
+ raise FileNotFoundError
22
+ access_token = st.secrets.get("netherator")
23
+ except FileNotFoundError:
24
+ access_token = os.environ.get("HF_ACCESS_TOKEN", None)
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token)
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_name, use_auth_token=access_token
28
+ )
29
+ if device != -1:
30
+ model.to(f"cuda:{device}")
31
+ return tokenizer, model
32
+
33
+
34
+ class StoryGenerator:
35
+ def __init__(self, model_name):
36
+ self.model_name = model_name
37
+ self.tokenizer = None
38
+ self.model = None
39
+ self.generator = None
40
+ self.model_loaded = False
41
+
42
+ def load(self):
43
+ if not self.model_loaded:
44
+ self.tokenizer, self.model = load_model(self.model_name)
45
+ self.generator = pipeline(
46
+ "text-generation",
47
+ model=self.model,
48
+ tokenizer=self.tokenizer,
49
+ device=device,
50
+ )
51
+ self.model_loaded = True
52
+
53
+ def get_text(self, text: str, **generate_kwargs) -> str:
54
+ return self.generator(text, **generate_kwargs)
55
+
56
+
57
+ STORY_GENERATORS = [
58
+ {
59
+ "model_name": "yhavinga/gpt-neo-125M-dutch-nedd",
60
+ "desc": "Dutch GPTNeo Small",
61
+ "story_generator": None,
62
+ },
63
+ {
64
+ "model_name": "yhavinga/gpt2-medium-dutch-nedd",
65
+ "desc": "Dutch GPT2 Medium",
66
+ "story_generator": None,
67
+ },
68
+ # {
69
+ # "model_name": "yhavinga/gpt-neo-125M-dutch",
70
+ # "desc": "Dutch GPTNeo Small",
71
+ # "story_generator": None,
72
+ # },
73
+ # {
74
+ # "model_name": "yhavinga/gpt2-medium-dutch",
75
+ # "desc": "Dutch GPT2 Medium",
76
+ # "story_generator": None,
77
+ # },
78
+ ]
79
+
80
+
81
+ def instantiate_models():
82
+ for sg in STORY_GENERATORS:
83
+ sg["story_generator"] = StoryGenerator(sg["model_name"])
84
+ with st.spinner(text=f"Loading the model {sg['desc']} ..."):
85
+ sg["story_generator"].load()
86
+
87
+
88
+ def set_new_seed():
89
+ seed = randint(0, 2 ** 32 - 1)
90
+ set_seed(seed)
91
+ return seed
92
+
93
+
94
+ def main():
95
+ st.set_page_config( # Alternate names: setup_page, page, layout
96
+ page_title="Netherator", # String or None. Strings get appended with "• Streamlit".
97
+ layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
98
+ initial_sidebar_state="expanded", # Can be "auto", "expanded", "collapsed"
99
+ page_icon="📚", # String, anything supported by st.image, or None.
100
+ )
101
+ instantiate_models()
102
+
103
+ with open("style.css") as f:
104
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
105
+
106
+ st.sidebar.image("demon-reading-Stewart-Orr.png", width=200)
107
+
108
+ st.sidebar.markdown(
109
+ """# Netherator
110
+ Teller of tales from the Netherlands"""
111
+ )
112
+
113
+ model_desc = st.sidebar.selectbox(
114
+ "Model", [sg["desc"] for sg in STORY_GENERATORS], index=1
115
+ )
116
+
117
+ st.sidebar.title("Parameters:")
118
+
119
+ if "prompt_box" not in st.session_state:
120
+ st.session_state["prompt_box"] = "Het was een koude winterdag"
121
+
122
+ st.session_state["text"] = st.text_area("Enter text", st.session_state.prompt_box)
123
+
124
+ # min_length = st.sidebar.number_input(
125
+ # "Min length", min_value=10, max_value=150, value=75
126
+ # )
127
+ max_length = st.sidebar.number_input(
128
+ "Lengte van de tekst",
129
+ value=300,
130
+ max_value=512,
131
+ )
132
+ no_repeat_ngram_size = st.sidebar.number_input(
133
+ "No-repeat NGram size", min_value=1, max_value=5, value=3
134
+ )
135
+ repetition_penalty = st.sidebar.number_input(
136
+ "Repetition penalty", min_value=0.0, max_value=5.0, value=1.2, step=0.1
137
+ )
138
+ num_return_sequences = st.sidebar.number_input(
139
+ "Num return sequences", min_value=1, max_value=5, value=1
140
+ )
141
+
142
+ if sampling_mode := st.sidebar.selectbox(
143
+ "select a Mode", index=0, options=["Top-k Sampling", "Beam Search"]
144
+ ):
145
+ if sampling_mode == "Beam Search":
146
+ num_beams = st.sidebar.number_input(
147
+ "Num beams", min_value=1, max_value=10, value=4
148
+ )
149
+ length_penalty = st.sidebar.number_input(
150
+ "Length penalty", min_value=0.0, max_value=5.0, value=1.5, step=0.1
151
+ )
152
+ params = {
153
+ "max_length": max_length,
154
+ "no_repeat_ngram_size": no_repeat_ngram_size,
155
+ "repetition_penalty": repetition_penalty,
156
+ "num_return_sequences": num_return_sequences,
157
+ "num_beams": num_beams,
158
+ "early_stopping": True,
159
+ "length_penalty": length_penalty,
160
+ }
161
+ else:
162
+ top_k = st.sidebar.number_input(
163
+ "Top K", min_value=0, max_value=100, value=50
164
+ )
165
+ top_p = st.sidebar.number_input(
166
+ "Top P", min_value=0.0, max_value=1.0, value=0.95, step=0.05
167
+ )
168
+ temperature = st.sidebar.number_input(
169
+ "Temperature", min_value=0.05, max_value=1.0, value=0.8, step=0.05
170
+ )
171
+ params = {
172
+ "max_length": max_length,
173
+ "no_repeat_ngram_size": no_repeat_ngram_size,
174
+ "repetition_penalty": repetition_penalty,
175
+ "num_return_sequences": num_return_sequences,
176
+ "do_sample": True,
177
+ "top_k": top_k,
178
+ "top_p": top_p,
179
+ "temperature": temperature,
180
+ }
181
+
182
+ st.sidebar.markdown(
183
+ """For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
184
+ and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
185
+ """
186
+ )
187
+
188
+ if st.button("Run"):
189
+ estimate = max_length / 18
190
+ if device == -1:
191
+ ## cpu
192
+ estimate = estimate * (1 + 0.7 * (num_return_sequences - 1))
193
+ if sampling_mode == "Beam Search":
194
+ estimate = estimate * (1.1 + 0.3 * (num_beams - 1))
195
+ else:
196
+ ## gpu
197
+ estimate = estimate * (1 + 0.1 * (num_return_sequences - 1))
198
+ estimate = 0.5 + estimate / 5
199
+ if sampling_mode == "Beam Search":
200
+ estimate = estimate * (1.0 + 0.1 * (num_beams - 1))
201
+ estimate = int(estimate)
202
+
203
+ with st.spinner(
204
+ text=f"Please wait ~ {estimate} second{'s' if estimate != 1 else ''} while getting results ..."
205
+ ):
206
+ memory = psutil.virtual_memory()
207
+ story_generator = next(
208
+ (
209
+ x["story_generator"]
210
+ for x in STORY_GENERATORS
211
+ if x["desc"] == model_desc
212
+ ),
213
+ None,
214
+ )
215
+ seed = set_new_seed()
216
+ time_start = time.time()
217
+ result = story_generator.get_text(text=st.session_state.text, **params)
218
+ time_end = time.time()
219
+ time_diff = time_end - time_start
220
+
221
+ st.subheader("Result")
222
+ for text in result:
223
+ st.write(text.get("generated_text").replace("\n", " \n"))
224
+
225
+ # st.text("*Translation*")
226
+ # translation = translate(result, "en", "nl")
227
+ # st.write(translation.replace("\n", " \n"))
228
+ #
229
+ info = f"""
230
+ ---
231
+ *Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB*
232
+ *Text generated using seed {seed} in {time_diff:.5} seconds*
233
+ """
234
+ st.write(info)
235
+
236
+ params["seed"] = seed
237
+ params["prompt"] = st.session_state.text
238
+ params["model"] = story_generator.model_name
239
+ params_text = json.dumps(params)
240
+ print(params_text)
241
+ st.json(params_text)
242
+
243
+
244
+ if __name__ == "__main__":
245
+ main()
demon-reading-Stewart-Orr.png ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ -f https://download.pytorch.org/whl/torch_stable.html
2
+ streamlit==1.4.0
3
+ torch==1.6.0+cpu
4
+ torchvision==0.7.0+cpu
5
+ transformers>=4.13.0
6
+ mtranslate
7
+ psutil
style.css ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ background-color: #eee;
3
+ }
4
+ /*.fullScreenFrame > div {*/
5
+ /* display: flex;*/
6
+ /* justify-content: center;*/
7
+ /*}*/
8
+ /*.stButton>button {*/
9
+ /* color: #4F8BF9;*/
10
+ /* border-radius: 50%;*/
11
+ /* height: 3em;*/
12
+ /* width: 3em;*/
13
+ /*}*/
14
+
15
+ .stTextInput>div>div>input {
16
+ color: #4F8BF9;
17
+ }
18
+ .stTextArea>div>div>input {
19
+ color: #4F8BF9;
20
+ min-height: 500px;
21
+ }
22
+
23
+
24
+ /*.st-cj {*/
25
+ /* min-height: 500px;*/
26
+ /* spellcheck="false";*/
27
+ /* color: #4F8BF9;*/
28
+ /*}*/
29
+ /*.st-ch {*/
30
+ /* min-height: 500px;*/
31
+ /* spellcheck="false";*/
32
+ /* color: #4F8BF9;*/
33
+ /*}*/
34
+ /*.st-bb {*/
35
+ /* min-height: 500px;*/
36
+ /* spellcheck="false";*/
37
+ /* color: #4F8BF9;*/
38
+ /*}*/
39
+
40
+ /*body {*/
41
+ /* background-color: #f1fbff*/
42
+ /*}*/