|
import streamlit as st |
|
from Zmaker import Zmaker |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
with st.spinner(text = "loading GPT-2..."): |
|
if not ("AI" in st.session_state.keys()): |
|
st.session_state["AI"] = Zmaker( |
|
ft_path = "./models/" |
|
) |
|
|
|
|
|
with st.sidebar: |
|
st.title("GPT-2のパラメータ") |
|
|
|
|
|
sld_max_len = st.sidebar.slider( |
|
"length of the sentence", min_value = 0, max_value = 256, |
|
value = (25, 75), step = 1, key = "length" |
|
) |
|
|
|
|
|
sld_temp = st.sidebar.slider( |
|
"temperature", min_value = 0.1, max_value = 1.5, |
|
value = 0.1, step = 0.1, key = "temp" |
|
) |
|
|
|
|
|
sld_top_k = st.sidebar.slider( |
|
"top_k", min_value = 0, max_value = 500, |
|
value = 40, step = 1, key = "top_k" |
|
) |
|
|
|
|
|
sld_top_p = st.sidebar.slider( |
|
"top_p", min_value = 0.01, max_value = 1.0, |
|
value = 0.95, step = 0.01, key = "top_p" |
|
) |
|
|
|
|
|
sld_top_p = st.sidebar.slider( |
|
"repeat_ngram_size ", min_value = 1, max_value = 10, |
|
value = 1, step = 1, key = "repeat_ngram_size" |
|
) |
|
|
|
|
|
with st.form(key = "Letter Form", clear_on_submit = False): |
|
st.title("おてがみ 入力欄") |
|
body = st.empty() |
|
if ("letter_body" in st.session_state.keys()): |
|
ret = body.text_area( |
|
label = "お手紙を途中まで漢字+ひらがなで書いてください。続きをAIが生成します。\n"\ |
|
"本アプリで生成できるのは本文のみです。", |
|
value = st.session_state["letter_body"] |
|
) |
|
else: |
|
ret = body.text_area( |
|
label = "お手紙を途中まで漢字+ひらがなで書いてください。\n"\ |
|
"続きをAIが生成します。", |
|
value = "ズッポシ村へようこそ!" |
|
) |
|
sub = st.form_submit_button("Generate") |
|
|
|
|
|
with st.expander("注意事項"): |
|
st.text( |
|
"※このAIは「どうぶつの森e+実況プレイ」"\ |
|
" (https://www.nicovideo.jp/mylist/45062007)において"\ |
|
" 稲葉百万鉄氏により作成された文章を学習データに用いております。\n" |
|
" また,教師データの作成においてmintmama氏の作成した"\ |
|
" 「ズッポシむら手紙集」(https://www.nicovideo.jp/series/85494)\n"\ |
|
"を用いております。" |
|
) |
|
|
|
|
|
|
|
if sub == True: |
|
|
|
st.session_state["AI"].min_len = st.session_state["length"][0] |
|
st.session_state["AI"].max_len = st.session_state["length"][-1] |
|
st.session_state["AI"].top_k = st.session_state["top_k"] |
|
st.session_state["AI"].top_p = st.session_state["top_p"] |
|
st.session_state["AI"].temp = st.session_state["temp"] |
|
st.session_state["AI"].repeat_ngram_size = st.session_state["repeat_ngram_size"] |
|
|
|
|
|
with st.spinner(text = "generating..."): |
|
prompt = ret |
|
text = str(st.session_state["AI"].GenLetter("<s>"+prompt)[0]) |
|
text = text.replace('<s>', '') |
|
text = text.replace('</s>', '') |
|
st.session_state["letter_body"] = text |
|
st.experimental_rerun() |