Spaces:
Sleeping
Sleeping
Add streaming, disable translation
Browse files* Also upgrade transformers, add sentencepiece
- app.py +110 -81
- requirements.txt +8 -5
app.py
CHANGED
@@ -10,6 +10,7 @@ from transformers import (
|
|
10 |
AutoModelForCausalLM,
|
11 |
AutoModelForSeq2SeqLM,
|
12 |
AutoTokenizer,
|
|
|
13 |
pipeline,
|
14 |
set_seed,
|
15 |
)
|
@@ -41,6 +42,20 @@ def load_model(model_name, task):
|
|
41 |
return tokenizer, model
|
42 |
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
class Generator:
|
45 |
def __init__(self, model_name, task, desc):
|
46 |
self.model_name = model_name
|
@@ -52,18 +67,38 @@ class Generator:
|
|
52 |
self.load()
|
53 |
|
54 |
def load(self):
|
55 |
-
if not self.
|
56 |
print(f"Loading model {self.model_name}")
|
57 |
self.tokenizer, self.model = load_model(self.model_name, self.task)
|
58 |
-
self.pipeline = pipeline(
|
59 |
-
task=self.task,
|
60 |
-
model=self.model,
|
61 |
-
tokenizer=self.tokenizer,
|
62 |
-
device=device,
|
63 |
-
)
|
64 |
|
65 |
-
def
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
|
69 |
class GeneratorFactory:
|
@@ -82,11 +117,11 @@ class GeneratorFactory:
|
|
82 |
"desc": "GPT2 Medium Dutch (book finetune)",
|
83 |
"task": "text-generation",
|
84 |
},
|
85 |
-
{
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
},
|
90 |
]
|
91 |
for g in GENERATOR_LIST:
|
92 |
with st.spinner(text=f"Loading the model {g['desc']} ..."):
|
@@ -148,12 +183,13 @@ def main():
|
|
148 |
repetition_penalty = st.sidebar.number_input(
|
149 |
"Repetition penalty", min_value=0.0, max_value=5.0, value=1.2, step=0.1
|
150 |
)
|
151 |
-
num_return_sequences =
|
152 |
-
|
153 |
-
|
|
|
154 |
seed_placeholder = st.sidebar.empty()
|
155 |
if "seed" not in st.session_state:
|
156 |
-
print(f"Session state
|
157 |
st.session_state["seed"] = 4162549114
|
158 |
print(f"Seed is set to: {st.session_state['seed']}")
|
159 |
|
@@ -218,69 +254,62 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
|
|
218 |
)
|
219 |
|
220 |
if st.button("Run"):
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
)
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
params["seed"] = seed
|
279 |
-
params["prompt"] = st.session_state.text
|
280 |
-
params["model"] = generator.model_name
|
281 |
-
params_text = json.dumps(params)
|
282 |
-
print(params_text)
|
283 |
-
st.json(params_text)
|
284 |
|
285 |
|
286 |
if __name__ == "__main__":
|
|
|
10 |
AutoModelForCausalLM,
|
11 |
AutoModelForSeq2SeqLM,
|
12 |
AutoTokenizer,
|
13 |
+
TextIteratorStreamer,
|
14 |
pipeline,
|
15 |
set_seed,
|
16 |
)
|
|
|
42 |
return tokenizer, model
|
43 |
|
44 |
|
45 |
+
class StreamlitTextIteratorStreamer(TextIteratorStreamer):
|
46 |
+
def __init__(
|
47 |
+
self, output_placeholder, tokenizer, skip_prompt=False, **decode_kwargs
|
48 |
+
):
|
49 |
+
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
50 |
+
self.output_placeholder = output_placeholder
|
51 |
+
self.output_text = ""
|
52 |
+
|
53 |
+
def on_finalized_text(self, text: str, stream_end: bool = False):
|
54 |
+
self.output_text += text
|
55 |
+
self.output_placeholder.markdown(self.output_text, unsafe_allow_html=True)
|
56 |
+
super().on_finalized_text(text, stream_end)
|
57 |
+
|
58 |
+
|
59 |
class Generator:
|
60 |
def __init__(self, model_name, task, desc):
|
61 |
self.model_name = model_name
|
|
|
67 |
self.load()
|
68 |
|
69 |
def load(self):
|
70 |
+
if not self.model:
|
71 |
print(f"Loading model {self.model_name}")
|
72 |
self.tokenizer, self.model = load_model(self.model_name, self.task)
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
+
def generate(self, text: str, streamer=None, **generate_kwargs) -> (str, dict):
|
75 |
+
batch_encoded = self.tokenizer(
|
76 |
+
text,
|
77 |
+
max_length=generate_kwargs["max_length"],
|
78 |
+
padding=False,
|
79 |
+
truncation=False,
|
80 |
+
return_tensors="pt",
|
81 |
+
)
|
82 |
+
if device != -1:
|
83 |
+
batch_encoded.to(f"cuda:{device}")
|
84 |
+
logits = self.model.generate(
|
85 |
+
batch_encoded["input_ids"],
|
86 |
+
attention_mask=batch_encoded["attention_mask"],
|
87 |
+
streamer=streamer,
|
88 |
+
**generate_kwargs,
|
89 |
+
)
|
90 |
+
decoded_preds = self.tokenizer.batch_decode(
|
91 |
+
logits.cpu().numpy(), skip_special_tokens=False
|
92 |
+
)
|
93 |
+
|
94 |
+
def replace_tokens(pred):
|
95 |
+
pred = pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
|
96 |
+
if hasattr(self.tokenizer, "newline_token"):
|
97 |
+
pred = pred.replace(self.tokenizer.newline_token, "\n")
|
98 |
+
return pred
|
99 |
+
|
100 |
+
decoded_preds = list(map(replace_tokens, decoded_preds))
|
101 |
+
return decoded_preds[0], generate_kwargs
|
102 |
|
103 |
|
104 |
class GeneratorFactory:
|
|
|
117 |
"desc": "GPT2 Medium Dutch (book finetune)",
|
118 |
"task": "text-generation",
|
119 |
},
|
120 |
+
# {
|
121 |
+
# "model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
|
122 |
+
# "desc": "Dutch<->English T5 small 24 layers",
|
123 |
+
# "task": TRANSLATION_NL_TO_EN,
|
124 |
+
# },
|
125 |
]
|
126 |
for g in GENERATOR_LIST:
|
127 |
with st.spinner(text=f"Loading the model {g['desc']} ..."):
|
|
|
183 |
repetition_penalty = st.sidebar.number_input(
|
184 |
"Repetition penalty", min_value=0.0, max_value=5.0, value=1.2, step=0.1
|
185 |
)
|
186 |
+
num_return_sequences = 1
|
187 |
+
# st.sidebar.number_input(
|
188 |
+
# "Num return sequences", min_value=1, max_value=5, value=1
|
189 |
+
# )
|
190 |
seed_placeholder = st.sidebar.empty()
|
191 |
if "seed" not in st.session_state:
|
192 |
+
print(f"Session state does not contain seed")
|
193 |
st.session_state["seed"] = 4162549114
|
194 |
print(f"Seed is set to: {st.session_state['seed']}")
|
195 |
|
|
|
254 |
)
|
255 |
|
256 |
if st.button("Run"):
|
257 |
+
memory = psutil.virtual_memory()
|
258 |
+
st.subheader("Result")
|
259 |
+
container = st.container()
|
260 |
+
output_placeholder = container.empty()
|
261 |
+
streaming_enabled = True # sampling_mode != "Beam Search" or num_beams == 1
|
262 |
+
generator = generators.get_generator(desc=model_desc)
|
263 |
+
streamer = (
|
264 |
+
StreamlitTextIteratorStreamer(output_placeholder, generator.tokenizer)
|
265 |
+
if streaming_enabled
|
266 |
+
else None
|
267 |
+
)
|
268 |
+
set_seed(seed)
|
269 |
+
time_start = time.time()
|
270 |
+
result = generator.generate(
|
271 |
+
text=st.session_state.text, streamer=streamer, **params
|
272 |
+
)
|
273 |
+
time_end = time.time()
|
274 |
+
time_diff = time_end - time_start
|
275 |
+
|
276 |
+
# for text in result:
|
277 |
+
# st.write(text.get("generated_text").replace("\n", " \n"))
|
278 |
+
# st.text("*Translation*")
|
279 |
+
# translate_params = {
|
280 |
+
# "num_return_sequences": 1,
|
281 |
+
# "num_beams": 4,
|
282 |
+
# "early_stopping": True,
|
283 |
+
# "length_penalty": 1.1,
|
284 |
+
# "max_length": 200,
|
285 |
+
# }
|
286 |
+
# text_lines = [
|
287 |
+
# "translate Dutch to English: " + t
|
288 |
+
# for t in text.get("generated_text").splitlines()
|
289 |
+
# ]
|
290 |
+
# translated_lines = [
|
291 |
+
# t["translation_text"]
|
292 |
+
# for t in generators.get_generator(
|
293 |
+
# task=TRANSLATION_NL_TO_EN
|
294 |
+
# ).get_text(text_lines, **translate_params)
|
295 |
+
# ]
|
296 |
+
# translation = " \n".join(translated_lines)
|
297 |
+
# st.write(translation)
|
298 |
+
# st.write("---")
|
299 |
+
#
|
300 |
+
info = f"""
|
301 |
+
---
|
302 |
+
*Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB*
|
303 |
+
*Text generated using seed {seed} in {time_diff:.5} seconds*
|
304 |
+
"""
|
305 |
+
st.write(info)
|
306 |
+
|
307 |
+
params["seed"] = seed
|
308 |
+
params["prompt"] = st.session_state.text
|
309 |
+
params["model"] = generator.model_name
|
310 |
+
params_text = json.dumps(params)
|
311 |
+
# print(params_text)
|
312 |
+
st.json(params_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
|
314 |
|
315 |
if __name__ == "__main__":
|
requirements.txt
CHANGED
@@ -1,7 +1,10 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
|
|
|
|
6 |
mtranslate
|
7 |
psutil
|
|
|
|
1 |
+
#-f https://download.pytorch.org/whl/torch_stable.html
|
2 |
+
-f https://download.pytorch.org/whl/cu116
|
3 |
+
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
4 |
+
protobuf<3.20
|
5 |
+
streamlit>=1.4.0,<=1.10.0
|
6 |
+
torch
|
7 |
+
git+https://github.com/huggingface/transformers.git@1905384fd576acf4b645a8216907f980b4788d9b
|
8 |
mtranslate
|
9 |
psutil
|
10 |
+
sentencepiece
|