Spaces:
Sleeping
Sleeping
File size: 3,968 Bytes
460d569 b628870 460d569 cc91e6e 460d569 b628870 12d19c9 b628870 12d19c9 b628870 12d19c9 63d047a aabd4e2 12d19c9 460d569 12d19c9 c3d7acd 460d569 356b987 cc91e6e 073629e cc91e6e 460d569 12d19c9 9c47ff5 8bdd958 460d569 17a60a4 12d19c9 17a60a4 9c47ff5 12d19c9 1c8cc34 460d569 12d19c9 1c8cc34 12d19c9 460d569 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import transformers
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import torch
st.title("Fine-tuned GPT-2 for New Language with Custom Tokenizer")
# Слайдеры для управления температурой и длиной текста
temperature = st.slider("Temperature", 0.1, 2.0, 1.0) # Для обеих моделей
max_len = st.slider("Max Length", 40, 120, 70) # Для обеих моделей
# Кеширование модели и токенизатора GPT-2
@st.cache_resource
def load_gpt2():
model_gpt2 = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer_gpt2 = GPT2TokenizerFast.from_pretrained("gpt2")
return model_gpt2, tokenizer_gpt2
# Кеширование кастомной модели и токенизатора
@st.cache_resource
def load_custom_model():
# Здесь замените путь на вашу кастомную модель
model_custom = GPT2LMHeadModel.from_pretrained("./rus_gpt2_tuned", from_tf=False, use_safetensors=True)
tokenizer_custom = GPT2TokenizerFast.from_pretrained("./rus_gpt2_tuned/tokenizer")
return model_custom, tokenizer_custom
# Функция для генерации текста
def generate_text(model, tokenizer, prompt, max_len, temperature):
input_ids = tokenizer.encode(prompt, return_tensors='pt')
attention_mask = torch.ones_like(input_ids)
# Генерация текста
output = model.generate(
input_ids,
max_length=max_len,
temperature=temperature, # Управление разнообразием текста
top_k=50, # Ограничение топ-50 самых вероятных слов
top_p=0.9, # Nucleus sampling (суммарная вероятность)
repetition_penalty=1.2, # Штраф за повторение слов или фраз
no_repeat_ngram_size=4, # Запрет на повторение n-грамм (например, биграмм)
do_sample=True, # Включение сэмплинга для большей разнообразности
attention_mask=attention_mask,
pad_token_id=tokenizer.eos_token_id
)
# Декодирование сгенерированных токенов в текст
generated_text = tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
return generated_text
# Streamlit приложение
def main():
model_gpt2, tokenizer_gpt2 = load_gpt2() # GPT-2 модель
model_custom, tokenizer_custom = load_custom_model() # Кастомная модель
#st.write("Fine-tuned GPT-2 for New Language with Custom Tokenizer")
# # Блок для генерации текста с GPT-2
# st.subheader("GPT-2 Text Generation")
# prompt_gpt2 = st.text_area("Введите фразу для GPT-2 генерации:", value="В средние века")
# generate_button_gpt2 = st.button("Сгенерировать текст с GPT-2")
# if generate_button_gpt2:
# generated_text_gpt2 = generate_text(model_gpt2, tokenizer_gpt2, prompt_gpt2, max_len, temperature)
# st.subheader("Результат генерации GPT-2:")
# st.write(generated_text_gpt2)
# Блок для генерации текста с кастомной моделью
st.subheader("Custom Model Text Generation")
prompt_custom = st.text_area("Enter a phrase to generate with the updated model:", value="Когда-то давно")
generate_button_custom = st.button("Generate!")
if generate_button_custom:
generated_text_custom = generate_text(model_custom, tokenizer_custom, prompt_custom, max_len, temperature)
st.subheader("Result:")
st.write(generated_text_custom)
if __name__ == "__main__":
main()
|