Andrei Cozma commited on
Commit
d922c89
·
1 Parent(s): 9c2fd5e
Files changed (1) hide show
  1. demo.py +25 -19
demo.py CHANGED
@@ -33,10 +33,17 @@ action_map = {
33
  },
34
  }
35
 
 
 
 
 
 
 
 
36
  # Global variables to allow changing it on the fly
37
  live_render_fps = 5
38
  live_epsilon = 0.0
39
- live_paused = False
40
  live_steps_forward = None
41
 
42
 
@@ -54,14 +61,10 @@ def change_epsilon(x):
54
 
55
  def change_paused(x):
56
  print("Changing paused:", x)
57
- val_map = {
58
- "▶️ Resume": False,
59
- "⏸️ Pause": True,
60
- }
61
- val_map_inv = {v: k for k, v in val_map.items()}
62
  global live_paused
63
- live_paused = val_map[x]
64
- next_val = val_map_inv[not live_paused]
65
  return gr.update(value=next_val), gr.update(interactive=live_paused)
66
 
67
 
@@ -77,6 +80,7 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
77
  global live_render_fps, live_epsilon, live_paused, live_steps_forward
78
  live_render_fps = render_fps
79
  live_epsilon = epsilon
 
80
  print("=" * 80)
81
  print("Running...")
82
  print(f"- policy_fname: {policy_fname}")
@@ -123,14 +127,6 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
123
  max_steps=max_steps, render=True, override_epsilon=True
124
  )
125
  ):
126
- if live_steps_forward is not None:
127
- if live_steps_forward > 0:
128
- live_steps_forward -= 1
129
-
130
- if live_steps_forward == 0:
131
- live_steps_forward = None
132
- live_paused = True
133
-
134
  _, _, last_reward = (
135
  episode_hist[-2] if len(episode_hist) > 1 else (None, None, None)
136
  )
@@ -201,7 +197,15 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
201
  step
202
  ), state, action, last_reward, "Running..."
203
 
204
- time.sleep(1 / live_render_fps)
 
 
 
 
 
 
 
 
205
 
206
  while live_paused and live_steps_forward is None:
207
  yield agent_type, env_name, rgb_array, policy_viz, ep_str(
@@ -250,7 +254,7 @@ with gr.Blocks(title="CS581 Demo") as demo:
250
  label="Max steps per episode",
251
  )
252
 
253
- btn_run = gr.components.Button("▶️ Start", interactive=bool(all_policies))
254
 
255
  gr.components.HTML("<h2>Live Statistics & Policy Visualization:</h2>")
256
  with gr.Row():
@@ -292,7 +296,9 @@ with gr.Blocks(title="CS581 Demo") as demo:
292
  )
293
 
294
  with gr.Row():
295
- btn_pause = gr.components.Button("⏸️ Pause", interactive=True)
 
 
296
  btn_forward = gr.components.Button("⏩ Step", interactive=False)
297
 
298
  btn_pause.click(
 
33
  },
34
  }
35
 
36
+
37
+ pause_val_map = {
38
+ "▶️ Resume": False,
39
+ "⏸️ Pause": True,
40
+ }
41
+ pause_val_map_inv = {v: k for k, v in pause_val_map.items()}
42
+
43
  # Global variables to allow changing it on the fly
44
  live_render_fps = 5
45
  live_epsilon = 0.0
46
+ live_paused = True
47
  live_steps_forward = None
48
 
49
 
 
61
 
62
  def change_paused(x):
63
  print("Changing paused:", x)
64
+
 
 
 
 
65
  global live_paused
66
+ live_paused = pause_val_map[x]
67
+ next_val = pause_val_map_inv[not live_paused]
68
  return gr.update(value=next_val), gr.update(interactive=live_paused)
69
 
70
 
 
80
  global live_render_fps, live_epsilon, live_paused, live_steps_forward
81
  live_render_fps = render_fps
82
  live_epsilon = epsilon
83
+ live_steps_forward = None
84
  print("=" * 80)
85
  print("Running...")
86
  print(f"- policy_fname: {policy_fname}")
 
127
  max_steps=max_steps, render=True, override_epsilon=True
128
  )
129
  ):
 
 
 
 
 
 
 
 
130
  _, _, last_reward = (
131
  episode_hist[-2] if len(episode_hist) > 1 else (None, None, None)
132
  )
 
197
  step
198
  ), state, action, last_reward, "Running..."
199
 
200
+ if live_steps_forward is not None:
201
+ if live_steps_forward > 0:
202
+ live_steps_forward -= 1
203
+
204
+ if live_steps_forward == 0:
205
+ live_steps_forward = None
206
+ live_paused = True
207
+ else:
208
+ time.sleep(1 / live_render_fps)
209
 
210
  while live_paused and live_steps_forward is None:
211
  yield agent_type, env_name, rgb_array, policy_viz, ep_str(
 
254
  label="Max steps per episode",
255
  )
256
 
257
+ btn_run = gr.components.Button("👀 Select", interactive=bool(all_policies))
258
 
259
  gr.components.HTML("<h2>Live Statistics & Policy Visualization:</h2>")
260
  with gr.Row():
 
296
  )
297
 
298
  with gr.Row():
299
+ btn_pause = gr.components.Button(
300
+ pause_val_map_inv[not live_paused], interactive=True
301
+ )
302
  btn_forward = gr.components.Button("⏩ Step", interactive=False)
303
 
304
  btn_pause.click(