Extended_GPT_2 / app.py
SaviAnna's picture
Update app.py
12d19c9 verified
raw
history blame
3.45 kB
import transformers
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import torch
st.title("""
Fine-tuned GPT-2 for New Language with Custom Tokenizer
""")
# Слайдеры для управления температурой и длиной текста
temperature = st.slider("Temperature", 0.1, 2.0, 1.0) # Для обеих моделей
max_len = st.slider("Max Length", 40, 120, 70) # Для обеих моделей
# Кеширование модели и токенизатора GPT-2
@st.cache_resource
def load_gpt2():
model_gpt2 = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer_gpt2 = GPT2TokenizerFast.from_pretrained("gpt2")
return model_gpt2, tokenizer_gpt2
# Кеширование кастомной модели и токенизатора
@st.cache_resource
def load_custom_model():
# Здесь замените путь на вашу кастомную модель
model_custom = GPT2LMHeadModel.from_pretrained("rus_gpt2_tuned")
tokenizer_custom = GPT2TokenizerFast.from_pretrained("rus_gpt2_tuned")
return model_custom, tokenizer_custom
# Функция для генерации текста
def generate_text(model, tokenizer, prompt, max_len, temperature):
input_ids = tokenizer.encode(prompt, return_tensors='pt')
# Генерация текста
output = model.generate(input_ids=input_ids,
max_length=max_len,
do_sample=True,
temperature=temperature,
top_k=50,
top_p=0.6,
no_repeat_ngram_size=3,
num_return_sequences=1)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
# Streamlit приложение
def main():
model_gpt2, tokenizer_gpt2 = load_gpt2() # GPT-2 модель
model_custom, tokenizer_custom = load_custom_model() # Кастомная модель
st.write("""
# Fine-tuned GPT-2 for New Language with Custom Tokenizer
""")
# Блок для генерации текста с GPT-2
st.subheader("GPT-2 Text Generation")
prompt_gpt2 = st.text_area("Введите фразу для GPT-2 генерации:", value="В средние века")
generate_button_gpt2 = st.button("Сгенерировать текст с GPT-2")
if generate_button_gpt2:
generated_text_gpt2 = generate_text(model_gpt2, tokenizer_gpt2, prompt_gpt2, max_len, temperature)
st.subheader("Результат генерации GPT-2:")
st.write(generated_text_gpt2)
# Блок для генерации текста с кастомной моделью
st.subheader("Custom Model Text Generation")
prompt_custom = st.text_area("Введите фразу для генерации с кастомной моделью:", value="Когда-то давно")
generate_button_custom = st.button("Сгенерировать текст с кастомной моделью")
if generate_button_custom:
generated_text_custom = generate_text(model_custom, tokenizer_custom, prompt_custom, max_len, temperature)
st.subheader("Результат генерации с кастомной моделью:")
st.write(generated_text_custom)
if __name__ == "__main__":
main()