netherator / app.py
yhavinga's picture
Unpin streamlit, update cache function usage
6202886
import json
import os
import time
from random import randint
import psutil
import streamlit as st
import torch
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
TextIteratorStreamer,
pipeline,
set_seed,
)
device = torch.cuda.device_count() - 1
TRANSLATION_NL_TO_EN = "translation_en_to_nl"
@st.cache_resource()
def load_model(model_name, task):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
try:
if not os.path.exists(".streamlit/secrets.toml"):
raise FileNotFoundError
access_token = st.secrets.get("netherator")
except FileNotFoundError:
access_token = os.environ.get("HF_ACCESS_TOKEN", None)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token)
if tokenizer.pad_token is None:
print("Adding pad_token to the tokenizer")
tokenizer.pad_token = tokenizer.eos_token
auto_model_class = (
AutoModelForSeq2SeqLM if "translation" in task else AutoModelForCausalLM
)
model = auto_model_class.from_pretrained(model_name, use_auth_token=access_token)
if device != -1:
model.to(f"cuda:{device}")
return tokenizer, model
class StreamlitTextIteratorStreamer(TextIteratorStreamer):
def __init__(
self, output_placeholder, tokenizer, skip_prompt=False, **decode_kwargs
):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.output_placeholder = output_placeholder
self.output_text = ""
def on_finalized_text(self, text: str, stream_end: bool = False):
self.output_text += text
self.output_placeholder.markdown(self.output_text, unsafe_allow_html=True)
super().on_finalized_text(text, stream_end)
class Generator:
def __init__(self, model_name, task, desc):
self.model_name = model_name
self.task = task
self.desc = desc
self.tokenizer = None
self.model = None
self.pipeline = None
self.load()
def load(self):
if not self.model:
print(f"Loading model {self.model_name}")
self.tokenizer, self.model = load_model(self.model_name, self.task)
def generate(self, text: str, streamer=None, **generate_kwargs) -> (str, dict):
batch_encoded = self.tokenizer(
text,
max_length=generate_kwargs["max_length"],
padding=False,
truncation=False,
return_tensors="pt",
)
if device != -1:
batch_encoded.to(f"cuda:{device}")
logits = self.model.generate(
batch_encoded["input_ids"],
attention_mask=batch_encoded["attention_mask"],
streamer=streamer,
**generate_kwargs,
)
decoded_preds = self.tokenizer.batch_decode(
logits.cpu().numpy(), skip_special_tokens=False
)
def replace_tokens(pred):
pred = pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
if hasattr(self.tokenizer, "newline_token"):
pred = pred.replace(self.tokenizer.newline_token, "\n")
return pred
decoded_preds = list(map(replace_tokens, decoded_preds))
return decoded_preds[0], generate_kwargs
class GeneratorFactory:
def __init__(self):
self.generators = []
def instantiate_generators(self):
GENERATOR_LIST = [
{
"model_name": "yhavinga/gpt-neo-125M-dutch-nedd",
"desc": "GPT-Neo Small Dutch(book finetune)",
"task": "text-generation",
},
{
"model_name": "yhavinga/gpt2-medium-dutch-nedd",
"desc": "GPT2 Medium Dutch (book finetune)",
"task": "text-generation",
},
# {
# "model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
# "desc": "Dutch<->English T5 small 24 layers",
# "task": TRANSLATION_NL_TO_EN,
# },
]
for g in GENERATOR_LIST:
with st.spinner(text=f"Loading the model {g['desc']} ..."):
self.add_generator(**g)
return self
def add_generator(self, model_name, task, desc):
# If the generator is not yet present, add it
if not self.get_generator(model_name=model_name, task=task, desc=desc):
g = Generator(model_name, task, desc)
g.load()
self.generators.append(g)
def get_generator(self, **kwargs):
for g in self.generators:
if all([g.__dict__.get(k) == v for k, v in kwargs.items()]):
return g
return None
def gpt_descs(self):
return [g.desc for g in self.generators if g.task == "text-generation"]
def main():
st.set_page_config( # Alternate names: setup_page, page, layout
page_title="Netherator", # String or None. Strings get appended with "• Streamlit".
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
initial_sidebar_state="expanded", # Can be "auto", "expanded", "collapsed"
page_icon="📚", # String, anything supported by st.image, or None.
)
if "generators" not in st.session_state:
st.session_state["generators"] = GeneratorFactory().instantiate_generators()
generators = st.session_state["generators"]
with open("style.css") as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
st.sidebar.image("demon-reading-Stewart-Orr.png", width=200)
st.sidebar.markdown(
"""# Netherator
Nederlandse verhalenverteller"""
)
model_desc = st.sidebar.selectbox("Model", generators.gpt_descs(), index=1)
st.sidebar.title("Parameters:")
if "prompt_box" not in st.session_state:
st.session_state["prompt_box"] = "Het was een koude winterdag"
st.session_state["text"] = st.text_area("Enter text", st.session_state.prompt_box)
max_length = st.sidebar.number_input(
"Lengte van de tekst",
value=200,
max_value=512,
)
no_repeat_ngram_size = st.sidebar.number_input(
"No-repeat NGram size", min_value=1, max_value=5, value=3
)
repetition_penalty = st.sidebar.number_input(
"Repetition penalty", min_value=0.0, max_value=5.0, value=1.2, step=0.1
)
num_return_sequences = 1
# st.sidebar.number_input(
# "Num return sequences", min_value=1, max_value=5, value=1
# )
seed_placeholder = st.sidebar.empty()
if "seed" not in st.session_state:
print(f"Session state does not contain seed")
st.session_state["seed"] = 4162549114
print(f"Seed is set to: {st.session_state['seed']}")
seed = seed_placeholder.number_input(
"Seed", min_value=0, max_value=2**32 - 1, value=st.session_state["seed"]
)
def set_random_seed():
st.session_state["seed"] = randint(0, 2**32 - 1)
seed = seed_placeholder.number_input(
"Seed", min_value=0, max_value=2**32 - 1, value=st.session_state["seed"]
)
print(f"New random seed set to: {seed}")
if st.button("New random seed?"):
set_random_seed()
if sampling_mode := st.sidebar.selectbox(
"select a Mode", index=0, options=["Top-k Sampling", "Beam Search"]
):
if sampling_mode == "Beam Search":
num_beams = st.sidebar.number_input(
"Num beams", min_value=1, max_value=10, value=4
)
length_penalty = st.sidebar.number_input(
"Length penalty", min_value=0.0, max_value=2.0, value=1.0, step=0.1
)
params = {
"max_length": max_length,
"no_repeat_ngram_size": no_repeat_ngram_size,
"repetition_penalty": repetition_penalty,
"num_return_sequences": num_return_sequences,
"num_beams": num_beams,
"early_stopping": True,
"length_penalty": length_penalty,
}
else:
top_k = st.sidebar.number_input(
"Top K", min_value=0, max_value=100, value=50
)
top_p = st.sidebar.number_input(
"Top P", min_value=0.0, max_value=1.0, value=0.95, step=0.05
)
temperature = st.sidebar.number_input(
"Temperature", min_value=0.05, max_value=1.0, value=1.0, step=0.05
)
params = {
"max_length": max_length,
"no_repeat_ngram_size": no_repeat_ngram_size,
"repetition_penalty": repetition_penalty,
"num_return_sequences": num_return_sequences,
"do_sample": True,
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
}
st.sidebar.markdown(
"""For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
"""
)
if st.button("Run"):
memory = psutil.virtual_memory()
st.subheader("Result")
container = st.container()
output_placeholder = container.empty()
streaming_enabled = True # sampling_mode != "Beam Search" or num_beams == 1
generator = generators.get_generator(desc=model_desc)
streamer = (
StreamlitTextIteratorStreamer(output_placeholder, generator.tokenizer)
if streaming_enabled
else None
)
set_seed(seed)
time_start = time.time()
result = generator.generate(
text=st.session_state.text, streamer=streamer, **params
)
time_end = time.time()
time_diff = time_end - time_start
# for text in result:
# st.write(text.get("generated_text").replace("\n", " \n"))
# st.text("*Translation*")
# translate_params = {
# "num_return_sequences": 1,
# "num_beams": 4,
# "early_stopping": True,
# "length_penalty": 1.1,
# "max_length": 200,
# }
# text_lines = [
# "translate Dutch to English: " + t
# for t in text.get("generated_text").splitlines()
# ]
# translated_lines = [
# t["translation_text"]
# for t in generators.get_generator(
# task=TRANSLATION_NL_TO_EN
# ).get_text(text_lines, **translate_params)
# ]
# translation = " \n".join(translated_lines)
# st.write(translation)
# st.write("---")
#
info = f"""
---
*Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB*
*Text generated using seed {seed} in {time_diff:.5} seconds*
"""
st.write(info)
params["seed"] = seed
params["prompt"] = st.session_state.text
params["model"] = generator.model_name
params_text = json.dumps(params)
# print(params_text)
st.json(params_text)
if __name__ == "__main__":
main()