Spaces:
Sleeping
Sleeping
Add random play bot.
Browse files
app.py
CHANGED
@@ -5,11 +5,30 @@ from chessmodels import DCMinMax
|
|
5 |
import streamlit as st
|
6 |
from streamlit_image_coordinates import streamlit_image_coordinates
|
7 |
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
@st.cache_resource(max_entries=4)
|
11 |
def load_model(bot_select):
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
class TrackPlay:
|
15 |
def __init__(self):
|
@@ -73,7 +92,7 @@ def main(kwargs=None):
|
|
73 |
with placeholder.container():
|
74 |
st.session_state.t_p.bot_select = st.selectbox(
|
75 |
label="bot",
|
76 |
-
options=BOT_SELECITONS,
|
77 |
)
|
78 |
st.session_state.t_p.active_game = st.button("Play bot!")
|
79 |
if st.session_state.t_p.active_game:
|
|
|
5 |
import streamlit as st
|
6 |
from streamlit_image_coordinates import streamlit_image_coordinates
|
7 |
|
8 |
+
import torch
|
9 |
+
|
10 |
+
class RandomPlayer(torch.nn.Module):
|
11 |
+
def __init__(self) -> None:
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
def forward(self, board, input):
|
15 |
+
options = torch.arange(input['action_mask'].size)[input['action_mask'].astype(bool)]
|
16 |
+
choice = torch.randint(options.numel(), (1,))
|
17 |
+
return options[choice].item()
|
18 |
+
|
19 |
+
BOT_SELECITONS = {
|
20 |
+
"RandomMover":RandomPlayer,
|
21 |
+
"DeepChessReplicationMinMax":DCMinMax,
|
22 |
+
}
|
23 |
|
24 |
@st.cache_resource(max_entries=4)
|
25 |
def load_model(bot_select):
|
26 |
+
model = None
|
27 |
+
if bot_select in ["RandomMover"]:
|
28 |
+
model = BOT_SELECITONS[bot_select]()
|
29 |
+
else:
|
30 |
+
model = BOT_SELECITONS[bot_select].from_pretrained("mzimm003/{}".format(bot_select))
|
31 |
+
return model
|
32 |
|
33 |
class TrackPlay:
|
34 |
def __init__(self):
|
|
|
92 |
with placeholder.container():
|
93 |
st.session_state.t_p.bot_select = st.selectbox(
|
94 |
label="bot",
|
95 |
+
options=list(BOT_SELECITONS.keys()),
|
96 |
)
|
97 |
st.session_state.t_p.active_game = st.button("Play bot!")
|
98 |
if st.session_state.t_p.active_game:
|