File size: 3,968 Bytes
460d569
 
b628870
460d569
 
cc91e6e
460d569
b628870
12d19c9
 
b628870
12d19c9
b628870
12d19c9
 
 
 
 
 
 
 
 
63d047a
aabd4e2
12d19c9
460d569
 
12d19c9
 
c3d7acd
460d569
356b987
cc91e6e
 
 
 
 
 
 
 
073629e
cc91e6e
 
 
 
 
460d569
 
 
 
12d19c9
 
9c47ff5
8bdd958
460d569
17a60a4
 
 
 
12d19c9
17a60a4
 
 
 
9c47ff5
12d19c9
 
1c8cc34
 
460d569
12d19c9
 
1c8cc34
12d19c9
460d569
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import transformers
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import torch

st.title("Fine-tuned GPT-2 for New Language with Custom Tokenizer")

# Слайдеры для управления температурой и длиной текста
temperature = st.slider("Temperature", 0.1, 2.0, 1.0)  # Для обеих моделей
max_len = st.slider("Max Length", 40, 120, 70)  # Для обеих моделей

# Кеширование модели и токенизатора GPT-2
@st.cache_resource
def load_gpt2():
    model_gpt2 = GPT2LMHeadModel.from_pretrained("gpt2")
    tokenizer_gpt2 = GPT2TokenizerFast.from_pretrained("gpt2")
    return model_gpt2, tokenizer_gpt2

# Кеширование кастомной модели и токенизатора
@st.cache_resource
def load_custom_model():
    # Здесь замените путь на вашу кастомную модель
    model_custom = GPT2LMHeadModel.from_pretrained("./rus_gpt2_tuned", from_tf=False, use_safetensors=True)
    tokenizer_custom = GPT2TokenizerFast.from_pretrained("./rus_gpt2_tuned/tokenizer")
    return model_custom, tokenizer_custom

# Функция для генерации текста
def generate_text(model, tokenizer, prompt, max_len, temperature):
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    attention_mask = torch.ones_like(input_ids)
    # Генерация текста
    output = model.generate(
        input_ids,
        max_length=max_len,
        temperature=temperature,              # Управление разнообразием текста
        top_k=50,                     # Ограничение топ-50 самых вероятных слов
        top_p=0.9,                    # Nucleus sampling (суммарная вероятность)
        repetition_penalty=1.2,       # Штраф за повторение слов или фраз
        no_repeat_ngram_size=4,       # Запрет на повторение n-грамм (например, биграмм)
        do_sample=True,                # Включение сэмплинга для большей разнообразности
        attention_mask=attention_mask,
        pad_token_id=tokenizer.eos_token_id
    )

    # Декодирование сгенерированных токенов в текст
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return generated_text

# Streamlit приложение
def main():
    model_gpt2, tokenizer_gpt2 = load_gpt2()  # GPT-2 модель
    model_custom, tokenizer_custom = load_custom_model()  # Кастомная модель

    #st.write("Fine-tuned GPT-2 for New Language with Custom Tokenizer")

    # # Блок для генерации текста с GPT-2
    # st.subheader("GPT-2 Text Generation")
    # prompt_gpt2 = st.text_area("Введите фразу для GPT-2 генерации:", value="В средние века")
    # generate_button_gpt2 = st.button("Сгенерировать текст с GPT-2")

    # if generate_button_gpt2:
    #     generated_text_gpt2 = generate_text(model_gpt2, tokenizer_gpt2, prompt_gpt2, max_len, temperature)
    #     st.subheader("Результат генерации GPT-2:")
    #     st.write(generated_text_gpt2)

    # Блок для генерации текста с кастомной моделью
    st.subheader("Custom Model Text Generation")
    prompt_custom = st.text_area("Enter a phrase to generate with the updated model:", value="Когда-то давно")
    generate_button_custom = st.button("Generate!")

    if generate_button_custom:
        generated_text_custom = generate_text(model_custom, tokenizer_custom, prompt_custom, max_len, temperature)
        st.subheader("Result:")
        st.write(generated_text_custom)

if __name__ == "__main__":
    main()