Andrei Cozma commited on
Commit
1ac9ba4
·
1 Parent(s): 18d81a3
Files changed (1) hide show
  1. demo.py +35 -37
demo.py CHANGED
@@ -12,8 +12,8 @@ default_render_fps = 5
12
  default_epsilon = 0.0
13
  default_paused = True
14
 
15
- frame_env_h, frame_env_w = 256, 768
16
- frame_policy_w = 384
17
 
18
  # For the dropdown list of policies
19
  policies_folder = "policies"
@@ -163,38 +163,38 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
163
  state, action, reward = episode_hist[-1]
164
  curr_policy = agent.Pi[state]
165
 
166
- frame_env = cv2.resize(
167
- frame_env,
168
- (
169
- int(frame_env.shape[1] / frame_env.shape[0] * frame_env_h),
170
- frame_env_h,
171
- ),
172
- interpolation=cv2.INTER_AREA,
173
- )
174
-
175
- if frame_env.shape[1] < frame_env_w:
176
- rgb_array_new = np.pad(
177
- frame_env,
178
- (
179
- (0, 0),
180
- (
181
- (frame_env_w - frame_env.shape[1]) // 2,
182
- (frame_env_w - frame_env.shape[1]) // 2,
183
- ),
184
- (0, 0),
185
- ),
186
- "constant",
187
- )
188
- frame_env = np.uint8(rgb_array_new)
189
-
190
- frame_policy_h = frame_policy_w // len(curr_policy)
191
- frame_policy = np.zeros((frame_policy_h, frame_policy_w))
192
  for i, p in enumerate(curr_policy):
193
  frame_policy[
194
  :,
195
  i
196
- * (frame_policy_w // len(curr_policy)) : (i + 1)
197
- * (frame_policy_w // len(curr_policy)),
198
  ] = p
199
 
200
  frame_policy = scipy.ndimage.gaussian_filter(frame_policy, sigma=1.0)
@@ -208,7 +208,7 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
208
  frame_policy,
209
  str(action),
210
  (
211
- int((action + 0.5) * frame_policy_w // len(curr_policy) - 8),
212
  frame_policy_h // 2 - 5,
213
  ),
214
  cv2.FONT_HERSHEY_SIMPLEX,
@@ -226,7 +226,7 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
226
  action_name,
227
  (
228
  int(
229
- (action + 0.5) * frame_policy_w // len(curr_policy)
230
  - 5 * len(action_name)
231
  ),
232
  frame_policy_h // 2 + 25,
@@ -274,7 +274,7 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
274
  agent_type,
275
  env_name,
276
  np.ones((frame_env_h, frame_env_w, 3)),
277
- np.ones((frame_policy_h, frame_policy_w)),
278
  ep_str(episode + 1),
279
  ep_str(episodes_solved),
280
  step_str(step),
@@ -345,9 +345,8 @@ with gr.Blocks(title="CS581 Demo") as demo:
345
  label="Action Sampled vs Policy Distribution for Current State",
346
  type="numpy",
347
  image_mode="RGB",
348
- value=np.ones((16, frame_policy_w)),
349
- shape=(16, frame_policy_w),
350
  )
 
351
 
352
  with gr.Row():
353
  input_epsilon = gr.components.Slider(
@@ -367,9 +366,8 @@ with gr.Blocks(title="CS581 Demo") as demo:
367
  label="Environment",
368
  type="numpy",
369
  image_mode="RGB",
370
- value=np.ones((frame_env_h, frame_env_w, 3)),
371
- shape=(frame_env_h, frame_env_w),
372
  )
 
373
 
374
  with gr.Row():
375
  btn_pause = gr.components.Button(
 
12
  default_epsilon = 0.0
13
  default_paused = True
14
 
15
+ frame_env_h, frame_env_w = 512, 768
16
+ frame_policy_res = 384
17
 
18
  # For the dropdown list of policies
19
  policies_folder = "policies"
 
163
  state, action, reward = episode_hist[-1]
164
  curr_policy = agent.Pi[state]
165
 
166
+ # frame_env = cv2.resize(
167
+ # frame_env,
168
+ # (
169
+ # int(frame_env.shape[1] / frame_env.shape[0] * frame_env_h),
170
+ # frame_env_h,
171
+ # ),
172
+ # interpolation=cv2.INTER_AREA,
173
+ # )
174
+
175
+ # if frame_env.shape[1] < frame_env_w:
176
+ # rgb_array_new = np.pad(
177
+ # frame_env,
178
+ # (
179
+ # (0, 0),
180
+ # (
181
+ # (frame_env_w - frame_env.shape[1]) // 2,
182
+ # (frame_env_w - frame_env.shape[1]) // 2,
183
+ # ),
184
+ # (0, 0),
185
+ # ),
186
+ # "constant",
187
+ # )
188
+ # frame_env = np.uint8(rgb_array_new)
189
+
190
+ frame_policy_h = frame_policy_res // len(curr_policy)
191
+ frame_policy = np.zeros((frame_policy_h, frame_policy_res))
192
  for i, p in enumerate(curr_policy):
193
  frame_policy[
194
  :,
195
  i
196
+ * (frame_policy_res // len(curr_policy)) : (i + 1)
197
+ * (frame_policy_res // len(curr_policy)),
198
  ] = p
199
 
200
  frame_policy = scipy.ndimage.gaussian_filter(frame_policy, sigma=1.0)
 
208
  frame_policy,
209
  str(action),
210
  (
211
+ int((action + 0.5) * frame_policy_res // len(curr_policy) - 8),
212
  frame_policy_h // 2 - 5,
213
  ),
214
  cv2.FONT_HERSHEY_SIMPLEX,
 
226
  action_name,
227
  (
228
  int(
229
+ (action + 0.5) * frame_policy_res // len(curr_policy)
230
  - 5 * len(action_name)
231
  ),
232
  frame_policy_h // 2 + 25,
 
274
  agent_type,
275
  env_name,
276
  np.ones((frame_env_h, frame_env_w, 3)),
277
+ np.ones((frame_policy_h, frame_policy_res)),
278
  ep_str(episode + 1),
279
  ep_str(episodes_solved),
280
  step_str(step),
 
345
  label="Action Sampled vs Policy Distribution for Current State",
346
  type="numpy",
347
  image_mode="RGB",
 
 
348
  )
349
+ out_image_policy.style(height=200)
350
 
351
  with gr.Row():
352
  input_epsilon = gr.components.Slider(
 
366
  label="Environment",
367
  type="numpy",
368
  image_mode="RGB",
 
 
369
  )
370
+ out_image_frame.style(height=frame_env_h)
371
 
372
  with gr.Row():
373
  btn_pause = gr.components.Button(