Spaces:
Sleeping
Sleeping
from my_chess.scripts.scripts import HumanVsBot | |
from my_chess.learner.environments import Chess | |
from chessmodels import DCMinMax | |
import streamlit as st | |
from s_i_c import streamlit_image_coordinates | |
import torch | |
class RandomPlayer(torch.nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
def forward(self, board, input): | |
options = torch.arange(input['action_mask'].size)[input['action_mask'].astype(bool)] | |
choice = torch.randint(options.numel(), (1,)) | |
return options[choice].item() | |
BOT_SELECITONS = { | |
"RandomMover":RandomPlayer, | |
"DeepChessReplicationMinMax":DCMinMax, | |
} | |
def load_model(bot_select): | |
model = None | |
if bot_select in ["RandomMover"]: | |
model = BOT_SELECITONS[bot_select]() | |
else: | |
model = BOT_SELECITONS[bot_select].from_pretrained("mzimm003/{}".format(bot_select)) | |
return model | |
class TrackPlay: | |
def __init__(self): | |
self.active_game = False | |
self.bot_select = None | |
def reset(self): | |
self.active_game = False | |
self.bot_select = None | |
def information(play, t_p): | |
display = { | |
True:{True:"Your turn!",False:"Bot is thinking..."}, | |
False:{"1-0":"White Wins!","0-1":"Black Wins!","1/2-1/2":"It's a Draw!"}, | |
} | |
hum_turn = play.get_human_player() == play.get_curr_player() | |
game_over = play.is_done() | |
choose1 = not game_over | |
choose2 = play.get_result() if game_over else hum_turn | |
return display[choose1][choose2] | |
def reset(t_p): | |
t_p.reset() | |
del st.session_state["play"] | |
def play(bot_select, t_p): | |
if "play" not in st.session_state: | |
st.session_state.play = HumanVsBot( | |
model=load_model(bot_select=bot_select), | |
environment=Chess(render_mode="rgb_array"), | |
extra_model_environment_context=lambda env: {"board":env.board}, | |
) | |
col1, col2 = st.columns([3,1]) | |
with col1: | |
st.markdown("# {}".format(information(st.session_state.play, t_p))) | |
st.session_state["board"] = streamlit_image_coordinates( | |
st.session_state.play.render_board(), | |
key="brd", | |
use_column_width=True, | |
click_and_drag=True) | |
print(st.session_state.board) | |
with col2: | |
st.button("Reset", on_click=reset, args=[t_p]) | |
st.markdown("#### You are playing as {}!".format("white" if st.session_state.play.get_human_player() == "player_0" else "black")) | |
st.markdown(""" | |
1. Play by dragging the piece you want to move. | |
2. The piece won't move as you drag, but the move will be registered upon release. | |
3. Illegal moves will not be registered. | |
""") | |
if not st.session_state.play.is_done(): | |
st.session_state.play.run() | |
def main(kwargs=None): | |
css=''' | |
<style> | |
section.main > div {max-width:75rem} | |
</style> | |
''' | |
st.markdown(css, unsafe_allow_html=True) | |
if "t_p" not in st.session_state: | |
st.session_state.t_p = TrackPlay() | |
placeholder = st.empty() | |
if not st.session_state.t_p.active_game: | |
# st.cache_resource.clear() | |
with placeholder.container(): | |
st.session_state.t_p.bot_select = st.selectbox( | |
label="bot", | |
options=list(BOT_SELECITONS.keys()), | |
) | |
st.session_state.t_p.active_game = st.button("Play bot!") | |
if st.session_state.t_p.active_game: | |
st.rerun() | |
else: | |
with placeholder.container(): | |
play(st.session_state.t_p.bot_select, st.session_state.t_p) | |
if __name__ == "__main__": | |
main() |