Spaces:
Sleeping
Sleeping
File size: 3,684 Bytes
d288960 770f950 0154f58 e7988f7 a529397 b3c5de3 a4ba045 dea0b2e c609ed6 dea0b2e a4ba045 770f950 c609ed6 22beef3 74c4804 22beef3 c609ed6 22beef3 c609ed6 a529397 22beef3 c609ed6 22beef3 c609ed6 22beef3 c609ed6 770f950 e7988f7 22beef3 c609ed6 22beef3 c609ed6 22beef3 c609ed6 22beef3 a4ba045 22beef3 c609ed6 22beef3 c609ed6 e7988f7 770f950 22beef3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from my_chess.scripts.scripts import HumanVsBot
from my_chess.learner.environments import Chess
from chessmodels import DCMinMax
import streamlit as st
from s_i_c import streamlit_image_coordinates
import torch
class RandomPlayer(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, board, input):
options = torch.arange(input['action_mask'].size)[input['action_mask'].astype(bool)]
choice = torch.randint(options.numel(), (1,))
return options[choice].item()
BOT_SELECITONS = {
"RandomMover":RandomPlayer,
"DeepChessReplicationMinMax":DCMinMax,
}
@st.cache_resource(max_entries=4)
def load_model(bot_select):
model = None
if bot_select in ["RandomMover"]:
model = BOT_SELECITONS[bot_select]()
else:
model = BOT_SELECITONS[bot_select].from_pretrained("mzimm003/{}".format(bot_select))
return model
class TrackPlay:
def __init__(self):
self.active_game = False
self.bot_select = None
def reset(self):
self.active_game = False
self.bot_select = None
def information(play, t_p):
display = {
True:{True:"Your turn!",False:"Bot is thinking..."},
False:{"1-0":"White Wins!","0-1":"Black Wins!","1/2-1/2":"It's a Draw!"},
}
hum_turn = play.get_human_player() == play.get_curr_player()
game_over = play.is_done()
choose1 = not game_over
choose2 = play.get_result() if game_over else hum_turn
return display[choose1][choose2]
def reset(t_p):
t_p.reset()
del st.session_state["play"]
def play(bot_select, t_p):
if "play" not in st.session_state:
st.session_state.play = HumanVsBot(
model=load_model(bot_select=bot_select),
environment=Chess(render_mode="rgb_array"),
extra_model_environment_context=lambda env: {"board":env.board},
)
col1, col2 = st.columns([3,1])
with col1:
st.markdown("# {}".format(information(st.session_state.play, t_p)))
st.session_state["board"] = streamlit_image_coordinates(
st.session_state.play.render_board(),
key="brd",
use_column_width=True,
click_and_drag=True)
print(st.session_state.board)
with col2:
st.button("Reset", on_click=reset, args=[t_p])
st.markdown("#### You are playing as {}!".format("white" if st.session_state.play.get_human_player() == "player_0" else "black"))
st.markdown("""
1. Play by dragging the piece you want to move.
2. The piece won't move as you drag, but the move will be registered upon release.
3. Illegal moves will not be registered.
""")
if not st.session_state.play.is_done():
st.session_state.play.run()
def main(kwargs=None):
css='''
<style>
section.main > div {max-width:75rem}
</style>
'''
st.markdown(css, unsafe_allow_html=True)
if "t_p" not in st.session_state:
st.session_state.t_p = TrackPlay()
placeholder = st.empty()
if not st.session_state.t_p.active_game:
# st.cache_resource.clear()
with placeholder.container():
st.session_state.t_p.bot_select = st.selectbox(
label="bot",
options=list(BOT_SELECITONS.keys()),
)
st.session_state.t_p.active_game = st.button("Play bot!")
if st.session_state.t_p.active_game:
st.rerun()
else:
with placeholder.container():
play(st.session_state.t_p.bot_select, st.session_state.t_p)
if __name__ == "__main__":
main() |