Spaces:
Runtime error
Runtime error
Upload main.py
Browse files
main.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas
|
3 |
+
|
4 |
+
st.set_page_config(
|
5 |
+
page_title="JokerAI", page_icon="🎈", layout="centered"
|
6 |
+
)
|
7 |
+
|
8 |
+
def _max_width_():
|
9 |
+
max_width_str = f"max-width: 1400px;"
|
10 |
+
st.markdown(
|
11 |
+
f"""
|
12 |
+
<style>
|
13 |
+
.reportview-container .main .block-container{{
|
14 |
+
{max_width_str}
|
15 |
+
}}
|
16 |
+
</style>
|
17 |
+
""",
|
18 |
+
unsafe_allow_html=True,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
def main():
|
23 |
+
st.title("🤖 JokerAI")
|
24 |
+
|
25 |
+
st.write("""---""")
|
26 |
+
|
27 |
+
with st.sidebar.expander("ℹ️ - О приложении", expanded=True):
|
28 |
+
st.write(
|
29 |
+
"""
|
30 |
+
- *JokerAI* стремится сочинять стендап с помощью нейросетей. И это не шутки!
|
31 |
+
- Модель была натренирована на корпусе русскоязычных шуток и обучена при помощи архитекуры ruGPT3-large
|
32 |
+
"""
|
33 |
+
)
|
34 |
+
|
35 |
+
# * *temperature* — параметр сглаживания; чем выше, тем сильнее сглаживание вероятностного распределения токенов при предсказании
|
36 |
+
# * *top_k* — техника сэмплирования: сортировка предсказаний каждого следующего слова по вероятностям и отсекание вариантов после k-го токена
|
37 |
+
# * *top_p* — техника сэмплирования: сортировка предсказаний каждого следующего слова по вероятностям и отсекание вариантов, как только суммарная вероятность предыдущих токенов превысит p
|
38 |
+
# * *max_length* — максимальная длина генерируемого текста
|
39 |
+
# * *repetition_penalty* — «штрафование» слов, которые уже были сгенерированы или относятся к исходной фразе
|
40 |
+
# * *num_return_sequences* - количество вариантов последовательностей, которые вернёт модель
|
41 |
+
|
42 |
+
|
43 |
+
col1, col2, col3 = st.columns(3)
|
44 |
+
with col1:
|
45 |
+
temperature = st.number_input(
|
46 |
+
"Выберите параметр temperature", min_value=0.0, max_value=1.0, value=0.75, step=0.01,
|
47 |
+
help='Параметр сглаживания; чем выше, тем сильнее сглаживание вероятностного распределения токенов при предсказании'
|
48 |
+
)
|
49 |
+
max_length = st.number_input(
|
50 |
+
'Выберите параметр max_length', min_value=16, max_value=128, value=120, step=1,
|
51 |
+
help='Максимальная длина генерируемого текста'
|
52 |
+
)
|
53 |
+
with col2:
|
54 |
+
top_p = st.number_input(
|
55 |
+
"Выберите параметр top_p", min_value=0.0, max_value=1.0, value=0.92, step=0.01,
|
56 |
+
help='Техника сэмплирования: сортировка предсказаний каждого следующего слова по вероятностям и отсекание вариантов, как только суммарная вероятность предыдущих токенов превысит p'
|
57 |
+
)
|
58 |
+
# repeatition_penalty = st.number_input(
|
59 |
+
# "Выберите параметр repeatition_penalty", min_value=-30, max_value=30, step=1, value=0,
|
60 |
+
# help='«Штрафование» слов, которые уже были сгенерированы или относятся к исходной фразе'
|
61 |
+
# )
|
62 |
+
top_k = st.number_input(
|
63 |
+
"Выберите параметр top_k", min_value=0, max_value=100, value=50, step=1,
|
64 |
+
help='техника сэмплирования: сортировка предсказаний каждого следующего слова по вероятностям и отсекание вариантов после k-го токена'
|
65 |
+
)
|
66 |
+
with col3:
|
67 |
+
num_return_sequences = st.number_input(
|
68 |
+
'Выберите параметр num_return_sequences', min_value=0, max_value=7, value=3, step=1,
|
69 |
+
help='Количество вариантов последовательностей, которые вернёт модель'
|
70 |
+
)
|
71 |
+
|
72 |
+
st.write("""---""")
|
73 |
+
|
74 |
+
a, b = st.columns([4, 1])
|
75 |
+
user_input = a.text_input(
|
76 |
+
label="Your message:",
|
77 |
+
placeholder="Напишите затравку для шутки или скетча...",
|
78 |
+
label_visibility="collapsed",
|
79 |
+
)
|
80 |
+
button = b.button("Отправить", use_container_width=True)
|
81 |
+
|
82 |
+
if button:
|
83 |
+
st.write('serverstate message')
|
84 |
+
|
85 |
+
if __name__ == '__main__':
|
86 |
+
_max_width_()
|
87 |
+
main()
|