Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
from time import perf_counter | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
st.set_page_config( | |
page_title="Romanian Text Generator", | |
page_icon="🇷🇴", | |
layout="wide" | |
) | |
############################################# | |
# Python stuff here | |
model_list = [ | |
"dumitrescustefan/gpt-neo-romanian-780m", | |
"readerbench/RoGPT2-base", | |
"readerbench/RoGPT2-medium", | |
"readerbench/RoGPT2-large" | |
] | |
def greedy_search(model, input_ids, attention_mask, no_repeat_ngram_size, max_length): | |
return model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
max_length=max_length | |
) | |
def beam_search(model, input_ids, attention_mask, no_repeat_ngram_size, max_length, num_beams): | |
return model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
max_length=max_length, | |
num_beams=num_beams, | |
early_stopping=True | |
) | |
def sampling(model, input_ids, attention_mask, no_repeat_ngram_size, max_length, temperature, top_k, top_p): | |
return model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
max_length=max_length, | |
do_sample=True, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p | |
) | |
def typical_sampling(model, input_ids, attention_mask, no_repeat_ngram_size, max_length, temperature, typical_p): | |
return model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
max_length=max_length, | |
do_sample=True, | |
temperature=temperature, | |
typical_p=typical_p, | |
top_k=0 | |
) | |
def setModel(model_checkpoint): | |
model = AutoModelForCausalLM.from_pretrained(model_checkpoint) | |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
return model, tokenizer | |
############################################# | |
col_title, _, col_b1, col_b2, col_b3, _ = st.columns([18, 1, 8, 8, 8, 1]) | |
col_title.markdown("**Playground for text generation with Romanian models**") | |
button_greedy = col_b1.button("Greedy generation") | |
button_sampling = col_b2.button("Sampling generation") | |
button_typical = col_b3.button("Typical sampling generation") | |
col1, _, col2 = st.columns([10, 1, 16]) | |
with col1: | |
st.markdown("**Step 1: Select model**") | |
model_checkpoint = st.selectbox("Select model", model_list) | |
st.markdown("**Step 2: Adjust text generation parameters**") | |
max_length = st.slider("Number of tokens to generate", value=50, min_value=10, max_value=256) | |
top_k = col1.slider("Top-k", min_value=0, max_value=100, step=10, value=0) | |
top_p = col1.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=0.9) | |
typical_p = col1.slider("Typical-p", min_value=0., max_value=1., step=.10, value=1.0) | |
temperature = st.slider("Temperature", value=1.0, min_value=0.1, max_value=1.0, step=0.1) | |
no_repeat_ngrams = st.slider("No repeat n-grams", value=2, min_value=0, max_value=3) | |
# st.markdown("**Step 4: Select a prompt or input your own text, and click generate in the left panel**") | |
def update_prompt(): | |
st.session_state['text'] = prompt | |
prompt = st.selectbox("Select prompt", model_list, on_change=update_prompt) | |
def setModel(model_checkpoint): | |
model = AutoModelForCausalLM.from_pretrained(model_checkpoint) | |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
return model, tokenizer | |
##################################################### | |
# show-time | |
if 'text' not in st.session_state: | |
st.session_state['text'] = 'Acesta este un exemplu de text generat de un model de limbă.' | |
details = "" | |
tokenized_text = None | |
if button_greedy or button_sampling or button_typical: | |
if len(st.session_state['text'].strip()) == 0: | |
col2.warning("Please input some text!") | |
text_element = col2.text_area('Text:', height=400, key="text") | |
st.stop() | |
model, tokenizer = setModel(model_checkpoint) | |
tokenized_text = tokenizer(st.session_state['text'], add_special_tokens=False, return_tensors="pt") | |
if len(tokenized_text.input_ids[0]) + max_length > 512: # need to keep less words | |
keep_last = 512 - max_length | |
print(f"keep last: {keep_last}") | |
input_ids, attention_mask = tokenized_text.input_ids[0][-keep_last:], tokenized_text.attention_mask[0][-keep_last:] | |
previous_ids = tokenized_text.input_ids[0][:keep_last] | |
st.warning(f"kept last {keep_last}") | |
else: | |
input_ids, attention_mask = tokenized_text.input_ids[0], tokenized_text.attention_mask[0] | |
previous_ids = None | |
length = min(512, len(input_ids)+max_length) | |
timer_mark = perf_counter() | |
if button_greedy: | |
output = greedy_search(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length) | |
details = f"Text generated using greedy decoding in {perf_counter()-timer_mark:.2f}s" | |
if button_sampling: | |
output = sampling(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length, temperature, top_k, top_p) | |
details = f"Text generated using sampling, top-p={top_p:.2f}, top-k={top_k}, temperature={temperature:.2f} in {perf_counter()-timer_mark:.2f}s" | |
if button_typical: | |
output = typical_sampling(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length, temperature, typical_p) | |
details = f"Text generated using typical sampling, typical-p={typical_p:.2f}, temperature={temperature:.2f} in {perf_counter()-timer_mark:.2f}s" | |
if previous_ids is not None: | |
print(f"\nConcat prev id: "+tokenizer.decode(previous_ids, skip_special_tokens=True)) | |
print(f"\nWith current decode: " + tokenizer.decode(output[0], skip_special_tokens=True)) | |
new_text = tokenizer.decode(torch.cat([previous_ids, output[0]], dim=-1), skip_special_tokens=True) | |
else: | |
new_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
st.session_state['text'] = new_text | |
text_element = col2.text_area('Text:', height=400, key="text") | |
col2.markdown("""---""") | |
col2.text("Statistics and details:") | |
if details != "": | |
col2.caption(" Generation details: " + details) | |
if tokenized_text is None: | |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
tt = tokenizer(text_element, add_special_tokens=False, return_tensors="pt") | |
col2.caption(f" Text length is {len(text_element)} characters, {len(tt.input_ids[0])} tokens.") |