Andrei Cozma commited on
Commit
f221e91
·
1 Parent(s): 261964b
Files changed (1) hide show
  1. demo.py +127 -89
demo.py CHANGED
@@ -15,7 +15,7 @@ default_epsilon = 0.0
15
  default_paused = True
16
 
17
  frame_env_h, frame_env_w = 512, 768
18
- frame_policy_res = 384
19
 
20
  # For the dropdown list of policies
21
  policies_folder = "policies"
@@ -51,75 +51,80 @@ pause_val_map = {
51
  pause_val_map_inv = {v: k for k, v in pause_val_map.items()}
52
 
53
  # Global variables to allow changing it on the fly
54
- current_policy = None
55
- live_render_fps = default_render_fps
56
- live_epsilon = default_epsilon
57
- live_paused = default_paused
58
- live_steps_forward = None
59
- should_reset = False
60
-
61
-
62
- def reset(policy_fname):
63
- global current_policy, live_render_fps, live_epsilon, live_paused, live_steps_forward, should_reset
64
- if current_policy is not None and current_policy != policy_fname:
65
- should_reset = True
66
- live_paused = default_paused
67
- live_render_fps = default_render_fps
68
- live_epsilon = default_epsilon
69
- live_steps_forward = None
70
- return gr.update(value=pause_val_map_inv[not live_paused]), gr.update(
71
- interactive=live_paused
 
 
 
72
  )
73
 
74
 
75
- def change_render_fps(x):
76
  print("Changing render fps:", x)
77
- global live_render_fps
78
- live_render_fps = x
79
 
80
 
81
- def change_epsilon(x):
82
  print("Changing greediness:", x)
83
- global live_epsilon
84
- live_epsilon = x
85
 
86
 
87
- def change_paused(x):
88
  print("Changing paused:", x)
89
- global live_paused
90
- live_paused = pause_val_map[x]
91
- return gr.update(value=pause_val_map_inv[not live_paused]), gr.update(
92
- interactive=live_paused
 
93
  )
94
 
95
 
96
- def onclick_btn_forward():
97
  print("Step forward")
98
- global live_steps_forward
99
- if live_steps_forward is None:
100
- live_steps_forward = 0
101
- live_steps_forward += 1
102
-
103
-
104
- def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
105
- global current_policy, live_render_fps, live_epsilon, live_paused, live_steps_forward, should_reset
106
- current_policy = policy_fname
107
- live_render_fps = render_fps
108
- live_epsilon = epsilon
109
- live_steps_forward = None
 
110
  print("=" * 80)
111
  print("Running...")
112
- print(f"- policy_fname: {policy_fname}")
113
  print(f"- n_test_episodes: {n_test_episodes}")
114
  print(f"- max_steps: {max_steps}")
115
- print(f"- render_fps: {live_render_fps}")
116
- print(f"- epsilon: {live_epsilon}")
117
 
118
  policy_path = os.path.join(policies_folder, policy_fname)
119
  props = policy_fname.split("_")
120
 
121
  if len(props) < 2:
122
- yield None, None, None, None, None, None, None, None, None, None, "🚫 Please select a valid policy file."
123
  return
124
 
125
  agent_type, env_name = props[0], props[1]
@@ -152,7 +157,9 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
152
 
153
  for step, (episode_hist, solved, frame_env) in enumerate(
154
  agent.generate_episode(
155
- max_steps=max_steps, render=True, epsilon_override=live_epsilon
 
 
156
  )
157
  ):
158
  _, _, last_reward = (
@@ -173,26 +180,34 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
173
 
174
  frame_policy = scipy.ndimage.gaussian_filter(frame_policy, sigma=1.0)
175
  frame_policy = np.clip(
176
- frame_policy * (1.0 - live_epsilon) + live_epsilon / len(curr_policy),
 
177
  0.0,
178
  1.0,
179
  )
180
-
181
- label_loc_h, label_loc_w =frame_policy_h // 2, int((action + 0.5) * frame_policy_res // len(curr_policy))
182
-
 
 
183
  frame_policy_label_color = 1.0 - frame_policy[label_loc_h, label_loc_w]
184
  frame_policy_label_font = cv2.FONT_HERSHEY_SIMPLEX
185
  frame_policy_label_thicc = 1
186
- action_text_scale, action_text_label_scale = 0.8, 0.5
187
-
188
- (label_width, _), _ = cv2.getTextSize(str(action), frame_policy_label_font, action_text_scale, frame_policy_label_thicc)
 
 
 
 
 
189
 
190
  cv2.putText(
191
  frame_policy,
192
  str(action),
193
  (
194
  label_loc_w - label_width // 2,
195
- label_loc_h,
196
  ),
197
  frame_policy_label_font,
198
  action_text_scale,
@@ -203,14 +218,21 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
203
 
204
  if env_action_map:
205
  action_name = env_action_map.get(action, "")
206
- (label_width, _), _ = cv2.getTextSize(action_name, frame_policy_label_font, action_text_label_scale, frame_policy_label_thicc)
207
-
 
 
 
 
 
208
  cv2.putText(
209
  frame_policy,
210
  action_name,
211
  (
212
  int(label_loc_w - label_width / 2),
213
- label_loc_h + 25,
 
 
214
  ),
215
  frame_policy_label_font,
216
  action_text_label_scale,
@@ -220,38 +242,39 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
220
  )
221
 
222
  print(
223
- f"Episode: {ep_str(episode + 1)} - step: {step_str(step)} - state: {state} - action: {action} - reward: {reward} (epsilon: {live_epsilon:.2f}) (frame time: {1 / live_render_fps:.2f}s)"
224
  )
225
 
226
- yield agent_type, env_name, frame_env, frame_policy, ep_str(
227
  episode + 1
228
  ), ep_str(episodes_solved), step_str(
229
  step
230
  ), state, action, last_reward, "Running..."
231
 
232
- if live_steps_forward is not None:
233
- if live_steps_forward > 0:
234
- live_steps_forward -= 1
235
 
236
- if live_steps_forward == 0:
237
- live_steps_forward = None
238
- live_paused = True
239
  else:
240
- time.sleep(1 / live_render_fps)
241
 
242
- while live_paused and live_steps_forward is None:
243
- yield agent_type, env_name, frame_env, frame_policy, ep_str(
244
  episode + 1
245
  ), ep_str(episodes_solved), step_str(
246
  step
247
  ), state, action, last_reward, "Paused..."
248
- time.sleep(1 / live_render_fps)
249
- if should_reset is True:
250
  break
251
 
252
- if should_reset is True:
253
- should_reset = False
254
  yield (
 
255
  agent_type,
256
  env_name,
257
  np.ones((frame_env_h, frame_env_w, 3)),
@@ -271,10 +294,10 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
271
 
272
  time.sleep(0.25)
273
 
274
- current_policy = None
275
- yield agent_type, env_name, frame_env, frame_policy, ep_str(episode + 1), ep_str(
276
- episodes_solved
277
- ), step_str(step), state, action, reward, "Done!"
278
 
279
 
280
  with gr.Blocks(title="CS581 Demo") as demo:
@@ -282,6 +305,8 @@ with gr.Blocks(title="CS581 Demo") as demo:
282
  "<h1>CS581 Final Project Demo - Dynamic Programming & Monte-Carlo RL Methods (<a href='https://huggingface.co/spaces/acozma/CS581-Algos-Demo'>HF Space</a>)</h1>"
283
  )
284
 
 
 
285
  gr.components.HTML("<h2>Select Configuration:</h2>")
286
  with gr.Row():
287
  input_policy = gr.components.Dropdown(
@@ -333,17 +358,26 @@ with gr.Blocks(title="CS581 Demo") as demo:
333
  input_epsilon = gr.components.Slider(
334
  minimum=0,
335
  maximum=1,
336
- value=live_epsilon,
337
- step=1/200,
338
  label="Epsilon (0 = greedy, 1 = random)",
339
  )
340
- input_epsilon.change(change_epsilon, inputs=[input_epsilon])
 
 
341
 
342
  input_render_fps = gr.components.Slider(
343
- minimum=1, maximum=60, value=live_render_fps, step=1,
344
- label="Simulation speed (fps)"
 
 
 
 
 
 
 
 
345
  )
346
- input_render_fps.change(change_render_fps, inputs=[input_render_fps])
347
 
348
  out_image_frame = gr.components.Image(
349
  label="Environment",
@@ -354,18 +388,18 @@ with gr.Blocks(title="CS581 Demo") as demo:
354
 
355
  with gr.Row():
356
  btn_pause = gr.components.Button(
357
- pause_val_map_inv[not live_paused], interactive=True
358
  )
359
  btn_forward = gr.components.Button("⏩ Step")
360
 
361
  btn_pause.click(
362
  fn=change_paused,
363
- inputs=[btn_pause],
364
- outputs=[btn_pause, btn_forward],
365
  )
366
 
367
  btn_forward.click(
368
- fn=onclick_btn_forward,
369
  )
370
 
371
  out_msg = gr.components.Textbox(
@@ -376,12 +410,15 @@ with gr.Blocks(title="CS581 Demo") as demo:
376
  )
377
 
378
  input_policy.change(
379
- fn=reset, inputs=[input_policy], outputs=[btn_pause, btn_forward]
 
 
380
  )
381
 
382
  btn_run.click(
383
  fn=run,
384
  inputs=[
 
385
  input_policy,
386
  input_n_test_episodes,
387
  input_max_steps,
@@ -389,6 +426,7 @@ with gr.Blocks(title="CS581 Demo") as demo:
389
  input_epsilon,
390
  ],
391
  outputs=[
 
392
  out_agent,
393
  out_environment,
394
  out_image_frame,
 
15
  default_paused = True
16
 
17
  frame_env_h, frame_env_w = 512, 768
18
+ frame_policy_res = 256
19
 
20
  # For the dropdown list of policies
21
  policies_folder = "policies"
 
51
  pause_val_map_inv = {v: k for k, v in pause_val_map.items()}
52
 
53
  # Global variables to allow changing it on the fly
54
+
55
+
56
+ class RunState:
57
+ def __init__(self) -> None:
58
+ self.current_policy = None
59
+ self.live_render_fps = default_render_fps
60
+ self.live_epsilon = default_epsilon
61
+ self.live_paused = default_paused
62
+ self.live_steps_forward = None
63
+ self.should_reset = False
64
+
65
+
66
+ def reset(state, policy_fname):
67
+ if state.current_policy is not None and state.current_policy != policy_fname:
68
+ state.should_reset = True
69
+ state.live_paused = default_paused
70
+ state.live_render_fps = default_render_fps
71
+ state.live_epsilon = default_epsilon
72
+ state.live_steps_forward = None
73
+ return gr.update(value=pause_val_map_inv[not state.live_paused]), gr.update(
74
+ interactive=state.live_paused
75
  )
76
 
77
 
78
+ def change_render_fps(state, x):
79
  print("Changing render fps:", x)
80
+ state.live_render_fps = x
81
+ return state
82
 
83
 
84
+ def change_epsilon(state, x):
85
  print("Changing greediness:", x)
86
+ state.live_epsilon = x
87
+ return state
88
 
89
 
90
+ def change_paused(state, x):
91
  print("Changing paused:", x)
92
+ state.live_paused = pause_val_map[x]
93
+ return (
94
+ state,
95
+ gr.update(value=pause_val_map_inv[not state.live_paused]),
96
+ gr.update(interactive=state.live_paused),
97
  )
98
 
99
 
100
+ def onclick_btn_forward(state):
101
  print("Step forward")
102
+ if state.live_steps_forward is None:
103
+ state.live_steps_forward = 0
104
+ state.live_steps_forward += 1
105
+ return state
106
+
107
+
108
+ def run(
109
+ localstate: RunState, policy_fname, n_test_episodes, max_steps, render_fps, epsilon
110
+ ):
111
+ localstate.current_policy = policy_fname
112
+ localstate.live_render_fps = render_fps
113
+ localstate.live_epsilon = epsilon
114
+ localstate.live_steps_forward = None
115
  print("=" * 80)
116
  print("Running...")
117
+ print(f"- policy_fname: {localstate.current_policy}")
118
  print(f"- n_test_episodes: {n_test_episodes}")
119
  print(f"- max_steps: {max_steps}")
120
+ print(f"- render_fps: {localstate.live_render_fps}")
121
+ print(f"- epsilon: {localstate.live_steps_forward}")
122
 
123
  policy_path = os.path.join(policies_folder, policy_fname)
124
  props = policy_fname.split("_")
125
 
126
  if len(props) < 2:
127
+ yield localstate, None, None, None, None, None, None, None, None, None, None, "🚫 Please select a valid policy file."
128
  return
129
 
130
  agent_type, env_name = props[0], props[1]
 
157
 
158
  for step, (episode_hist, solved, frame_env) in enumerate(
159
  agent.generate_episode(
160
+ max_steps=max_steps,
161
+ render=True,
162
+ epsilon_override=localstate.live_epsilon,
163
  )
164
  ):
165
  _, _, last_reward = (
 
180
 
181
  frame_policy = scipy.ndimage.gaussian_filter(frame_policy, sigma=1.0)
182
  frame_policy = np.clip(
183
+ frame_policy * (1.0 - localstate.live_epsilon)
184
+ + localstate.live_epsilon / len(curr_policy),
185
  0.0,
186
  1.0,
187
  )
188
+
189
+ label_loc_h, label_loc_w = frame_policy_h // 2, int(
190
+ (action + 0.5) * frame_policy_res // len(curr_policy)
191
+ )
192
+
193
  frame_policy_label_color = 1.0 - frame_policy[label_loc_h, label_loc_w]
194
  frame_policy_label_font = cv2.FONT_HERSHEY_SIMPLEX
195
  frame_policy_label_thicc = 1
196
+ action_text_scale, action_text_label_scale = 0.6, 0.3
197
+
198
+ (label_width, label_height), _ = cv2.getTextSize(
199
+ str(action),
200
+ frame_policy_label_font,
201
+ action_text_scale,
202
+ frame_policy_label_thicc,
203
+ )
204
 
205
  cv2.putText(
206
  frame_policy,
207
  str(action),
208
  (
209
  label_loc_w - label_width // 2,
210
+ label_loc_h + label_height // 2,
211
  ),
212
  frame_policy_label_font,
213
  action_text_scale,
 
218
 
219
  if env_action_map:
220
  action_name = env_action_map.get(action, "")
221
+ (label_width, label_height), _ = cv2.getTextSize(
222
+ action_name,
223
+ frame_policy_label_font,
224
+ action_text_label_scale,
225
+ frame_policy_label_thicc,
226
+ )
227
+
228
  cv2.putText(
229
  frame_policy,
230
  action_name,
231
  (
232
  int(label_loc_w - label_width / 2),
233
+ frame_policy_h
234
+ - (frame_policy_h - label_loc_h) // 2
235
+ + label_height // 2,
236
  ),
237
  frame_policy_label_font,
238
  action_text_label_scale,
 
242
  )
243
 
244
  print(
245
+ f"Episode: {ep_str(episode + 1)} - step: {step_str(step)} - state: {state} - action: {action} - reward: {reward} (epsilon: {localstate.live_epsilon:.2f}) (frame time: {1 / localstate.live_render_fps:.2f}s)"
246
  )
247
 
248
+ yield localstate, agent_type, env_name, frame_env, frame_policy, ep_str(
249
  episode + 1
250
  ), ep_str(episodes_solved), step_str(
251
  step
252
  ), state, action, last_reward, "Running..."
253
 
254
+ if localstate.live_steps_forward is not None:
255
+ if localstate.live_steps_forward > 0:
256
+ localstate.live_steps_forward -= 1
257
 
258
+ if localstate.live_steps_forward == 0:
259
+ localstate.live_steps_forward = None
260
+ localstate.live_paused = True
261
  else:
262
+ time.sleep(1 / localstate.live_render_fps)
263
 
264
+ while localstate.live_paused and localstate.live_steps_forward is None:
265
+ yield localstate, agent_type, env_name, frame_env, frame_policy, ep_str(
266
  episode + 1
267
  ), ep_str(episodes_solved), step_str(
268
  step
269
  ), state, action, last_reward, "Paused..."
270
+ time.sleep(1 / localstate.live_render_fps)
271
+ if localstate.should_reset is True:
272
  break
273
 
274
+ if localstate.should_reset is True:
275
+ localstate.should_reset = False
276
  yield (
277
+ localstate,
278
  agent_type,
279
  env_name,
280
  np.ones((frame_env_h, frame_env_w, 3)),
 
294
 
295
  time.sleep(0.25)
296
 
297
+ localstate.current_policy = None
298
+ yield localstate, agent_type, env_name, frame_env, frame_policy, ep_str(
299
+ episode + 1
300
+ ), ep_str(episodes_solved), step_str(step), state, action, reward, "Done!"
301
 
302
 
303
  with gr.Blocks(title="CS581 Demo") as demo:
 
305
  "<h1>CS581 Final Project Demo - Dynamic Programming & Monte-Carlo RL Methods (<a href='https://huggingface.co/spaces/acozma/CS581-Algos-Demo'>HF Space</a>)</h1>"
306
  )
307
 
308
+ localstate = gr.State(RunState())
309
+
310
  gr.components.HTML("<h2>Select Configuration:</h2>")
311
  with gr.Row():
312
  input_policy = gr.components.Dropdown(
 
358
  input_epsilon = gr.components.Slider(
359
  minimum=0,
360
  maximum=1,
361
+ value=default_epsilon,
362
+ step=1 / 200,
363
  label="Epsilon (0 = greedy, 1 = random)",
364
  )
365
+ input_epsilon.change(
366
+ change_epsilon, inputs=[localstate, input_epsilon], outputs=[localstate]
367
+ )
368
 
369
  input_render_fps = gr.components.Slider(
370
+ minimum=1,
371
+ maximum=60,
372
+ value=default_render_fps,
373
+ step=1,
374
+ label="Simulation speed (fps)",
375
+ )
376
+ input_render_fps.change(
377
+ change_render_fps,
378
+ inputs=[localstate, input_render_fps],
379
+ outputs=[localstate],
380
  )
 
381
 
382
  out_image_frame = gr.components.Image(
383
  label="Environment",
 
388
 
389
  with gr.Row():
390
  btn_pause = gr.components.Button(
391
+ pause_val_map_inv[not default_paused], interactive=True
392
  )
393
  btn_forward = gr.components.Button("⏩ Step")
394
 
395
  btn_pause.click(
396
  fn=change_paused,
397
+ inputs=[localstate, btn_pause],
398
+ outputs=[localstate, btn_pause, btn_forward],
399
  )
400
 
401
  btn_forward.click(
402
+ fn=onclick_btn_forward, inputs=[localstate], outputs=[localstate]
403
  )
404
 
405
  out_msg = gr.components.Textbox(
 
410
  )
411
 
412
  input_policy.change(
413
+ fn=reset,
414
+ inputs=[localstate, input_policy],
415
+ outputs=[localstate, btn_pause, btn_forward],
416
  )
417
 
418
  btn_run.click(
419
  fn=run,
420
  inputs=[
421
+ localstate,
422
  input_policy,
423
  input_n_test_episodes,
424
  input_max_steps,
 
426
  input_epsilon,
427
  ],
428
  outputs=[
429
+ localstate,
430
  out_agent,
431
  out_environment,
432
  out_image_frame,