Extended_GPT_2 / app.py
Last commit not found
raw
history blame
2.43 kB
import transformers
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import torch
st.title("""
Fine-tuned GPT-2 for New Language with Custom Tokenizer
""")
# Слайдеры для управления температурой и длиной текста
temperature = st.slider("Temperature", 0.1, 2.0, 1.0)
max_len = st.slider("Max Length", 40, 120, 70)
# Кеширование модели и токенизатора
@st.cache_resource
def load_gpt_base():
model_GPT = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer_GPT = GPT2TokenizerFast.from_pretrained("gpt2")
return model_GPT, tokenizer_GPT
# Функция для генерации текста
def generate_text(model_GPT, tokenizer_GPT, prompt, max_len, temperature):
# Преобразование входной строки в токены
input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt')
# Генерация текста
output = model_GPT.generate(input_ids=input_ids,
max_length=max_len,
do_sample=True,
temperature=temperature,
top_k=50,
top_p=0.6,
no_repeat_ngram_size=3,
num_return_sequences=1)
# Декодирование сгенерированного текста
generated_text = tokenizer_GPT.decode(output[0], skip_special_tokens=True)
return generated_text
# Streamlit приложение
def main():
model_GPT, tokenizer_GPT = load_gpt_base()
st.write("""
# Fine-tuned GPT-2 for New Language with Custom Tokenizer
""")
# Ввод строки пользователем для генерации текста
prompt = st.text_area("Введите фразу для генерации:", value="В средние века")
# Создание кнопки для генерации
generate_button = st.button("Сгенерировать текст")
# Обработка события нажатия кнопки
if generate_button:
generated_text = generate_text(model_GPT, tokenizer_GPT, prompt, max_len, temperature)
st.subheader("Сгенерированный текст:")
st.write(generated_text)
if __name__ == "__main__":
main()