Extended_GPT_2 / app.py
SaviAnna's picture
Create app.py
460d569 verified
raw
history blame
3.95 kB
import transformers
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
from PIL import Image
import torch
st.title("""
Fine-tuned GPT-2 for New Language with Custom Tokenizer
""")
# Добавление слайдера
temperature = st.slider("Temerature", 1, 20, 1)
max_len = st.slider("Length", 40, 120, 2)
# Загрузка модели и токенизатора
# model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# #Задаем класс модели (уже в streamlit/tg_bot)
@st.cache
# def load_gpt():
# model_GPT = GPT2LMHeadModel.from_pretrained(
# 'sberbank-ai/rugpt3small_based_on_gpt2',
# output_attentions = False,
# output_hidden_states = False,
# )
# tokenizer_GPT = GPT2Tokenizer.from_pretrained(
# 'sberbank-ai/rugpt3small_based_on_gpt2',
# output_attentions = False,
# output_hidden_states = False,
# )
# gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# model_GPT.load_state_dict(torch.load('model_history_friday.pt', map_location=torch.device('cpu')))
# return model_GPT, tokenizer_GPT
def load_gpt_base():
model_GPT = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer_GPT = GPT2TokenizerFast.from_pretrained("gpt2")
return model_GPT, tokenizer_GPT
# # Вешаем сохраненные веса на нашу модель
# Функция для генерации текста
def generate_text(model_GPT, tokenizer_GPT, prompt):
# Преобразование входной строки в токены
input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt')
# Генерация текста
output = model_GPT.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True,
temperature=1., top_k=50, top_p=0.6, no_repeat_ngram_size=3,
num_return_sequences=3)
# Декодирование сгенерированного текста
generated_text = tokenizer_GPT.decode(output[0], skip_special_tokens=True)
return generated_text
# Streamlit приложение
def main():
model_GPT, tokenizer_GPT = load_gpt()
st.write("""
# Fine-tuned GPT-2 for New Language with Custom Tokenizer
""")
# Ввод строки пользователем
prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века")
# # Генерация текста по введенной строке
# generated_text = generate_text(prompt)
# Создание кнопки "Сгенерировать"
generate_button = st.button("Complete!")
# Обработка события нажатия кнопки
if generate_button:
# Вывод сгенерированного текста
#generated_text = generate_text(model_GPT, tokenizer_GPT, prompt)
generated_text = 'test'
st.subheader("Completed prompt:")
st.write(generated_text)
# Ввод строки пользователем
prompt1 = st.text_area("Какую фразу нужно продолжить:", value="В средние века")
# # Генерация текста по введенной строке
# generated_text = generate_text(prompt)
# Создание кнопки "Сгенерировать"
generate_button1 = st.button("Complete!")
# Обработка события нажатия кнопки
if generate_button1:
# Вывод сгенерированного текста
#generated_text = generate_text(model_GPT, tokenizer_GPT, prompt)
generated_text = 'test'
st.subheader("Completed prompt:")
st.write(generated_text)
if __name__ == "__main__":
main()