Spaces:
Sleeping
Sleeping
import transformers | |
import streamlit as st | |
from transformers import GPT2LMHeadModel, GPT2TokenizerFast | |
import torch | |
st.title("Fine-tuned GPT-2 for New Language with Custom Tokenizer") | |
# Слайдеры для управления температурой и длиной текста | |
temperature = st.slider("Temperature", 0.1, 2.0, 1.0) # Для обеих моделей | |
max_len = st.slider("Max Length", 40, 120, 70) # Для обеих моделей | |
# Кеширование модели и токенизатора GPT-2 | |
def load_gpt2(): | |
model_gpt2 = GPT2LMHeadModel.from_pretrained("gpt2") | |
tokenizer_gpt2 = GPT2TokenizerFast.from_pretrained("gpt2") | |
return model_gpt2, tokenizer_gpt2 | |
# Кеширование кастомной модели и токенизатора | |
def load_custom_model(): | |
# Здесь замените путь на вашу кастомную модель | |
model_custom = GPT2LMHeadModel.from_pretrained("./rus_gpt2_tuned", from_tf=False, use_safetensors=True) | |
tokenizer_custom = GPT2TokenizerFast.from_pretrained("./rus_gpt2_tuned/tokenizer") | |
return model_custom, tokenizer_custom | |
# Функция для генерации текста | |
def generate_text(model, tokenizer, prompt, max_len, temperature): | |
input_ids = tokenizer.encode(prompt, return_tensors='pt') | |
attention_mask = torch.ones_like(input_ids) | |
# Генерация текста | |
output = model.generate( | |
input_ids, | |
max_length=max_len, | |
temperature=temperature, # Управление разнообразием текста | |
top_k=50, # Ограничение топ-50 самых вероятных слов | |
top_p=0.9, # Nucleus sampling (суммарная вероятность) | |
repetition_penalty=1.2, # Штраф за повторение слов или фраз | |
no_repeat_ngram_size=4, # Запрет на повторение n-грамм (например, биграмм) | |
do_sample=True, # Включение сэмплинга для большей разнообразности | |
attention_mask=attention_mask, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Декодирование сгенерированных токенов в текст | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
return generated_text | |
# Streamlit приложение | |
def main(): | |
model_gpt2, tokenizer_gpt2 = load_gpt2() # GPT-2 модель | |
model_custom, tokenizer_custom = load_custom_model() # Кастомная модель | |
#st.write("Fine-tuned GPT-2 for New Language with Custom Tokenizer") | |
# # Блок для генерации текста с GPT-2 | |
# st.subheader("GPT-2 Text Generation") | |
# prompt_gpt2 = st.text_area("Введите фразу для GPT-2 генерации:", value="В средние века") | |
# generate_button_gpt2 = st.button("Сгенерировать текст с GPT-2") | |
# if generate_button_gpt2: | |
# generated_text_gpt2 = generate_text(model_gpt2, tokenizer_gpt2, prompt_gpt2, max_len, temperature) | |
# st.subheader("Результат генерации GPT-2:") | |
# st.write(generated_text_gpt2) | |
# Блок для генерации текста с кастомной моделью | |
st.subheader("Custom Model Text Generation") | |
prompt_custom = st.text_area("Enter a phrase to generate with the updated model:", value="Когда-то давно") | |
generate_button_custom = st.button("Generate!") | |
if generate_button_custom: | |
generated_text_custom = generate_text(model_custom, tokenizer_custom, prompt_custom, max_len, temperature) | |
st.subheader("Result:") | |
st.write(generated_text_custom) | |
if __name__ == "__main__": | |
main() | |