Spaces:
Sleeping
Sleeping
import transformers | |
import streamlit as st | |
from transformers import GPT2LMHeadModel, GPT2TokenizerFast | |
import torch | |
st.title(""" | |
Fine-tuned GPT-2 for New Language with Custom Tokenizer | |
""") | |
# Слайдеры для управления температурой и длиной текста | |
temperature = st.slider("Temperature", 0.1, 2.0, 1.0) # Для обеих моделей | |
max_len = st.slider("Max Length", 40, 120, 70) # Для обеих моделей | |
# Кеширование модели и токенизатора GPT-2 | |
def load_gpt2(): | |
model_gpt2 = GPT2LMHeadModel.from_pretrained("gpt2") | |
tokenizer_gpt2 = GPT2TokenizerFast.from_pretrained("gpt2") | |
return model_gpt2, tokenizer_gpt2 | |
# Кеширование кастомной модели и токенизатора | |
def load_custom_model(): | |
# Здесь замените путь на вашу кастомную модель | |
model_custom = GPT2LMHeadModel.from_pretrained("rus_gpt2_tuned") | |
tokenizer_custom = GPT2TokenizerFast.from_pretrained("rus_gpt2_tuned") | |
return model_custom, tokenizer_custom | |
# Функция для генерации текста | |
def generate_text(model, tokenizer, prompt, max_len, temperature): | |
input_ids = tokenizer.encode(prompt, return_tensors='pt') | |
# Генерация текста | |
output = model.generate(input_ids=input_ids, | |
max_length=max_len, | |
do_sample=True, | |
temperature=temperature, | |
top_k=50, | |
top_p=0.6, | |
no_repeat_ngram_size=3, | |
num_return_sequences=1) | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
return generated_text | |
# Streamlit приложение | |
def main(): | |
model_gpt2, tokenizer_gpt2 = load_gpt2() # GPT-2 модель | |
model_custom, tokenizer_custom = load_custom_model() # Кастомная модель | |
st.write(""" | |
# Fine-tuned GPT-2 for New Language with Custom Tokenizer | |
""") | |
# Блок для генерации текста с GPT-2 | |
st.subheader("GPT-2 Text Generation") | |
prompt_gpt2 = st.text_area("Введите фразу для GPT-2 генерации:", value="В средние века") | |
generate_button_gpt2 = st.button("Сгенерировать текст с GPT-2") | |
if generate_button_gpt2: | |
generated_text_gpt2 = generate_text(model_gpt2, tokenizer_gpt2, prompt_gpt2, max_len, temperature) | |
st.subheader("Результат генерации GPT-2:") | |
st.write(generated_text_gpt2) | |
# Блок для генерации текста с кастомной моделью | |
st.subheader("Custom Model Text Generation") | |
prompt_custom = st.text_area("Введите фразу для генерации с кастомной моделью:", value="Когда-то давно") | |
generate_button_custom = st.button("Сгенерировать текст с кастомной моделью") | |
if generate_button_custom: | |
generated_text_custom = generate_text(model_custom, tokenizer_custom, prompt_custom, max_len, temperature) | |
st.subheader("Результат генерации с кастомной моделью:") | |
st.write(generated_text_custom) | |
if __name__ == "__main__": | |
main() | |