File size: 2,767 Bytes
7d6f77f
 
 
 
518485c
833aa0a
da9ee26
7d6f77f
 
47d0a4a
7d6f77f
 
724876e
1bc822d
724876e
7d6f77f
724876e
7d6f77f
 
 
 
 
d2e6254
bb72c45
7d6f77f
b86439f
7d6f77f
 
 
 
 
 
d2e6254
1bc822d
7d6f77f
 
 
 
 
833aa0a
 
7d6f77f
da9ee26
d2e6254
da9ee26
cc7b3e8
7d6f77f
cc7b3e8
2b23c62
833aa0a
7d6f77f
 
cc7b3e8
 
7d6f77f
 
cc7b3e8
 
 
 
 
 
 
 
 
 
4bd4566
cc7b3e8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import transformers
import torch
import tokenizers
import streamlit as st
import re

from PIL import Image


@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
def get_model(model_name, model_path):
    tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name)
    tokenizer.add_special_tokens({
        'eos_token': '[EOS]'
    })
    model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
    model.resize_token_embeddings(len(tokenizer))
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    return model, tokenizer


def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, length_of_generated=300):
    text += '\n'
    input_ids = tokenizer.encode(text, return_tensors="pt")
    length_of_prompt = len(input_ids[0])
    with torch.no_grad():
        out = model.generate(input_ids,
                             do_sample=True,
                             num_beams=n_beams,
                             temperature=temperature,
                             top_p=top_p,
                             max_length=length_of_prompt + length_of_generated,
                             eos_token_id=tokenizer.eos_token_id
                             )

    return list(map(tokenizer.decode, out))[0]


medium_model, medium_tokenizer = get_model('sberbank-ai/rugpt3medium_based_on_gpt2', 'korzh-medium_best_eval_loss.bin')
large_model, large_tokenizer = get_model('sberbank-ai/rugpt3large_based_on_gpt2', 'korzh-large_best_eval_loss.bin')

# st.title("NeuroKorzh")

image = Image.open('korzh.jpg')
st.image(image, caption='НейроКорж')

option = st.selectbox('Выберите своего Коржа', ('Быстрый', 'Глубокий'))
craziness = st.slider(label='Абсурдность', min_value=0.5, max_value=4., value=2.5, step=0.1)

st.markdown("\n")

text = st.text_area(label='Напишите начало песни', value='Что делать, Макс?', height=100)
button = st.button('Старт')

if button:
    try:
        with st.spinner("Пушечка пишется"):
            if option == 'Быстрый':
                result = predict(text, medium_model, medium_tokenizer, temperature=craziness)
            elif option == 'Глубокий':
                result = predict(text, large_model, large_tokenizer, temperature=craziness)
            else:
                raise st.error('Error in selectbox')
        
        st.text_area(label='', value=result, height=1100)
    
    except Exception:
        st.error("Ooooops, something went wrong. Please try again and report to me, tg: @vladyur")