zupposhi-maker / app.py
John Doe
app.pyのアップデート
0be3887
raw
history blame
3.9 kB
import streamlit as st
from Zmaker import Zmaker
if __name__ == "__main__":
#ファインチューニング済みモデルの読み込み
with st.spinner(text = "loading GPT-2..."):
if not ("AI" in st.session_state.keys()):
st.session_state["AI"] = Zmaker(
ft_path = "./models/"
)
#設定用サイドバーの設定
with st.sidebar:
st.title("GPT-2のパラメータ")
#max_lenの設定用スライダ
sld_max_len = st.sidebar.slider(
"length of the sentence", min_value = 0, max_value = 256,
value = (25, 75), step = 1, key = "length"
)
#temperatureの設定用スライダ
sld_temp = st.sidebar.slider(
"temperature", min_value = 0.1, max_value = 1.5,
value = 0.1, step = 0.1, key = "temp"
)
#top_kの設定用スライダ
sld_top_k = st.sidebar.slider(
"top_k", min_value = 0, max_value = 500,
value = 40, step = 1, key = "top_k"
)
#top_pの設定用スライダ
sld_top_p = st.sidebar.slider(
"top_p", min_value = 0.01, max_value = 1.0,
value = 0.95, step = 0.01, key = "top_p"
)
#repeat_ngram_sizeの設定用スライダ
sld_top_p = st.sidebar.slider(
"repeat_ngram_size ", min_value = 1, max_value = 10,
value = 1, step = 1, key = "repeat_ngram_size"
)
#メインフォームの設定
with st.form(key = "Letter Form", clear_on_submit = False):
st.title("おてがみ 入力欄")
body = st.empty()
if ("letter_body" in st.session_state.keys()):
ret = body.text_area(
label = "お手紙を途中まで漢字+ひらがなで書いてください。続きをAIが生成します。\n"\
"本アプリで生成できるのは本文のみです。",
value = st.session_state["letter_body"]
)
else:
ret = body.text_area(
label = "お手紙を途中まで漢字+ひらがなで書いてください。\n"\
"続きをAIが生成します。",
value = "ズッポシ村へようこそ!"
)
sub = st.form_submit_button("Generate")
#注意事項
with st.expander("注意事項"):
st.text(
"※このAIは「どうぶつの森e+実況プレイ」"\
" (https://www.nicovideo.jp/mylist/45062007)において"\
" 稲葉百万鉄氏により作成された文章を学習データに用いております。\n"
" また,教師データの作成においてmintmama氏の作成した"\
" 「ズッポシむら手紙集」(https://www.nicovideo.jp/series/85494)\n"\
"を用いております。"
)
#submitボタンが押された
if sub == True:
#predictに必要な条件をGUIで設定した値に更新
st.session_state["AI"].min_len = st.session_state["length"][0]
st.session_state["AI"].max_len = st.session_state["length"][-1]
st.session_state["AI"].top_k = st.session_state["top_k"]
st.session_state["AI"].top_p = st.session_state["top_p"]
st.session_state["AI"].temp = st.session_state["temp"]
st.session_state["AI"].repeat_ngram_size = st.session_state["repeat_ngram_size"]
#AIによる予測を実行
with st.spinner(text = "generating..."):
prompt = ret
text = str(st.session_state["AI"].GenLetter("<s>"+prompt)[0])
text = text.replace('<s>', '')
text = text.replace('</s>', '')
st.session_state["letter_body"] = text
st.experimental_rerun()