Spaces:
Runtime error
Runtime error
File size: 2,617 Bytes
e816bb5 f4a3863 76d0859 0957f7e 76d0859 f4a3863 e816bb5 f4a3863 b26e605 f4a3863 e816bb5 f4a3863 e816bb5 f4a3863 0957f7e f4a3863 0957f7e c44f938 f4a3863 c44f938 0957f7e f4a3863 c485299 e816bb5 0957f7e e816bb5 c44f938 f4a3863 e816bb5 0957f7e e816bb5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import transformers
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
st.set_page_config(
page_title="Romanian Text Generator",
page_icon="🇷🇴",
layout="wide"
)
st.write("Type your text here and press Ctrl+Enter to generate the next sequence:")
model_list = [
"dumitrescustefan/gpt-neo-romanian-780m",
"readerbench/RoGPT2-base",
"readerbench/RoGPT2-medium",
"readerbench/RoGPT2-large"
]
st.sidebar.header("Select model")
model_checkpoint = st.sidebar.radio("", model_list)
st.sidebar.header("Select generation parameters")
max_length = st.sidebar.slider("Max Length", value=20, min_value=10, max_value=200)
temperature = st.sidebar.slider("Temperature", value=1.0, min_value=0.0, max_value=1.0, step=0.05)
top_k = st.sidebar.slider("Top-k", min_value=0, max_value=15, step=1, value=0)
top_p = st.sidebar.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=0.9)
text_element = st.text_input('Text:', 'Acesta este un exemplu,')
@st.cache(allow_output_mutation=True)
def setModel(model_checkpoint):
model = AutoModelForCausalLM.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
return model, tokenizer
def infer(model, tokenizer, text, max_length, temperature, top_k, top_p):
encoded_prompt = tokenizer(text, add_special_tokens=False, return_tensors="pt")
output_sequences = model.generate(
input_ids=encoded_prompt.input_ids,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
num_return_sequences=1
)
return output_sequences
model, tokenizer = setModel(model_checkpoint)
output_sequences = infer(model, tokenizer, text_element, max_length, temperature, top_k, top_p)
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
generated_sequences = generated_sequence.tolist()
# Decode text
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
# Remove all text after the stop token
# text = text[: text.find(args.stop_token) if args.stop_token else None]
# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
total_sequence = (
sent + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)):]
)
generated_sequences.append(total_sequence)
print(total_sequence)
st.write(generated_sequences[-1], text_element)
|