File size: 2,617 Bytes
e816bb5
 
 
f4a3863
76d0859
 
 
 
0957f7e
76d0859
 
f4a3863
e816bb5
f4a3863
b26e605
f4a3863
 
 
 
e816bb5
f4a3863
 
e816bb5
f4a3863
0957f7e
 
 
 
 
f4a3863
0957f7e
 
c44f938
f4a3863
c44f938
0957f7e
 
f4a3863
c485299
 
e816bb5
0957f7e
e816bb5
 
 
 
 
 
 
 
 
 
c44f938
f4a3863
e816bb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0957f7e
e816bb5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import transformers
import streamlit as st

from transformers import AutoTokenizer, AutoModelForCausalLM

st.set_page_config(
    page_title="Romanian Text Generator",
    page_icon="🇷🇴",
    layout="wide"
)

st.write("Type your text here and press Ctrl+Enter to generate the next sequence:")

model_list = [
    "dumitrescustefan/gpt-neo-romanian-780m",
    "readerbench/RoGPT2-base",
    "readerbench/RoGPT2-medium",
    "readerbench/RoGPT2-large"
]

st.sidebar.header("Select model")
model_checkpoint = st.sidebar.radio("", model_list)

st.sidebar.header("Select generation parameters")
max_length = st.sidebar.slider("Max Length", value=20, min_value=10, max_value=200)
temperature = st.sidebar.slider("Temperature", value=1.0, min_value=0.0, max_value=1.0, step=0.05)
top_k = st.sidebar.slider("Top-k", min_value=0, max_value=15, step=1, value=0)
top_p = st.sidebar.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=0.9)

text_element = st.text_input('Text:', 'Acesta este un exemplu,')

@st.cache(allow_output_mutation=True)
def setModel(model_checkpoint):
    model = AutoModelForCausalLM.from_pretrained(model_checkpoint)
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    return model, tokenizer

def infer(model, tokenizer, text, max_length, temperature, top_k, top_p):
    encoded_prompt = tokenizer(text, add_special_tokens=False, return_tensors="pt")

    output_sequences = model.generate(
        input_ids=encoded_prompt.input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=True,
        num_return_sequences=1
    )

    return output_sequences

model, tokenizer = setModel(model_checkpoint)
output_sequences = infer(model, tokenizer, text_element, max_length, temperature, top_k, top_p)

for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
    print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
    generated_sequences = generated_sequence.tolist()

    # Decode text
    text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)

    # Remove all text after the stop token
    # text = text[: text.find(args.stop_token) if args.stop_token else None]

    # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
    total_sequence = (
            sent + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)):]
    )

    generated_sequences.append(total_sequence)
    print(total_sequence)

st.write(generated_sequences[-1], text_element)