Spaces:
Runtime error
Runtime error
import streamlit as st | |
from Zmaker import Zmaker | |
if __name__ == "__main__": | |
#ファインチューニング済みモデルの読み込み | |
with st.spinner(text = "loading GPT-2..."): | |
if not ("AI" in st.session_state.keys()): | |
st.session_state["AI"] = Zmaker( | |
ft_path = "./models/" | |
) | |
#設定用サイドバーの設定 | |
with st.sidebar: | |
st.title("GPT-2のパラメータ") | |
#max_lenの設定用スライダ | |
sld_max_len = st.sidebar.slider( | |
"length of the sentence", min_value = 0, max_value = 256, | |
value = (25, 75), step = 1, key = "length" | |
) | |
#temperatureの設定用スライダ | |
sld_temp = st.sidebar.slider( | |
"temperature", min_value = 0.1, max_value = 1.5, | |
value = 0.1, step = 0.1, key = "temp" | |
) | |
#top_kの設定用スライダ | |
sld_top_k = st.sidebar.slider( | |
"top_k", min_value = 0, max_value = 500, | |
value = 40, step = 1, key = "top_k" | |
) | |
#top_pの設定用スライダ | |
sld_top_p = st.sidebar.slider( | |
"top_p", min_value = 0.01, max_value = 1.0, | |
value = 0.95, step = 0.01, key = "top_p" | |
) | |
#repeat_ngram_sizeの設定用スライダ | |
sld_top_p = st.sidebar.slider( | |
"repeat_ngram_size ", min_value = 1, max_value = 10, | |
value = 1, step = 1, key = "repeat_ngram_size" | |
) | |
#メインフォームの設定 | |
with st.form(key = "Letter Form", clear_on_submit = False): | |
st.title("おてがみ 入力欄") | |
body = st.empty() | |
if ("letter_body" in st.session_state.keys()): | |
ret = body.text_area( | |
label = "お手紙を途中まで漢字+ひらがなで書いてください。続きをAIが生成します。\n"\ | |
"本アプリで生成できるのは本文のみです。", | |
value = st.session_state["letter_body"] | |
) | |
else: | |
ret = body.text_area( | |
label = "お手紙を途中まで漢字+ひらがなで書いてください。\n"\ | |
"続きをAIが生成します。", | |
value = "ズッポシ村へようこそ!" | |
) | |
sub = st.form_submit_button("Generate") | |
#注意事項 | |
with st.expander("注意事項"): | |
st.text( | |
"※このAIは「どうぶつの森e+実況プレイ」"\ | |
" (https://www.nicovideo.jp/mylist/45062007)において"\ | |
" 稲葉百万鉄氏により作成された文章を学習データに用いております。\n" | |
" また,教師データの作成においてmintmama氏の作成した"\ | |
" 「ズッポシむら手紙集」(https://www.nicovideo.jp/series/85494)\n"\ | |
"を用いております。" | |
) | |
#submitボタンが押された | |
if sub == True: | |
#predictに必要な条件をGUIで設定した値に更新 | |
st.session_state["AI"].min_len = st.session_state["length"][0] | |
st.session_state["AI"].max_len = st.session_state["length"][-1] | |
st.session_state["AI"].top_k = st.session_state["top_k"] | |
st.session_state["AI"].top_p = st.session_state["top_p"] | |
st.session_state["AI"].temp = st.session_state["temp"] | |
st.session_state["AI"].repeat_ngram_size = st.session_state["repeat_ngram_size"] | |
#AIによる予測を実行 | |
with st.spinner(text = "generating..."): | |
prompt = ret | |
text = str(st.session_state["AI"].GenLetter("<s>"+prompt)[0]) | |
text = text.replace('<s>', '') | |
text = text.replace('</s>', '') | |
st.session_state["letter_body"] = text | |
st.experimental_rerun() |