File size: 5,759 Bytes
b609ad3
6977ab7
6482a95
eb576ee
7cbbb3d
b609ad3
7cbbb3d
b0c0555
 
 
7cbbb3d
 
 
 
a09694a
f4a6637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8458e2d
fadbf0b
a09694a
09bf3c9
1bb26c5
 
78208b9
3cab9f7
41169e8
e45098e
 
78208b9
05d24d0
 
ee28d22
05d24d0
 
8fec5fc
c7c1267
d3a1231
 
 
 
 
 
c7c1267
8fec5fc
d3a1231
 
c7c1267
e45098e
 
9f2284d
ac6b9a8
 
 
 
5ce7c60
ac6b9a8
467b498
5884829
a93f461
d3a1231
5705b27
d3a1231
e45098e
 
 
5705b27
e45098e
9e6846d
5705b27
e45098e
 
5705b27
 
e55e238
 
 
 
 
 
 
 
d6bf0b8
ff73010
4a92690
ff73010
467b498
 
 
5e1c7d5
ff73010
467b498
 
ac6b9a8
 
 
d6bf0b8
467b498
ff73010
467b498
ff73010
b713ddf
4076728
5e1c7d5
 
 
 
7683f73
d3a1231
f5909c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import streamlit as st
import cshogi
from IPython.display import display
from transformers import T5ForConditionalGeneration, T5Tokenizer
import pandas as pd

#モデルの読み込み
tokenizer = T5Tokenizer.from_pretrained("pizzagatakasugi/shogi_t5", is_fast=True)
model = T5ForConditionalGeneration.from_pretrained("pizzagatakasugi/shogi_t5")
model.eval()

st.title("将棋解説文の自動生成")
df = pd.read_csv("./dataset10.csv")
num = st.text_input("0から9の数字を入力")

KIFU_TO_SQUARE_NAMES = [
                    '1一', '1二', '1三', '1四', '1五', '1六', '1七', '1八', '1九',
                    '2一', '2二', '2三', '2四', '2五', '2六', '2七', '2八', '2九',
                    '3一', '3二', '3三', '3四', '3五', '3六', '3七', '3八', '3九',
                    '4一', '4二', '4三', '4四', '4五', '4六', '4七', '4八', '4九',
                    '5一', '5二', '5三', '5四', '5五', '5六', '5七', '5八', '5九',
                    '6一', '6二', '6三', '6四', '6五', '6六', '6七', '6八', '6九',
                    '7一', '7二', '7三', '7四', '7五', '7六', '7七', '7八', '7九',
                    '8一', '8二', '8三', '8四', '8五', '8六', '8七', '8八', '8九',
                    '9一', '9二', '9三', '9四', '9五', '9六', '9七', '9八', '9九',
                ]
KIFU_FROM_SQUARE_NAMES = [
                    '11', '12', '13', '14', '15', '16', '17', '18', '19',
                    '21', '22', '23', '24', '25', '26', '27', '28', '29',
                    '31', '32', '33', '34', '35', '36', '37', '38', '39',
                    '41', '42', '43', '44', '45', '46', '47', '48', '49',
                    '51', '52', '53', '54', '55', '56', '57', '58', '59',
                    '61', '62', '63', '64', '65', '66', '67', '68', '69',
                    '71', '72', '73', '74', '75', '76', '77', '78', '79',
                    '81', '82', '83', '84', '85', '86', '87', '88', '89',
                    '91', '92', '93', '94', '95', '96', '97', '98', '99',
                ]

if  num in [str(x) for x in list(range(10))]:
    df = df.iloc[int(num)]
    st.write(df["game_type"],df["precedence_name"],df["follower_name"])
    sfen = df["sfen"].split("\n")
    bestlist = eval(df["bestlist"])
    best2list = eval(df["best2list"])
    te = []
    te_sf = []
    movelist = []

    #文字の正規化
    for x in range(len(sfen)):
        if x < 2:
            continue
        if len(sfen[x]) > 30:
            te_sf.append(sfen[x])
        else:
            #te.append(sfen[x])
            temp = sfen[x].split()
            num = temp[1][0] + temp[1][1]
            for y in range(len(KIFU_FROM_SQUARE_NAMES)):
                if num == KIFU_FROM_SQUARE_NAMES[y]:
                    sq = KIFU_TO_SQUARE_NAMES[y]
            word = sq+temp[1][2:]
            word = word.replace("竜","龍").replace("成銀","全").replace("成桂","圭").replace("成香","杏")
            if sfen[x].split()[1] not in ["投了" , "千日手" , "持将棋" , "反則勝ち"]:
                te.append(temp[0]+" "+word)
                movelist.append(word)
            else:
                movelist.append(sfen[x].split()[1])

    #盤面表示
    s = st.selectbox(label="手数を選択",options=te)

    with st.expander("parameter"):
        temp = st.slider("temperature",min_value=0.0,max_value=1.0,step=0.01,value=0.3,key=1)
        beams = st.slider("num_beams",min_value=1,max_value=5,step=1,value=1,key=2)
        tokens = st.slider("min_new_tokens",min_value=0,max_value=50,value=20,key=3)
    
    reload = st.button('盤面生成',key=0)
    if s in te and reload == True:
        reload = False
        idx = te.index(s)
        board = cshogi.Board(sfen=te_sf[idx+1])
        st.markdown(board.to_svg(),unsafe_allow_html=True)

        #入力文作成
        kifs=""
        cnt = 0
        for kif in movelist:
            if cnt > idx:
                break
            kif = kif.split("(")[0]
            kifs += kif
            cnt += 1
            
        best = ""
        for x in bestlist[idx]:
            best += x.split("(")[0]

        best2 = ""
        for y in best2list[idx]:
            best2 += y.split("(")[0]

        #st.write(idx,"入力",input)
        with st.spinner("推論中です..."):
            input = sfen[0]+sfen[1]+kifs+"最善手の予測手順は"+best+"次善手の予測手順は"+best2
            tokenized_inputs = tokenizer.encode(
                                    input, max_length= 512, truncation=True, 
                                    padding="max_length", return_tensors="pt"
                                )
        
            output_ids = model.generate(input_ids=tokenized_inputs,
                            max_length=512,
                            repetition_penalty=10.0, # 同じ文の繰り返しへのペナルティ
                            temperature = temp,
                            num_beams = beams,
                            min_new_tokens = tokens,
                            )
        
            output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True,
                                                clean_up_tokenization_spaces=False)
            st.write(output_text)

    
        # temperature = st.slider("temperature",min_value=0.0,max_value=1.0,step=0.01,value=0.3,key=1)
        # num_beams = st.slider("num_beams",min_value=1,max_value=5,step=1,value=1,key=2)
        # min_new_tokens = st.slider("min_new_tokens",min_value=0,max_value=100,value=30,key=3)