Extended_GPT_2 / app.py
SaviAnna's picture
Update app.py
c3d7acd verified
import transformers
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import torch
st.title("Fine-tuned GPT-2 for New Language with Custom Tokenizer")
# Слайдеры для управления температурой и длиной текста
temperature = st.slider("Temperature", 0.1, 2.0, 1.0) # Для обеих моделей
max_len = st.slider("Max Length", 40, 120, 70) # Для обеих моделей
# Кеширование модели и токенизатора GPT-2
@st.cache_resource
def load_gpt2():
model_gpt2 = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer_gpt2 = GPT2TokenizerFast.from_pretrained("gpt2")
return model_gpt2, tokenizer_gpt2
# Кеширование кастомной модели и токенизатора
@st.cache_resource
def load_custom_model():
# Здесь замените путь на вашу кастомную модель
model_custom = GPT2LMHeadModel.from_pretrained("./rus_gpt2_tuned", from_tf=False, use_safetensors=True)
tokenizer_custom = GPT2TokenizerFast.from_pretrained("./rus_gpt2_tuned/tokenizer")
return model_custom, tokenizer_custom
# Функция для генерации текста
def generate_text(model, tokenizer, prompt, max_len, temperature):
input_ids = tokenizer.encode(prompt, return_tensors='pt')
attention_mask = torch.ones_like(input_ids)
# Генерация текста
output = model.generate(
input_ids,
max_length=max_len,
temperature=temperature, # Управление разнообразием текста
top_k=50, # Ограничение топ-50 самых вероятных слов
top_p=0.9, # Nucleus sampling (суммарная вероятность)
repetition_penalty=1.2, # Штраф за повторение слов или фраз
no_repeat_ngram_size=4, # Запрет на повторение n-грамм (например, биграмм)
do_sample=True, # Включение сэмплинга для большей разнообразности
attention_mask=attention_mask,
pad_token_id=tokenizer.eos_token_id
)
# Декодирование сгенерированных токенов в текст
generated_text = tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
return generated_text
# Streamlit приложение
def main():
model_gpt2, tokenizer_gpt2 = load_gpt2() # GPT-2 модель
model_custom, tokenizer_custom = load_custom_model() # Кастомная модель
#st.write("Fine-tuned GPT-2 for New Language with Custom Tokenizer")
# # Блок для генерации текста с GPT-2
# st.subheader("GPT-2 Text Generation")
# prompt_gpt2 = st.text_area("Введите фразу для GPT-2 генерации:", value="В средние века")
# generate_button_gpt2 = st.button("Сгенерировать текст с GPT-2")
# if generate_button_gpt2:
# generated_text_gpt2 = generate_text(model_gpt2, tokenizer_gpt2, prompt_gpt2, max_len, temperature)
# st.subheader("Результат генерации GPT-2:")
# st.write(generated_text_gpt2)
# Блок для генерации текста с кастомной моделью
st.subheader("Custom Model Text Generation")
prompt_custom = st.text_area("Enter a phrase to generate with the updated model:", value="Когда-то давно")
generate_button_custom = st.button("Generate!")
if generate_button_custom:
generated_text_custom = generate_text(model_custom, tokenizer_custom, prompt_custom, max_len, temperature)
st.subheader("Result:")
st.write(generated_text_custom)
if __name__ == "__main__":
main()