File size: 3,684 Bytes
d288960
770f950
 
0154f58
e7988f7
a529397
b3c5de3
a4ba045
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dea0b2e
c609ed6
dea0b2e
a4ba045
 
 
 
 
 
770f950
c609ed6
 
 
 
 
 
 
22beef3
 
 
 
 
 
 
 
 
 
 
 
 
 
74c4804
22beef3
 
c609ed6
 
 
 
 
 
22beef3
 
c609ed6
a529397
 
 
 
 
 
22beef3
 
c609ed6
22beef3
 
c609ed6
 
22beef3
c609ed6
 
770f950
e7988f7
22beef3
 
 
 
 
 
 
c609ed6
 
22beef3
 
c609ed6
22beef3
 
c609ed6
22beef3
a4ba045
22beef3
c609ed6
 
22beef3
 
 
c609ed6
e7988f7
770f950
22beef3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from my_chess.scripts.scripts import HumanVsBot
from my_chess.learner.environments import Chess

from chessmodels import DCMinMax
import streamlit as st
from s_i_c import streamlit_image_coordinates

import torch

class RandomPlayer(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, board, input):
        options = torch.arange(input['action_mask'].size)[input['action_mask'].astype(bool)]
        choice = torch.randint(options.numel(), (1,))
        return options[choice].item()

BOT_SELECITONS = {
    "RandomMover":RandomPlayer,
    "DeepChessReplicationMinMax":DCMinMax,
    }

@st.cache_resource(max_entries=4)
def load_model(bot_select):
    model = None
    if bot_select in ["RandomMover"]:
        model = BOT_SELECITONS[bot_select]()
    else:
        model = BOT_SELECITONS[bot_select].from_pretrained("mzimm003/{}".format(bot_select))
    return model

class TrackPlay:
    def __init__(self):
        self.active_game = False
        self.bot_select = None
    def reset(self):
        self.active_game = False
        self.bot_select = None

def information(play, t_p):
    display = {
        True:{True:"Your turn!",False:"Bot is thinking..."},
        False:{"1-0":"White Wins!","0-1":"Black Wins!","1/2-1/2":"It's a Draw!"},
    }
    hum_turn = play.get_human_player() == play.get_curr_player()
    game_over = play.is_done()
    choose1 = not game_over
    choose2 = play.get_result() if game_over else hum_turn
    return display[choose1][choose2]

def reset(t_p):
    t_p.reset()
    del st.session_state["play"]

def play(bot_select, t_p):
    if "play" not in st.session_state:
        st.session_state.play = HumanVsBot(
            model=load_model(bot_select=bot_select),
            environment=Chess(render_mode="rgb_array"),
            extra_model_environment_context=lambda env: {"board":env.board},
        )
    col1, col2 = st.columns([3,1])
    with col1:
        st.markdown("# {}".format(information(st.session_state.play, t_p)))
        st.session_state["board"] = streamlit_image_coordinates(
            st.session_state.play.render_board(),
            key="brd",
            use_column_width=True,
            click_and_drag=True)
        print(st.session_state.board)
    with col2:
        st.button("Reset", on_click=reset, args=[t_p])
        st.markdown("#### You are playing as {}!".format("white" if st.session_state.play.get_human_player() == "player_0" else "black"))
        st.markdown("""
                    1. Play by dragging the piece you want to move.
                    2. The piece won't move as you drag, but the move will be registered upon release.
                    3. Illegal moves will not be registered.
                    """)
    if not st.session_state.play.is_done():
        st.session_state.play.run()

def main(kwargs=None):
    css='''
    <style>
        section.main > div {max-width:75rem}
    </style>
    '''
    st.markdown(css, unsafe_allow_html=True)

    if "t_p" not in st.session_state:
        st.session_state.t_p = TrackPlay()
    placeholder = st.empty()

    if not st.session_state.t_p.active_game:
        # st.cache_resource.clear()
        with placeholder.container():
            st.session_state.t_p.bot_select = st.selectbox(
                label="bot",
                options=list(BOT_SELECITONS.keys()),
            )
            st.session_state.t_p.active_game = st.button("Play bot!")
            if st.session_state.t_p.active_game:
                st.rerun()
    else:
        with placeholder.container():
            play(st.session_state.t_p.bot_select, st.session_state.t_p)

if __name__ == "__main__":
    main()