mzimm003 commited on
Commit
a4ba045
·
1 Parent(s): 0d250f3

Add random play bot.

Browse files
Files changed (1) hide show
  1. app.py +22 -3
app.py CHANGED
@@ -5,11 +5,30 @@ from chessmodels import DCMinMax
5
  import streamlit as st
6
  from streamlit_image_coordinates import streamlit_image_coordinates
7
 
8
- BOT_SELECITONS = ["DeepChessReplicationMinMax"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  @st.cache_resource(max_entries=4)
11
  def load_model(bot_select):
12
- return DCMinMax.from_pretrained("mzimm003/{}".format(bot_select))
 
 
 
 
 
13
 
14
  class TrackPlay:
15
  def __init__(self):
@@ -73,7 +92,7 @@ def main(kwargs=None):
73
  with placeholder.container():
74
  st.session_state.t_p.bot_select = st.selectbox(
75
  label="bot",
76
- options=BOT_SELECITONS,
77
  )
78
  st.session_state.t_p.active_game = st.button("Play bot!")
79
  if st.session_state.t_p.active_game:
 
5
  import streamlit as st
6
  from streamlit_image_coordinates import streamlit_image_coordinates
7
 
8
+ import torch
9
+
10
+ class RandomPlayer(torch.nn.Module):
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+
14
+ def forward(self, board, input):
15
+ options = torch.arange(input['action_mask'].size)[input['action_mask'].astype(bool)]
16
+ choice = torch.randint(options.numel(), (1,))
17
+ return options[choice].item()
18
+
19
+ BOT_SELECITONS = {
20
+ "RandomMover":RandomPlayer,
21
+ "DeepChessReplicationMinMax":DCMinMax,
22
+ }
23
 
24
  @st.cache_resource(max_entries=4)
25
  def load_model(bot_select):
26
+ model = None
27
+ if bot_select in ["RandomMover"]:
28
+ model = BOT_SELECITONS[bot_select]()
29
+ else:
30
+ model = BOT_SELECITONS[bot_select].from_pretrained("mzimm003/{}".format(bot_select))
31
+ return model
32
 
33
  class TrackPlay:
34
  def __init__(self):
 
92
  with placeholder.container():
93
  st.session_state.t_p.bot_select = st.selectbox(
94
  label="bot",
95
+ options=list(BOT_SELECITONS.keys()),
96
  )
97
  st.session_state.t_p.active_game = st.button("Play bot!")
98
  if st.session_state.t_p.active_game: