PlayChessVsAI / app.py
mzimm003's picture
Allow multiple, distinct, sessions on server.
c609ed6
raw
history blame
3.04 kB
from my_chess.scripts.scripts import HumanVsBot
from my_chess.learner.environments import Chess
from chessmodels import DCMinMax
import streamlit as st
from streamlit_image_coordinates import streamlit_image_coordinates
BOT_SELECITONS = ["DeepChessReplicationMinMax"]
@st.cache_resource(max_entries=4)
def load_model(bot_select):
return DCMinMax.from_pretrained("mzimm003/{}".format(bot_select))
class TrackPlay:
def __init__(self):
self.active_game = False
self.bot_select = None
def reset(self):
self.active_game = False
self.bot_select = None
def information(play, t_p):
display = {
True:{True:"Your turn!",False:"Bot is thinking..."},
False:{"1-0":"White Wins!","0-1":"Black Wins!","1/2-1/2":"It's a Draw!"},
}
hum_turn = play.get_human_player() == play.get_curr_player()
game_over = play.is_done()
choose1 = not game_over
choose2 = play.get_result() if game_over else hum_turn
return display[choose1][choose2]
def reset(t_p):
t_p.reset()
st.cache_resource.clear()
def play(bot_select, t_p):
if "play" not in st.session_state:
st.session_state.play = HumanVsBot(
model=load_model(bot_select=bot_select),
environment=Chess(render_mode="rgb_array"),
extra_model_environment_context=lambda env: {"board":env.board},
)
col1, col2 = st.columns([3,1])
with col1:
st.markdown("# {}".format(information(st.session_state.play, t_p)))
st.session_state["board"] = streamlit_image_coordinates(st.session_state.play.render_board(), key="brd", click_and_drag=True)
with col2:
st.button("Reset", on_click=reset, args=[t_p])
st.markdown("#### You are playing as {}!".format("white" if st.session_state.play.get_human_player() == "player_0" else "black"))
st.markdown("""
1. Play by dragging the piece you want to move.
2. The piece won't move as you drag, but the move will be registered upon release.
3. Illegal moves will not be registered.
""")
if not st.session_state.play.is_done():
st.session_state.play.run()
def main(kwargs=None):
css='''
<style>
section.main > div {max-width:75rem}
</style>
'''
st.markdown(css, unsafe_allow_html=True)
if "t_p" not in st.session_state:
st.session_state.t_p = TrackPlay()
placeholder = st.empty()
if not st.session_state.t_p.active_game:
# st.cache_resource.clear()
with placeholder.container():
st.session_state.t_p.bot_select = st.selectbox(
label="bot",
options=BOT_SELECITONS,
)
st.session_state.t_p.active_game = st.button("Play bot!")
if st.session_state.t_p.active_game:
st.rerun()
else:
with placeholder.container():
play(st.session_state.t_p.bot_select, st.session_state.t_p)
if __name__ == "__main__":
main()