Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# Загрузка модели и токенизатора | |
def load_model(): | |
model_name = "models/gpt" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
return model, tokenizer | |
def generate_text(model, tokenizer, prompt, gen_params): | |
inputs = tokenizer(prompt, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
max_length=gen_params['max_length'], | |
temperature=gen_params['temperature'], | |
top_k=gen_params['top_k'], | |
top_p=gen_params['top_p'], | |
num_return_sequences=gen_params['num_return_sequences'], | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
generated = [] | |
for i, output in enumerate(outputs): | |
text = tokenizer.decode(output, skip_special_tokens=True) | |
generated.append(f"Генерация {i+1}:\n{text}\n{'-'*50}") | |
return generated | |
def main(): | |
st.markdown( | |
"<h1 style='text-align: center;'>Генератор текста</h1>", | |
unsafe_allow_html=True | |
) | |
st.markdown( | |
"<h3 style='text-align: center;'>(ну почти)</h3>", | |
unsafe_allow_html=True | |
) | |
st.markdown("---") | |
col1, col2, col3 = st.columns([1, 2, 1]) | |
with col2: | |
st.image('images/scale_1200.png', width=500) | |
# Загрузка модели | |
model, tokenizer = load_model() | |
# Параметры генерации | |
with st.sidebar: | |
st.header("Настройки генерации") | |
prompt = st.text_area("Введите начальный текст:", height=100) | |
max_length = st.slider("Максимальная длина:", 50, 500, 100) | |
num_return_sequences = st.slider("Число генераций:", 1, 5, 1) | |
st.subheader("Параметры выборки:") | |
sampling_method = st.radio("Метод:", ["Temperature", "Top-k & Top-p"]) | |
if sampling_method == "Temperature": | |
temperature = st.slider("Temperature:", 0.1, 2.0, 1.0, 0.1) | |
top_k = None | |
top_p = None | |
else: | |
temperature = 1.0 | |
top_k = st.slider("Top-k:", 1, 100, 50) | |
top_p = st.slider("Top-p:", 0.1, 1.0, 0.9, 0.05) | |
# Кнопка генерации | |
if st.sidebar.button("Сгенерировать текст"): | |
if not prompt: | |
st.warning("Введите начальный текст!") | |
return | |
gen_params = { | |
'max_length': max_length, | |
'temperature': temperature, | |
'top_k': top_k, | |
'top_p': top_p, | |
'num_return_sequences': num_return_sequences | |
} | |
with st.spinner("Прибухиваем..."): | |
generated = generate_text(model, tokenizer, prompt, gen_params) | |
st.markdown("---") | |
st.subheader("Результаты:") | |
for text in generated: | |
st.text_area(label="", value=text, height=200) | |
if __name__ == "__main__": | |
main() |