Andrei Cozma commited on
Commit
45dcb54
·
1 Parent(s): e9e7977
Files changed (2) hide show
  1. README.md +1 -1
  2. demo.py +90 -74
README.md CHANGED
@@ -5,8 +5,8 @@ colorFrom: yellow
5
  colorTo: orange
6
  sdk: gradio
7
  app_file: demo.py
 
8
  pinned: true
9
-
10
  ---
11
 
12
  # CS581 Project - Reinforcement Learning: From Dynamic Programming to Monte-Carlo
 
5
  colorTo: orange
6
  sdk: gradio
7
  app_file: demo.py
8
+ fullWidth: true
9
  pinned: true
 
10
  ---
11
 
12
  # CS581 Project - Reinforcement Learning: From Dynamic Programming to Monte-Carlo
demo.py CHANGED
@@ -8,8 +8,11 @@ import cv2
8
 
9
  default_n_test_episodes = 10
10
  default_max_steps = 500
 
 
 
11
 
12
- frame_env_h, frame_env_w = 256, 512
13
  frame_policy_w = 384
14
 
15
  # For the dropdown list of policies
@@ -50,16 +53,25 @@ pause_val_map = {
50
  pause_val_map_inv = {v: k for k, v in pause_val_map.items()}
51
 
52
  # Global variables to allow changing it on the fly
53
- live_render_fps = 5
54
- live_epsilon = 0.0
55
- live_paused = True
 
56
  live_steps_forward = None
57
  should_reset = False
58
 
59
 
60
- # def reset():
61
- # global should_reset
62
- # should_reset = True
 
 
 
 
 
 
 
 
63
 
64
 
65
  def change_render_fps(x):
@@ -78,8 +90,9 @@ def change_paused(x):
78
  print("Changing paused:", x)
79
  global live_paused
80
  live_paused = pause_val_map[x]
81
- next_val = pause_val_map_inv[not live_paused]
82
- return gr.update(value=next_val), gr.update(interactive=live_paused)
 
83
 
84
 
85
  def onclick_btn_forward():
@@ -91,7 +104,8 @@ def onclick_btn_forward():
91
 
92
 
93
  def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
94
- global live_render_fps, live_epsilon, live_paused, live_steps_forward, should_reset
 
95
  live_render_fps = render_fps
96
  live_epsilon = epsilon
97
  live_steps_forward = None
@@ -116,7 +130,7 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
116
  agent.load_policy(policy_path)
117
  env_action_map = action_map.get(env_name)
118
 
119
- solved, rgb_array, policy_viz = None, None, None
120
  episode, step, state, action, reward, last_reward = (
121
  None,
122
  None,
@@ -136,9 +150,9 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
136
  return f"{step + 1}"
137
 
138
  for episode in range(n_test_episodes):
139
- time.sleep(0.5)
140
 
141
- for step, (episode_hist, solved, rgb_array) in enumerate(
142
  agent.generate_episode(
143
  max_steps=max_steps, render=True, epsilon_override=live_epsilon
144
  )
@@ -149,58 +163,58 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
149
  state, action, reward = episode_hist[-1]
150
  curr_policy = agent.Pi[state]
151
 
152
- # rgb_array = cv2.resize(
153
- # rgb_array,
154
- # (
155
- # int(rgb_array.shape[1] / rgb_array.shape[0] * frame_env_h),
156
- # frame_env_h,
157
- # ),
158
- # interpolation=cv2.INTER_AREA,
159
- # )
160
-
161
- # if rgb_array.shape[1] < frame_env_w:
162
- # rgb_array_new = np.pad(
163
- # rgb_array,
164
- # (
165
- # (0, 0),
166
- # (
167
- # (frame_env_w - rgb_array.shape[1]) // 2,
168
- # (frame_env_w - rgb_array.shape[1]) // 2,
169
- # ),
170
- # (0, 0),
171
- # ),
172
- # "constant",
173
- # )
174
- # rgb_array = np.uint8(rgb_array_new)
175
-
176
- viz_h = frame_policy_w // len(curr_policy)
177
- policy_viz = np.zeros((viz_h, frame_policy_w))
178
  for i, p in enumerate(curr_policy):
179
- policy_viz[
180
  :,
181
  i
182
  * (frame_policy_w // len(curr_policy)) : (i + 1)
183
  * (frame_policy_w // len(curr_policy)),
184
  ] = p
185
 
186
- policy_viz = scipy.ndimage.gaussian_filter(policy_viz, sigma=1.0)
187
- policy_viz = np.clip(
188
- policy_viz * (1.0 - live_epsilon) + live_epsilon / len(curr_policy),
189
  0.0,
190
  1.0,
191
  )
192
 
193
  cv2.putText(
194
- policy_viz,
195
  str(action),
196
  (
197
  int((action + 0.5) * frame_policy_w // len(curr_policy) - 8),
198
- viz_h // 2,
199
  ),
200
  cv2.FONT_HERSHEY_SIMPLEX,
 
201
  1.0,
202
- 1.0,
203
- 2,
204
  cv2.LINE_AA,
205
  )
206
 
@@ -208,19 +222,19 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
208
  action_name = env_action_map.get(action, "")
209
 
210
  cv2.putText(
211
- policy_viz,
212
  action_name,
213
  (
214
  int(
215
  (action + 0.5) * frame_policy_w // len(curr_policy)
216
  - 5 * len(action_name)
217
  ),
218
- viz_h // 2 + 30,
219
  ),
220
  cv2.FONT_HERSHEY_SIMPLEX,
221
- 0.6,
222
  1.0,
223
- 2,
224
  cv2.LINE_AA,
225
  )
226
 
@@ -228,7 +242,7 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
228
  f"Episode: {ep_str(episode + 1)} - step: {step_str(step)} - state: {state} - action: {action} - reward: {reward} (epsilon: {live_epsilon:.2f}) (frame time: {1 / live_render_fps:.2f}s)"
229
  )
230
 
231
- yield agent_type, env_name, rgb_array, policy_viz, ep_str(
232
  episode + 1
233
  ), ep_str(episodes_solved), step_str(
234
  step
@@ -245,37 +259,39 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
245
  time.sleep(1 / live_render_fps)
246
 
247
  while live_paused and live_steps_forward is None:
248
- yield agent_type, env_name, rgb_array, policy_viz, ep_str(
249
  episode + 1
250
  ), ep_str(episodes_solved), step_str(
251
  step
252
  ), state, action, last_reward, "Paused..."
253
  time.sleep(1 / live_render_fps)
254
- # if should_reset is True:
255
- # break
256
-
257
- # if should_reset is True:
258
- # should_reset = False
259
- # return (
260
- # agent_type,
261
- # env_name,
262
- # rgb_array,
263
- # policy_viz,
264
- # ep_str(episode + 1),
265
- # ep_str(episodes_solved),
266
- # step_str(step),
267
- # state,
268
- # action,
269
- # last_reward,
270
- # "Resetting...",
271
- # )
 
272
 
273
  if solved:
274
  episodes_solved += 1
275
 
276
- time.sleep(0.5)
277
 
278
- yield agent_type, env_name, rgb_array, policy_viz, ep_str(episode + 1), ep_str(
 
279
  episodes_solved
280
  ), step_str(step), state, action, reward, "Done!"
281
 
@@ -376,7 +392,7 @@ with gr.Blocks(title="CS581 Demo") as demo:
376
  label="Status Message",
377
  )
378
 
379
- # input_policy.change(fn=reset)
380
 
381
  btn_run.click(
382
  fn=run,
 
8
 
9
  default_n_test_episodes = 10
10
  default_max_steps = 500
11
+ default_render_fps = 5
12
+ default_epsilon = 0.0
13
+ default_paused = True
14
 
15
+ frame_env_h, frame_env_w = 256, 768
16
  frame_policy_w = 384
17
 
18
  # For the dropdown list of policies
 
53
  pause_val_map_inv = {v: k for k, v in pause_val_map.items()}
54
 
55
  # Global variables to allow changing it on the fly
56
+ is_running = False
57
+ live_render_fps = default_render_fps
58
+ live_epsilon = default_epsilon
59
+ live_paused = default_paused
60
  live_steps_forward = None
61
  should_reset = False
62
 
63
 
64
+ def reset():
65
+ global is_running, live_render_fps, live_epsilon, live_paused, live_steps_forward, should_reset
66
+ if is_running:
67
+ should_reset = True
68
+ live_paused = default_paused
69
+ live_render_fps = default_render_fps
70
+ live_epsilon = default_epsilon
71
+ live_steps_forward = None
72
+ return gr.update(value=pause_val_map_inv[not live_paused]), gr.update(
73
+ interactive=live_paused
74
+ )
75
 
76
 
77
  def change_render_fps(x):
 
90
  print("Changing paused:", x)
91
  global live_paused
92
  live_paused = pause_val_map[x]
93
+ return gr.update(value=pause_val_map_inv[not live_paused]), gr.update(
94
+ interactive=live_paused
95
+ )
96
 
97
 
98
  def onclick_btn_forward():
 
104
 
105
 
106
  def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
107
+ global is_running, live_render_fps, live_epsilon, live_paused, live_steps_forward, should_reset
108
+ is_running = True
109
  live_render_fps = render_fps
110
  live_epsilon = epsilon
111
  live_steps_forward = None
 
130
  agent.load_policy(policy_path)
131
  env_action_map = action_map.get(env_name)
132
 
133
+ solved, frame_env, frame_policy = None, None, None
134
  episode, step, state, action, reward, last_reward = (
135
  None,
136
  None,
 
150
  return f"{step + 1}"
151
 
152
  for episode in range(n_test_episodes):
153
+ time.sleep(0.25)
154
 
155
+ for step, (episode_hist, solved, frame_env) in enumerate(
156
  agent.generate_episode(
157
  max_steps=max_steps, render=True, epsilon_override=live_epsilon
158
  )
 
163
  state, action, reward = episode_hist[-1]
164
  curr_policy = agent.Pi[state]
165
 
166
+ frame_env = cv2.resize(
167
+ frame_env,
168
+ (
169
+ int(frame_env.shape[1] / frame_env.shape[0] * frame_env_h),
170
+ frame_env_h,
171
+ ),
172
+ interpolation=cv2.INTER_AREA,
173
+ )
174
+
175
+ if frame_env.shape[1] < frame_env_w:
176
+ rgb_array_new = np.pad(
177
+ frame_env,
178
+ (
179
+ (0, 0),
180
+ (
181
+ (frame_env_w - frame_env.shape[1]) // 2,
182
+ (frame_env_w - frame_env.shape[1]) // 2,
183
+ ),
184
+ (0, 0),
185
+ ),
186
+ "constant",
187
+ )
188
+ frame_env = np.uint8(rgb_array_new)
189
+
190
+ frame_policy_h = frame_policy_w // len(curr_policy)
191
+ frame_policy = np.zeros((frame_policy_h, frame_policy_w))
192
  for i, p in enumerate(curr_policy):
193
+ frame_policy[
194
  :,
195
  i
196
  * (frame_policy_w // len(curr_policy)) : (i + 1)
197
  * (frame_policy_w // len(curr_policy)),
198
  ] = p
199
 
200
+ frame_policy = scipy.ndimage.gaussian_filter(frame_policy, sigma=1.0)
201
+ frame_policy = np.clip(
202
+ frame_policy * (1.0 - live_epsilon) + live_epsilon / len(curr_policy),
203
  0.0,
204
  1.0,
205
  )
206
 
207
  cv2.putText(
208
+ frame_policy,
209
  str(action),
210
  (
211
  int((action + 0.5) * frame_policy_w // len(curr_policy) - 8),
212
+ frame_policy_h // 2 - 5,
213
  ),
214
  cv2.FONT_HERSHEY_SIMPLEX,
215
+ 0.8,
216
  1.0,
217
+ 1,
 
218
  cv2.LINE_AA,
219
  )
220
 
 
222
  action_name = env_action_map.get(action, "")
223
 
224
  cv2.putText(
225
+ frame_policy,
226
  action_name,
227
  (
228
  int(
229
  (action + 0.5) * frame_policy_w // len(curr_policy)
230
  - 5 * len(action_name)
231
  ),
232
+ frame_policy_h // 2 + 25,
233
  ),
234
  cv2.FONT_HERSHEY_SIMPLEX,
235
+ 0.5,
236
  1.0,
237
+ 1,
238
  cv2.LINE_AA,
239
  )
240
 
 
242
  f"Episode: {ep_str(episode + 1)} - step: {step_str(step)} - state: {state} - action: {action} - reward: {reward} (epsilon: {live_epsilon:.2f}) (frame time: {1 / live_render_fps:.2f}s)"
243
  )
244
 
245
+ yield agent_type, env_name, frame_env, frame_policy, ep_str(
246
  episode + 1
247
  ), ep_str(episodes_solved), step_str(
248
  step
 
259
  time.sleep(1 / live_render_fps)
260
 
261
  while live_paused and live_steps_forward is None:
262
+ yield agent_type, env_name, frame_env, frame_policy, ep_str(
263
  episode + 1
264
  ), ep_str(episodes_solved), step_str(
265
  step
266
  ), state, action, last_reward, "Paused..."
267
  time.sleep(1 / live_render_fps)
268
+ if should_reset is True:
269
+ break
270
+
271
+ if should_reset is True:
272
+ should_reset = False
273
+ yield (
274
+ agent_type,
275
+ env_name,
276
+ np.ones((frame_env_h, frame_env_w, 3)),
277
+ np.ones((frame_policy_h, frame_policy_w)),
278
+ ep_str(episode + 1),
279
+ ep_str(episodes_solved),
280
+ step_str(step),
281
+ state,
282
+ action,
283
+ last_reward,
284
+ "Reset...",
285
+ )
286
+ return
287
 
288
  if solved:
289
  episodes_solved += 1
290
 
291
+ time.sleep(0.25)
292
 
293
+ is_running = False
294
+ yield agent_type, env_name, frame_env, frame_policy, ep_str(episode + 1), ep_str(
295
  episodes_solved
296
  ), step_str(step), state, action, reward, "Done!"
297
 
 
392
  label="Status Message",
393
  )
394
 
395
+ input_policy.change(fn=reset, outputs=[btn_pause, btn_forward])
396
 
397
  btn_run.click(
398
  fn=run,