Spaces:
Sleeping
Sleeping
Andrei Cozma
commited on
Commit
·
7357801
1
Parent(s):
df12910
Updates
Browse files
demo.py
CHANGED
@@ -9,6 +9,9 @@ import cv2
|
|
9 |
default_n_test_episodes = 10
|
10 |
default_max_steps = 500
|
11 |
|
|
|
|
|
|
|
12 |
# For the dropdown list of policies
|
13 |
policies_folder = "policies"
|
14 |
try:
|
@@ -146,24 +149,23 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
146 |
state, action, reward = episode_hist[-1]
|
147 |
curr_policy = agent.Pi[state]
|
148 |
|
149 |
-
rgb_array_height, rgb_array_width = 512, 768
|
150 |
rgb_array = cv2.resize(
|
151 |
rgb_array,
|
152 |
(
|
153 |
-
int(rgb_array.shape[1] / rgb_array.shape[0] *
|
154 |
-
|
155 |
),
|
156 |
interpolation=cv2.INTER_AREA,
|
157 |
)
|
158 |
|
159 |
-
if rgb_array.shape[1] <
|
160 |
rgb_array_new = np.pad(
|
161 |
rgb_array,
|
162 |
(
|
163 |
(0, 0),
|
164 |
(
|
165 |
-
(
|
166 |
-
(
|
167 |
),
|
168 |
(0, 0),
|
169 |
),
|
@@ -171,15 +173,14 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
171 |
)
|
172 |
rgb_array = np.uint8(rgb_array_new)
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
policy_viz = np.zeros((viz_h, viz_w))
|
177 |
for i, p in enumerate(curr_policy):
|
178 |
policy_viz[
|
179 |
:,
|
180 |
i
|
181 |
-
* (
|
182 |
-
* (
|
183 |
] = p
|
184 |
|
185 |
policy_viz = scipy.ndimage.gaussian_filter(policy_viz, sigma=1.0)
|
@@ -193,8 +194,8 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
193 |
policy_viz,
|
194 |
str(action),
|
195 |
(
|
196 |
-
int((action + 0.5) *
|
197 |
-
viz_h // 2
|
198 |
),
|
199 |
cv2.FONT_HERSHEY_SIMPLEX,
|
200 |
1.0,
|
@@ -211,10 +212,10 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
211 |
action_name,
|
212 |
(
|
213 |
int(
|
214 |
-
(action + 0.5) *
|
215 |
- 5 * len(action_name)
|
216 |
),
|
217 |
-
viz_h // 2 +
|
218 |
),
|
219 |
cv2.FONT_HERSHEY_SIMPLEX,
|
220 |
0.6,
|
@@ -325,8 +326,7 @@ with gr.Blocks(title="CS581 Demo") as demo:
|
|
325 |
out_reward = gr.components.Textbox(label="Last Reward")
|
326 |
|
327 |
out_image_policy = gr.components.Image(
|
328 |
-
|
329 |
-
# shape=(16, 128),
|
330 |
label="Action Sampled vs Policy Distribution for Current State",
|
331 |
type="numpy",
|
332 |
image_mode="RGB",
|
@@ -347,7 +347,10 @@ with gr.Blocks(title="CS581 Demo") as demo:
|
|
347 |
input_render_fps.change(change_render_fps, inputs=[input_render_fps])
|
348 |
|
349 |
out_image_frame = gr.components.Image(
|
350 |
-
label="Environment",
|
|
|
|
|
|
|
351 |
)
|
352 |
|
353 |
with gr.Row():
|
|
|
9 |
default_n_test_episodes = 10
|
10 |
default_max_steps = 500
|
11 |
|
12 |
+
frame_env_h, frame_env_w = 256, 512
|
13 |
+
frame_policy_w = 384
|
14 |
+
|
15 |
# For the dropdown list of policies
|
16 |
policies_folder = "policies"
|
17 |
try:
|
|
|
149 |
state, action, reward = episode_hist[-1]
|
150 |
curr_policy = agent.Pi[state]
|
151 |
|
|
|
152 |
rgb_array = cv2.resize(
|
153 |
rgb_array,
|
154 |
(
|
155 |
+
int(rgb_array.shape[1] / rgb_array.shape[0] * frame_env_h),
|
156 |
+
frame_env_h,
|
157 |
),
|
158 |
interpolation=cv2.INTER_AREA,
|
159 |
)
|
160 |
|
161 |
+
if rgb_array.shape[1] < frame_env_w:
|
162 |
rgb_array_new = np.pad(
|
163 |
rgb_array,
|
164 |
(
|
165 |
(0, 0),
|
166 |
(
|
167 |
+
(frame_env_w - rgb_array.shape[1]) // 2,
|
168 |
+
(frame_env_w - rgb_array.shape[1]) // 2,
|
169 |
),
|
170 |
(0, 0),
|
171 |
),
|
|
|
173 |
)
|
174 |
rgb_array = np.uint8(rgb_array_new)
|
175 |
|
176 |
+
viz_h = frame_policy_w // len(curr_policy)
|
177 |
+
policy_viz = np.zeros((viz_h, frame_policy_w))
|
|
|
178 |
for i, p in enumerate(curr_policy):
|
179 |
policy_viz[
|
180 |
:,
|
181 |
i
|
182 |
+
* (frame_policy_w // len(curr_policy)) : (i + 1)
|
183 |
+
* (frame_policy_w // len(curr_policy)),
|
184 |
] = p
|
185 |
|
186 |
policy_viz = scipy.ndimage.gaussian_filter(policy_viz, sigma=1.0)
|
|
|
194 |
policy_viz,
|
195 |
str(action),
|
196 |
(
|
197 |
+
int((action + 0.5) * frame_policy_w // len(curr_policy) - 8),
|
198 |
+
viz_h // 2,
|
199 |
),
|
200 |
cv2.FONT_HERSHEY_SIMPLEX,
|
201 |
1.0,
|
|
|
212 |
action_name,
|
213 |
(
|
214 |
int(
|
215 |
+
(action + 0.5) * frame_policy_w // len(curr_policy)
|
216 |
- 5 * len(action_name)
|
217 |
),
|
218 |
+
viz_h // 2 + 30,
|
219 |
),
|
220 |
cv2.FONT_HERSHEY_SIMPLEX,
|
221 |
0.6,
|
|
|
326 |
out_reward = gr.components.Textbox(label="Last Reward")
|
327 |
|
328 |
out_image_policy = gr.components.Image(
|
329 |
+
value=np.ones((16, 128)),
|
|
|
330 |
label="Action Sampled vs Policy Distribution for Current State",
|
331 |
type="numpy",
|
332 |
image_mode="RGB",
|
|
|
347 |
input_render_fps.change(change_render_fps, inputs=[input_render_fps])
|
348 |
|
349 |
out_image_frame = gr.components.Image(
|
350 |
+
label="Environment",
|
351 |
+
type="numpy",
|
352 |
+
image_mode="RGB",
|
353 |
+
value=np.ones((frame_env_h, frame_env_w, 3)),
|
354 |
)
|
355 |
|
356 |
with gr.Row():
|