mzimm003 commited on
Commit
dea0b2e
·
1 Parent(s): 22beef3

Cache bots instead of game

Browse files
Files changed (2) hide show
  1. app.py +10 -18
  2. launch.py +1 -1
app.py CHANGED
@@ -5,19 +5,11 @@ from chessmodels import DCMinMax
5
  import streamlit as st
6
  from streamlit_image_coordinates import streamlit_image_coordinates
7
 
8
- class HvB(HumanVsBot):
9
- @staticmethod
10
- @st.cache_resource()
11
- def get_hvb_manager(
12
- _model,
13
- _environment,
14
- _extra_model_environment_context,
15
- **kwargs):
16
- return HvB(
17
- model = _model,
18
- environment = _environment,
19
- extra_model_environment_context = _extra_model_environment_context,
20
- **kwargs)
21
 
22
  @st.cache_resource
23
  def track_play():
@@ -46,10 +38,10 @@ def reset(t_p):
46
  st.cache_resource.clear()
47
 
48
  def play(bot_select, t_p):
49
- play = HvB.get_hvb_manager(
50
- _model=DCMinMax.from_pretrained("mzimm003/{}".format(bot_select)),
51
- _environment=Chess(render_mode="rgb_array"),
52
- _extra_model_environment_context=lambda env: {"board":env.board}
53
  )
54
  col1, col2 = st.columns([3,1])
55
  with col1:
@@ -81,7 +73,7 @@ def main(kwargs=None):
81
  with placeholder.container():
82
  t_p.bot_select = st.selectbox(
83
  label="bot",
84
- options=["DeepChessReplicationMinMax"],
85
  )
86
  t_p.active_game = st.button("Play bot!")
87
  if t_p.active_game:
 
5
  import streamlit as st
6
  from streamlit_image_coordinates import streamlit_image_coordinates
7
 
8
+ BOT_SELECITONS = ["DeepChessReplicationMinMax"]
9
+
10
+ @st.cache_resource
11
+ def load_model(bot_select):
12
+ return DCMinMax.from_pretrained("mzimm003/{}".format(bot_select))
 
 
 
 
 
 
 
 
13
 
14
  @st.cache_resource
15
  def track_play():
 
38
  st.cache_resource.clear()
39
 
40
  def play(bot_select, t_p):
41
+ play = HumanVsBot(
42
+ model=load_model(bot_select=bot_select),
43
+ environment=Chess(render_mode="rgb_array"),
44
+ extra_model_environment_context=lambda env: {"board":env.board},
45
  )
46
  col1, col2 = st.columns([3,1])
47
  with col1:
 
73
  with placeholder.container():
74
  t_p.bot_select = st.selectbox(
75
  label="bot",
76
+ options=BOT_SELECITONS,
77
  )
78
  t_p.active_game = st.button("Play bot!")
79
  if t_p.active_game:
launch.py CHANGED
@@ -1,4 +1,4 @@
1
  from streamlit.testing.v1 import AppTest
2
 
3
- at = AppTest.from_file("/app/app.py", default_timeout=100)
4
  at.run()
 
1
  from streamlit.testing.v1 import AppTest
2
 
3
+ at = AppTest.from_file("app.py", default_timeout=100)
4
  at.run()