Andrei Cozma commited on
Commit
7357801
·
1 Parent(s): df12910
Files changed (1) hide show
  1. demo.py +21 -18
demo.py CHANGED
@@ -9,6 +9,9 @@ import cv2
9
  default_n_test_episodes = 10
10
  default_max_steps = 500
11
 
 
 
 
12
  # For the dropdown list of policies
13
  policies_folder = "policies"
14
  try:
@@ -146,24 +149,23 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
146
  state, action, reward = episode_hist[-1]
147
  curr_policy = agent.Pi[state]
148
 
149
- rgb_array_height, rgb_array_width = 512, 768
150
  rgb_array = cv2.resize(
151
  rgb_array,
152
  (
153
- int(rgb_array.shape[1] / rgb_array.shape[0] * rgb_array_height),
154
- rgb_array_height,
155
  ),
156
  interpolation=cv2.INTER_AREA,
157
  )
158
 
159
- if rgb_array.shape[1] < rgb_array_width:
160
  rgb_array_new = np.pad(
161
  rgb_array,
162
  (
163
  (0, 0),
164
  (
165
- (rgb_array_width - rgb_array.shape[1]) // 2,
166
- (rgb_array_width - rgb_array.shape[1]) // 2,
167
  ),
168
  (0, 0),
169
  ),
@@ -171,15 +173,14 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
171
  )
172
  rgb_array = np.uint8(rgb_array_new)
173
 
174
- viz_w = 384
175
- viz_h = viz_w // len(curr_policy)
176
- policy_viz = np.zeros((viz_h, viz_w))
177
  for i, p in enumerate(curr_policy):
178
  policy_viz[
179
  :,
180
  i
181
- * (viz_w // len(curr_policy)) : (i + 1)
182
- * (viz_w // len(curr_policy)),
183
  ] = p
184
 
185
  policy_viz = scipy.ndimage.gaussian_filter(policy_viz, sigma=1.0)
@@ -193,8 +194,8 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
193
  policy_viz,
194
  str(action),
195
  (
196
- int((action + 0.5) * viz_w // len(curr_policy) - 8),
197
- viz_h // 2 - 10,
198
  ),
199
  cv2.FONT_HERSHEY_SIMPLEX,
200
  1.0,
@@ -211,10 +212,10 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
211
  action_name,
212
  (
213
  int(
214
- (action + 0.5) * viz_w // len(curr_policy)
215
  - 5 * len(action_name)
216
  ),
217
- viz_h // 2 + 20,
218
  ),
219
  cv2.FONT_HERSHEY_SIMPLEX,
220
  0.6,
@@ -325,8 +326,7 @@ with gr.Blocks(title="CS581 Demo") as demo:
325
  out_reward = gr.components.Textbox(label="Last Reward")
326
 
327
  out_image_policy = gr.components.Image(
328
- # value=np.ones((16, 128)),
329
- # shape=(16, 128),
330
  label="Action Sampled vs Policy Distribution for Current State",
331
  type="numpy",
332
  image_mode="RGB",
@@ -347,7 +347,10 @@ with gr.Blocks(title="CS581 Demo") as demo:
347
  input_render_fps.change(change_render_fps, inputs=[input_render_fps])
348
 
349
  out_image_frame = gr.components.Image(
350
- label="Environment", type="numpy", image_mode="RGB", shape=(512, 768)
 
 
 
351
  )
352
 
353
  with gr.Row():
 
9
  default_n_test_episodes = 10
10
  default_max_steps = 500
11
 
12
+ frame_env_h, frame_env_w = 256, 512
13
+ frame_policy_w = 384
14
+
15
  # For the dropdown list of policies
16
  policies_folder = "policies"
17
  try:
 
149
  state, action, reward = episode_hist[-1]
150
  curr_policy = agent.Pi[state]
151
 
 
152
  rgb_array = cv2.resize(
153
  rgb_array,
154
  (
155
+ int(rgb_array.shape[1] / rgb_array.shape[0] * frame_env_h),
156
+ frame_env_h,
157
  ),
158
  interpolation=cv2.INTER_AREA,
159
  )
160
 
161
+ if rgb_array.shape[1] < frame_env_w:
162
  rgb_array_new = np.pad(
163
  rgb_array,
164
  (
165
  (0, 0),
166
  (
167
+ (frame_env_w - rgb_array.shape[1]) // 2,
168
+ (frame_env_w - rgb_array.shape[1]) // 2,
169
  ),
170
  (0, 0),
171
  ),
 
173
  )
174
  rgb_array = np.uint8(rgb_array_new)
175
 
176
+ viz_h = frame_policy_w // len(curr_policy)
177
+ policy_viz = np.zeros((viz_h, frame_policy_w))
 
178
  for i, p in enumerate(curr_policy):
179
  policy_viz[
180
  :,
181
  i
182
+ * (frame_policy_w // len(curr_policy)) : (i + 1)
183
+ * (frame_policy_w // len(curr_policy)),
184
  ] = p
185
 
186
  policy_viz = scipy.ndimage.gaussian_filter(policy_viz, sigma=1.0)
 
194
  policy_viz,
195
  str(action),
196
  (
197
+ int((action + 0.5) * frame_policy_w // len(curr_policy) - 8),
198
+ viz_h // 2,
199
  ),
200
  cv2.FONT_HERSHEY_SIMPLEX,
201
  1.0,
 
212
  action_name,
213
  (
214
  int(
215
+ (action + 0.5) * frame_policy_w // len(curr_policy)
216
  - 5 * len(action_name)
217
  ),
218
+ viz_h // 2 + 30,
219
  ),
220
  cv2.FONT_HERSHEY_SIMPLEX,
221
  0.6,
 
326
  out_reward = gr.components.Textbox(label="Last Reward")
327
 
328
  out_image_policy = gr.components.Image(
329
+ value=np.ones((16, 128)),
 
330
  label="Action Sampled vs Policy Distribution for Current State",
331
  type="numpy",
332
  image_mode="RGB",
 
347
  input_render_fps.change(change_render_fps, inputs=[input_render_fps])
348
 
349
  out_image_frame = gr.components.Image(
350
+ label="Environment",
351
+ type="numpy",
352
+ image_mode="RGB",
353
+ value=np.ones((frame_env_h, frame_env_w, 3)),
354
  )
355
 
356
  with gr.Row():