Andrei Cozma commited on
Commit
adada5a
·
1 Parent(s): 6bb60fd
Files changed (1) hide show
  1. demo.py +34 -11
demo.py CHANGED
@@ -64,7 +64,7 @@ class RunState:
64
  self.should_reset = False
65
 
66
 
67
- def reset(state, policy_fname):
68
  if state.current_policy is not None and state.current_policy != policy_fname:
69
  state.should_reset = True
70
  state.live_paused = default_paused
@@ -78,6 +78,19 @@ def reset(state, policy_fname):
78
  )
79
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def change_render_fps(state, x):
82
  print("Changing render fps:", x)
83
  state.live_render_fps = x
@@ -292,16 +305,16 @@ def run(
292
  localstate.should_reset = False
293
  yield (
294
  localstate,
295
- agent_key,
296
- env_key,
297
  np.ones((frame_env_h, frame_env_w, 3)),
298
  np.ones((frame_policy_h, frame_policy_res)),
299
- ep_str(episode + 1),
300
- ep_str(episodes_solved),
301
- step_str(step),
302
- state,
303
- action,
304
- last_reward,
305
  "Reset...",
306
  )
307
  return
@@ -358,7 +371,11 @@ with gr.Blocks(title="CS581 Demo") as demo:
358
  label="Max steps per episode",
359
  )
360
 
361
- btn_run = gr.components.Button("👀 Select & Load", interactive=bool(all_policies))
 
 
 
 
362
 
363
  gr.components.HTML("<h2>Live Visualization & Information:</h2>")
364
  with gr.Row():
@@ -448,11 +465,17 @@ with gr.Blocks(title="CS581 Demo") as demo:
448
  )
449
 
450
  input_policy.change(
451
- fn=reset,
452
  inputs=[localstate, input_policy],
453
  outputs=[localstate, btn_pause, btn_forward],
454
  )
455
 
 
 
 
 
 
 
456
  btn_run.click(
457
  fn=run,
458
  inputs=[
 
64
  self.should_reset = False
65
 
66
 
67
+ def reset_change(state, policy_fname):
68
  if state.current_policy is not None and state.current_policy != policy_fname:
69
  state.should_reset = True
70
  state.live_paused = default_paused
 
78
  )
79
 
80
 
81
+ def reset_click(state):
82
+ state.should_reset = True
83
+ state.live_paused = default_paused
84
+ state.live_render_fps = default_render_fps
85
+ state.live_epsilon = default_epsilon
86
+ state.live_steps_forward = None
87
+ return (
88
+ state,
89
+ gr.update(value=pause_val_map_inv[not state.live_paused]),
90
+ gr.update(interactive=state.live_paused),
91
+ )
92
+
93
+
94
  def change_render_fps(state, x):
95
  print("Changing render fps:", x)
96
  state.live_render_fps = x
 
305
  localstate.should_reset = False
306
  yield (
307
  localstate,
308
+ None,
309
+ None,
310
  np.ones((frame_env_h, frame_env_w, 3)),
311
  np.ones((frame_policy_h, frame_policy_res)),
312
+ None,
313
+ None,
314
+ None,
315
+ None,
316
+ None,
317
+ None,
318
  "Reset...",
319
  )
320
  return
 
371
  label="Max steps per episode",
372
  )
373
 
374
+ with gr.Row():
375
+ btn_run = gr.components.Button(
376
+ "👀 Select & Load", interactive=bool(all_policies)
377
+ )
378
+ btn_clear = gr.components.Button("🗑️ Clear", interactive=bool(all_policies))
379
 
380
  gr.components.HTML("<h2>Live Visualization & Information:</h2>")
381
  with gr.Row():
 
465
  )
466
 
467
  input_policy.change(
468
+ fn=reset_change,
469
  inputs=[localstate, input_policy],
470
  outputs=[localstate, btn_pause, btn_forward],
471
  )
472
 
473
+ btn_clear.click(
474
+ fn=reset_click,
475
+ inputs=[localstate],
476
+ outputs=[localstate, btn_pause, btn_forward],
477
+ )
478
+
479
  btn_run.click(
480
  fn=run,
481
  inputs=[