Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas | |
import torch | |
from transformers import ( | |
GPT2Tokenizer, GPT2LMHeadModel, TextDataset, LineByLineTextDataset, | |
DataCollatorForLanguageModeling, TrainingArguments, Trainer, get_cosine_schedule_with_warmup | |
) | |
# output_dir = '/dokster/jokerai/model' | |
bos_token = '<joke>' | |
eos_token = '<end>' | |
model = GPT2LMHeadModel.from_pretrained('./model') | |
tokenizer = GPT2Tokenizer.from_pretrained('./model') | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
st.set_page_config( | |
page_title="JokerAI", page_icon="🎈", layout="centered" | |
) | |
def _max_width_(): | |
max_width_str = f"max-width: 1400px;" | |
st.markdown( | |
f""" | |
<style> | |
.reportview-container .main .block-container{{ | |
{max_width_str} | |
}} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
def generate(prompt, temperature, max_length, top_k, top_p, num_return_sequences): | |
model.eval() | |
answers = [] | |
input_ids = tokenizer.encode(prompt, return_tensors='pt') | |
input_ids = input_ids.to(device) #переносим инпут на GPU, где находится наша модель | |
eos_id = tokenizer.encode(eos_token)[0] | |
sample_outputs = model.generate( | |
input_ids, | |
max_length=max_length, | |
do_sample=True, | |
top_k=top_k, | |
top_p=top_p, | |
temperature=temperature, | |
eos_token_id=eos_id, | |
num_return_sequences=num_return_sequences | |
) | |
for i, sample_output in enumerate(sample_outputs): | |
answers.append(tokenizer.decode(sample_output, skip_special_tokens=True).replace('\n', ' ')) | |
return answers | |
def main(): | |
st.title("🤖 JokerAI") | |
st.write("""---""") | |
with st.sidebar.expander("ℹ️ - О приложении", expanded=True): | |
st.write( | |
""" | |
- *JokerAI* стремится сочинять стендап с помощью нейросетей. И это не шутки! | |
- Модель была натренирована на корпусе русскоязычных шуток и обучена при помощи архитекуры ruGPT3-large | |
""" | |
) | |
# * *temperature* — параметр сглаживания; чем выше, тем сильнее сглаживание вероятностного распределения токенов при предсказании | |
# * *top_k* — техника сэмплирования: сортировка предсказаний каждого следующего слова по вероятностям и отсекание вариантов после k-го токена | |
# * *top_p* — техника сэмплирования: сортировка предсказаний каждого следующего слова по вероятностям и отсекание вариантов, как только суммарная вероятность предыдущих токенов превысит p | |
# * *max_length* — максимальная длина генерируемого текста | |
# * *repetition_penalty* — «штрафование» слов, которые уже были сгенерированы или относятся к исходной фразе | |
# * *num_return_sequences* - количество вариантов последовательностей, которые вернёт модель | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
temperature = st.number_input( | |
"Выберите параметр temperature", min_value=0.0, max_value=1.0, value=0.75, step=0.01, | |
help='Параметр сглаживания; чем выше, тем сильнее сглаживание вероятностного распределения токенов при предсказании' | |
) | |
max_length = st.number_input( | |
'Выберите параметр max_length', min_value=32, max_value=64, value=40, step=1, | |
help='Максимальная длина генерируемого текста' | |
) | |
with col2: | |
top_p = st.number_input( | |
"Выберите параметр top_p", min_value=0.0, max_value=1.0, value=0.92, step=0.01, | |
help='Техника сэмплирования: сортировка предсказаний каждого следующего слова по вероятностям и отсекание вариантов, как только суммарная вероятность предыдущих токенов превысит p' | |
) | |
top_k = st.number_input( | |
"Выберите параметр top_k", min_value=0, max_value=100, value=50, step=1, | |
help='Техника сэмплирования: сортировка предсказаний каждого следующего слова по вероятностям и отсекание вариантов после k-го токена' | |
) | |
with col3: | |
num_return_sequences = st.number_input( | |
'Выберите параметр num_return_sequences', min_value=1, max_value=4, value=3, step=1, | |
help='Количество вариантов последовательностей, которые вернёт модель' | |
) | |
st.write("""---""") | |
a, b = st.columns([4, 1]) | |
user_input = a.text_input( | |
label="Напишите затравку для шутки или скетча...", | |
placeholder="Напишите затравку для шутки или скетча...", | |
label_visibility="collapsed", | |
value='Зашла улитка в бар' | |
) | |
button = b.button("Отправить", use_container_width=True) | |
if user_input == '': | |
user_input = '<joke>' | |
if button: | |
answers = generate(user_input, temperature, max_length, top_k, top_p, num_return_sequences) | |
# st.write(answers) | |
for answer in answers: | |
st.info(answer) | |
if __name__ == '__main__': | |
_max_width_() | |
main() |