Spaces:
Runtime error
Runtime error
File size: 6,811 Bytes
e816bb5 2f0ed55 c90ce91 f4a3863 76d0859 0957f7e 76d0859 9506fdb e816bb5 f4a3863 b26e605 f4a3863 e816bb5 9506fdb 50f4f41 9506fdb e816bb5 9506fdb 0957f7e 9506fdb 0957f7e c44f938 f4a3863 c44f938 0957f7e 2f0ed55 9506fdb 2f0ed55 0957f7e 9506fdb c485299 9506fdb 2f0ed55 e816bb5 9506fdb e816bb5 831de4c e816bb5 831de4c 9506fdb e816bb5 c90ce91 e816bb5 9506fdb 19c9e19 2f0ed55 9506fdb 2f0ed55 9506fdb 2f0ed55 9506fdb 2f0ed55 c90ce91 9506fdb 2f0ed55 9506fdb 2f0ed55 19c9e19 2f0ed55 c90ce91 2f0ed55 19c9e19 2f0ed55 9506fdb 2f0ed55 9506fdb 19c9e19 2f0ed55 19c9e19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import streamlit as st
import torch
from time import perf_counter
from transformers import AutoTokenizer, AutoModelForCausalLM
st.set_page_config(
page_title="Romanian Text Generator",
page_icon="🇷🇴",
layout="wide"
)
#############################################
# Python stuff here
model_list = [
"dumitrescustefan/gpt-neo-romanian-780m",
"readerbench/RoGPT2-base",
"readerbench/RoGPT2-medium",
"readerbench/RoGPT2-large"
]
def greedy_search(model, input_ids, attention_mask, no_repeat_ngram_size, max_length):
return model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
no_repeat_ngram_size=no_repeat_ngram_size,
max_length=max_length
)
def beam_search(model, input_ids, attention_mask, no_repeat_ngram_size, max_length, num_beams):
return model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
no_repeat_ngram_size=no_repeat_ngram_size,
max_length=max_length,
num_beams=num_beams,
early_stopping=True
)
def sampling(model, input_ids, attention_mask, no_repeat_ngram_size, max_length, temperature, top_k, top_p):
return model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
no_repeat_ngram_size=no_repeat_ngram_size,
max_length=max_length,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p
)
def typical_sampling(model, input_ids, attention_mask, no_repeat_ngram_size, max_length, temperature, typical_p):
return model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
no_repeat_ngram_size=no_repeat_ngram_size,
max_length=max_length,
do_sample=True,
temperature=temperature,
typical_p=typical_p,
top_k=0
)
@st.cache(allow_output_mutation=True)
def setModel(model_checkpoint):
model = AutoModelForCausalLM.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
return model, tokenizer
#############################################
col_title, _, col_b1, col_b2, col_b3, _ = st.columns([18, 1, 8, 8, 8, 1])
col_title.markdown("**Playground for text generation with Romanian models**")
button_greedy = col_b1.button("Greedy generation")
button_sampling = col_b2.button("Sampling generation")
button_typical = col_b3.button("Typical sampling generation")
col1, _, col2 = st.columns([10, 1, 16])
with col1:
st.markdown("**Step 1: Select model**")
model_checkpoint = st.selectbox("Select model", model_list)
st.markdown("**Step 2: Adjust text generation parameters**")
max_length = st.slider("Number of tokens to generate", value=50, min_value=10, max_value=256)
top_k = col1.slider("Top-k", min_value=0, max_value=100, step=10, value=0)
top_p = col1.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=0.9)
typical_p = col1.slider("Typical-p", min_value=0., max_value=1., step=.10, value=1.0)
temperature = st.slider("Temperature", value=1.0, min_value=0.1, max_value=1.0, step=0.1)
no_repeat_ngrams = st.slider("No repeat n-grams", value=2, min_value=0, max_value=3)
# st.markdown("**Step 4: Select a prompt or input your own text, and click generate in the left panel**")
def update_prompt():
st.session_state['text'] = prompt
prompt = st.selectbox("Select prompt", model_list, on_change=update_prompt)
@st.cache(allow_output_mutation=True)
def setModel(model_checkpoint):
model = AutoModelForCausalLM.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
return model, tokenizer
#####################################################
# show-time
if 'text' not in st.session_state:
st.session_state['text'] = 'Acesta este un exemplu de text generat de un model de limbă.'
details = ""
tokenized_text = None
if button_greedy or button_sampling or button_typical:
if len(st.session_state['text'].strip()) == 0:
col2.warning("Please input some text!")
text_element = col2.text_area('Text:', height=400, key="text")
st.stop()
model, tokenizer = setModel(model_checkpoint)
tokenized_text = tokenizer(st.session_state['text'], add_special_tokens=False, return_tensors="pt")
if len(tokenized_text.input_ids[0]) + max_length > 512: # need to keep less words
keep_last = 512 - max_length
print(f"keep last: {keep_last}")
input_ids, attention_mask = tokenized_text.input_ids[0][-keep_last:], tokenized_text.attention_mask[0][-keep_last:]
previous_ids = tokenized_text.input_ids[0][:keep_last]
st.warning(f"kept last {keep_last}")
else:
input_ids, attention_mask = tokenized_text.input_ids[0], tokenized_text.attention_mask[0]
previous_ids = None
length = min(512, len(input_ids)+max_length)
timer_mark = perf_counter()
if button_greedy:
output = greedy_search(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length)
details = f"Text generated using greedy decoding in {perf_counter()-timer_mark:.2f}s"
if button_sampling:
output = sampling(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length, temperature, top_k, top_p)
details = f"Text generated using sampling, top-p={top_p:.2f}, top-k={top_k}, temperature={temperature:.2f} in {perf_counter()-timer_mark:.2f}s"
if button_typical:
output = typical_sampling(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length, temperature, typical_p)
details = f"Text generated using typical sampling, typical-p={typical_p:.2f}, temperature={temperature:.2f} in {perf_counter()-timer_mark:.2f}s"
if previous_ids is not None:
print(f"\nConcat prev id: "+tokenizer.decode(previous_ids, skip_special_tokens=True))
print(f"\nWith current decode: " + tokenizer.decode(output[0], skip_special_tokens=True))
new_text = tokenizer.decode(torch.cat([previous_ids, output[0]], dim=-1), skip_special_tokens=True)
else:
new_text = tokenizer.decode(output[0], skip_special_tokens=True)
st.session_state['text'] = new_text
text_element = col2.text_area('Text:', height=400, key="text")
col2.markdown("""---""")
col2.text("Statistics and details:")
if details != "":
col2.caption(" Generation details: " + details)
if tokenized_text is None:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
tt = tokenizer(text_element, add_special_tokens=False, return_tensors="pt")
col2.caption(f" Text length is {len(text_element)} characters, {len(tt.input_ids[0])} tokens.") |