File size: 3,298 Bytes
1867879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Загрузка модели и токенизатора
@st.cache_resource
def load_model():
    model_name = "models/gpt" 
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    return model, tokenizer

def generate_text(model, tokenizer, prompt, gen_params):
    inputs = tokenizer(prompt, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_length=gen_params['max_length'],
            temperature=gen_params['temperature'],
            top_k=gen_params['top_k'],
            top_p=gen_params['top_p'],
            num_return_sequences=gen_params['num_return_sequences'],
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    generated = []
    for i, output in enumerate(outputs):
        text = tokenizer.decode(output, skip_special_tokens=True)
        generated.append(f"Генерация {i+1}:\n{text}\n{'-'*50}")
    
    return generated


def main():
    st.markdown(
    "<h1 style='text-align: center;'>Генератор текста</h1>", 
    unsafe_allow_html=True
    )
    st.markdown(
    "<h3 style='text-align: center;'>(ну почти)</h3>", 
    unsafe_allow_html=True
    )
    st.markdown("---")

    col1, col2, col3 = st.columns([1, 2, 1]) 
    with col2:
        st.image('images/scale_1200.png', width=500)
    
    # Загрузка модели
    model, tokenizer = load_model()
    
    # Параметры генерации
    with st.sidebar:
        st.header("Настройки генерации")
        prompt = st.text_area("Введите начальный текст:", height=100)
        max_length = st.slider("Максимальная длина:", 50, 500, 100)
        num_return_sequences = st.slider("Число генераций:", 1, 5, 1)
        
        st.subheader("Параметры выборки:")
        sampling_method = st.radio("Метод:", ["Temperature", "Top-k & Top-p"])
        
        if sampling_method == "Temperature":
            temperature = st.slider("Temperature:", 0.1, 2.0, 1.0, 0.1)
            top_k = None
            top_p = None
        else:
            temperature = 1.0
            top_k = st.slider("Top-k:", 1, 100, 50)
            top_p = st.slider("Top-p:", 0.1, 1.0, 0.9, 0.05)
    
    # Кнопка генерации
    if st.sidebar.button("Сгенерировать текст"):
        if not prompt:
            st.warning("Введите начальный текст!")
            return
            
        gen_params = {
            'max_length': max_length,
            'temperature': temperature,
            'top_k': top_k,
            'top_p': top_p,
            'num_return_sequences': num_return_sequences
        }
        
        with st.spinner("Прибухиваем..."):
            generated = generate_text(model, tokenizer, prompt, gen_params)
            
        st.markdown("---")
        st.subheader("Результаты:")
        for text in generated:
            st.text_area(label="", value=text, height=200)

if __name__ == "__main__":
    main()