import streamlit as st import pandas import torch from transformers import ( GPT2Tokenizer, GPT2LMHeadModel, TextDataset, LineByLineTextDataset, DataCollatorForLanguageModeling, TrainingArguments, Trainer, get_cosine_schedule_with_warmup ) output_dir = '/model' model = GPT2LMHeadModel.from_pretrained(output_dir) tokenizer = GPT2Tokenizer.from_pretrained(output_dir) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) st.set_page_config( page_title="JokerAI", page_icon="🎈", layout="centered" ) def _max_width_(): max_width_str = f"max-width: 1400px;" st.markdown( f""" """, unsafe_allow_html=True, ) def generate(prompt, temperature, max_length, top_k, top_p, num_return_sequences): model.eval() answers = [] input_ids = tokenizer.encode(prompt, return_tensors='pt') input_ids = input_ids.to(device) #переносим инпут на GPU, где находится наша модель eos_id = tokenizer.encode(eos_token)[0] sample_outputs = model.generate( input_ids, max_length=max_length, do_sample=True, top_k=top_k, top_p=top_p, temperature=temperature, eos_token_id=eos_id, num_return_sequences=num_return_sequences ) for i, sample_output in enumerate(sample_outputs): answers.append("{}".format(tokenizer.decode(sample_output, skip_special_tokens=True))) return answers def main(): st.title("🤖 JokerAI") st.write("""---""") with st.sidebar.expander("ℹ️ - О приложении", expanded=True): st.write( """ - *JokerAI* стремится сочинять стендап с помощью нейросетей. И это не шутки! - Модель была натренирована на корпусе русскоязычных шуток и обучена при помощи архитекуры ruGPT3-large """ ) # * *temperature* — параметр сглаживания; чем выше, тем сильнее сглаживание вероятностного распределения токенов при предсказании # * *top_k* — техника сэмплирования: сортировка предсказаний каждого следующего слова по вероятностям и отсекание вариантов после k-го токена # * *top_p* — техника сэмплирования: сортировка предсказаний каждого следующего слова по вероятностям и отсекание вариантов, как только суммарная вероятность предыдущих токенов превысит p # * *max_length* — максимальная длина генерируемого текста # * *repetition_penalty* — «штрафование» слов, которые уже были сгенерированы или относятся к исходной фразе # * *num_return_sequences* - количество вариантов последовательностей, которые вернёт модель col1, col2, col3 = st.columns(3) with col1: temperature = st.number_input( "Выберите параметр temperature", min_value=0.0, max_value=1.0, value=0.75, step=0.01, help='Параметр сглаживания; чем выше, тем сильнее сглаживание вероятностного распределения токенов при предсказании' ) max_length = st.number_input( 'Выберите параметр max_length', min_value=16, max_value=128, value=120, step=1, help='Максимальная длина генерируемого текста' ) with col2: top_p = st.number_input( "Выберите параметр top_p", min_value=0.0, max_value=1.0, value=0.92, step=0.01, help='Техника сэмплирования: сортировка предсказаний каждого следующего слова по вероятностям и отсекание вариантов, как только суммарная вероятность предыдущих токенов превысит p' ) top_k = st.number_input( "Выберите параметр top_k", min_value=0, max_value=100, value=50, step=1, help='техника сэмплирования: сортировка предсказаний каждого следующего слова по вероятностям и отсекание вариантов после k-го токена' ) with col3: num_return_sequences = st.number_input( 'Выберите параметр num_return_sequences', min_value=0, max_value=7, value=3, step=1, help='Количество вариантов последовательностей, которые вернёт модель' ) st.write("""---""") a, b = st.columns([4, 1]) user_input = a.text_input( label="Your message:", placeholder="Напишите затравку для шутки или скетча...", label_visibility="collapsed", value='' ) button = b.button("Отправить", use_container_width=True) if button: answers = generate(user_input, temperature, num_return_sequences, top_k, top_p, max_length) for answer in answers: st.write(answer) if __name__ == '__main__': _max_width_() main()