from my_chess.scripts.scripts import HumanVsBot from my_chess.learner.environments import Chess from chessmodels import DCMinMax import streamlit as st from s_i_c import streamlit_image_coordinates import torch class RandomPlayer(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, board, input): options = torch.arange(input['action_mask'].size)[input['action_mask'].astype(bool)] choice = torch.randint(options.numel(), (1,)) return options[choice].item() BOT_SELECITONS = { "RandomMover":RandomPlayer, "DeepChessReplicationMinMax":DCMinMax, } @st.cache_resource(max_entries=4) def load_model(bot_select): model = None if bot_select in ["RandomMover"]: model = BOT_SELECITONS[bot_select]() else: model = BOT_SELECITONS[bot_select].from_pretrained("mzimm003/{}".format(bot_select)) return model class TrackPlay: def __init__(self): self.active_game = False self.bot_select = None def reset(self): self.active_game = False self.bot_select = None def information(play, t_p): display = { True:{True:"Your turn!",False:"Bot is thinking..."}, False:{"1-0":"White Wins!","0-1":"Black Wins!","1/2-1/2":"It's a Draw!"}, } hum_turn = play.get_human_player() == play.get_curr_player() game_over = play.is_done() choose1 = not game_over choose2 = play.get_result() if game_over else hum_turn return display[choose1][choose2] def reset(t_p): t_p.reset() del st.session_state["play"] def play(bot_select, t_p): if "play" not in st.session_state: st.session_state.play = HumanVsBot( model=load_model(bot_select=bot_select), environment=Chess(render_mode="rgb_array"), extra_model_environment_context=lambda env: {"board":env.board}, ) col1, col2 = st.columns([3,1]) with col1: st.markdown("# {}".format(information(st.session_state.play, t_p))) st.session_state["board"] = streamlit_image_coordinates( st.session_state.play.render_board(), key="brd", use_column_width=True, click_and_drag=True) print(st.session_state.board) with col2: st.button("Reset", on_click=reset, args=[t_p]) st.markdown("#### You are playing as {}!".format("white" if st.session_state.play.get_human_player() == "player_0" else "black")) st.markdown(""" 1. Play by dragging the piece you want to move. 2. The piece won't move as you drag, but the move will be registered upon release. 3. Illegal moves will not be registered. """) if not st.session_state.play.is_done(): st.session_state.play.run() def main(kwargs=None): css=''' ''' st.markdown(css, unsafe_allow_html=True) if "t_p" not in st.session_state: st.session_state.t_p = TrackPlay() placeholder = st.empty() if not st.session_state.t_p.active_game: # st.cache_resource.clear() with placeholder.container(): st.session_state.t_p.bot_select = st.selectbox( label="bot", options=list(BOT_SELECITONS.keys()), ) st.session_state.t_p.active_game = st.button("Play bot!") if st.session_state.t_p.active_game: st.rerun() else: with placeholder.container(): play(st.session_state.t_p.bot_select, st.session_state.t_p) if __name__ == "__main__": main()