jokerai / app.py
dokster's picture
Update app.py
ea3334f
import streamlit as st
import pandas
import torch
from transformers import (
GPT2Tokenizer, GPT2LMHeadModel, TextDataset, LineByLineTextDataset,
DataCollatorForLanguageModeling, TrainingArguments, Trainer, get_cosine_schedule_with_warmup
)
# output_dir = '/dokster/jokerai/model'
bos_token = '<joke>'
eos_token = '<end>'
model = GPT2LMHeadModel.from_pretrained('./model')
tokenizer = GPT2Tokenizer.from_pretrained('./model')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
st.set_page_config(
page_title="JokerAI", page_icon="🎈", layout="centered"
)
def _max_width_():
max_width_str = f"max-width: 1400px;"
st.markdown(
f"""
<style>
.reportview-container .main .block-container{{
{max_width_str}
}}
</style>
""",
unsafe_allow_html=True,
)
def generate(prompt, temperature, max_length, top_k, top_p, num_return_sequences):
model.eval()
answers = []
input_ids = tokenizer.encode(prompt, return_tensors='pt')
input_ids = input_ids.to(device) #переносим инпут на GPU, где находится наша модель
eos_id = tokenizer.encode(eos_token)[0]
sample_outputs = model.generate(
input_ids,
max_length=max_length,
do_sample=True,
top_k=top_k,
top_p=top_p,
temperature=temperature,
eos_token_id=eos_id,
num_return_sequences=num_return_sequences
)
for i, sample_output in enumerate(sample_outputs):
answers.append(tokenizer.decode(sample_output, skip_special_tokens=True).replace('\n', ' '))
return answers
def main():
st.title("🤖 JokerAI")
st.write("""---""")
with st.sidebar.expander("ℹ️ - О приложении", expanded=True):
st.write(
"""
- *JokerAI* стремится сочинять стендап с помощью нейросетей. И это не шутки!
- Модель была натренирована на корпусе русскоязычных шуток и обучена при помощи архитекуры ruGPT3-large
"""
)
# * *temperature* — параметр сглаживания; чем выше, тем сильнее сглаживание вероятностного распределения токенов при предсказании
# * *top_k* — техника сэмплирования: сортировка предсказаний каждого следующего слова по вероятностям и отсекание вариантов после k-го токена
# * *top_p* — техника сэмплирования: сортировка предсказаний каждого следующего слова по вероятностям и отсекание вариантов, как только суммарная вероятность предыдущих токенов превысит p
# * *max_length* — максимальная длина генерируемого текста
# * *repetition_penalty* — «штрафование» слов, которые уже были сгенерированы или относятся к исходной фразе
# * *num_return_sequences* - количество вариантов последовательностей, которые вернёт модель
col1, col2, col3 = st.columns(3)
with col1:
temperature = st.number_input(
"Выберите параметр temperature", min_value=0.0, max_value=1.0, value=0.75, step=0.01,
help='Параметр сглаживания; чем выше, тем сильнее сглаживание вероятностного распределения токенов при предсказании'
)
max_length = st.number_input(
'Выберите параметр max_length', min_value=32, max_value=64, value=40, step=1,
help='Максимальная длина генерируемого текста'
)
with col2:
top_p = st.number_input(
"Выберите параметр top_p", min_value=0.0, max_value=1.0, value=0.92, step=0.01,
help='Техника сэмплирования: сортировка предсказаний каждого следующего слова по вероятностям и отсекание вариантов, как только суммарная вероятность предыдущих токенов превысит p'
)
top_k = st.number_input(
"Выберите параметр top_k", min_value=0, max_value=100, value=50, step=1,
help='Техника сэмплирования: сортировка предсказаний каждого следующего слова по вероятностям и отсекание вариантов после k-го токена'
)
with col3:
num_return_sequences = st.number_input(
'Выберите параметр num_return_sequences', min_value=1, max_value=4, value=3, step=1,
help='Количество вариантов последовательностей, которые вернёт модель'
)
st.write("""---""")
a, b = st.columns([4, 1])
user_input = a.text_input(
label="Напишите затравку для шутки или скетча...",
placeholder="Напишите затравку для шутки или скетча...",
label_visibility="collapsed",
value='Зашла улитка в бар'
)
button = b.button("Отправить", use_container_width=True)
if user_input == '':
user_input = '<joke>'
if button:
answers = generate(user_input, temperature, max_length, top_k, top_p, num_return_sequences)
# st.write(answers)
for answer in answers:
st.info(answer)
if __name__ == '__main__':
_max_width_()
main()