Spaces:
Sleeping
Sleeping
Andrei Cozma
commited on
Commit
·
1ac9ba4
1
Parent(s):
18d81a3
Updates
Browse files
demo.py
CHANGED
@@ -12,8 +12,8 @@ default_render_fps = 5
|
|
12 |
default_epsilon = 0.0
|
13 |
default_paused = True
|
14 |
|
15 |
-
frame_env_h, frame_env_w =
|
16 |
-
|
17 |
|
18 |
# For the dropdown list of policies
|
19 |
policies_folder = "policies"
|
@@ -163,38 +163,38 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
163 |
state, action, reward = episode_hist[-1]
|
164 |
curr_policy = agent.Pi[state]
|
165 |
|
166 |
-
frame_env = cv2.resize(
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
)
|
174 |
-
|
175 |
-
if frame_env.shape[1] < frame_env_w:
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
frame_policy_h =
|
191 |
-
frame_policy = np.zeros((frame_policy_h,
|
192 |
for i, p in enumerate(curr_policy):
|
193 |
frame_policy[
|
194 |
:,
|
195 |
i
|
196 |
-
* (
|
197 |
-
* (
|
198 |
] = p
|
199 |
|
200 |
frame_policy = scipy.ndimage.gaussian_filter(frame_policy, sigma=1.0)
|
@@ -208,7 +208,7 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
208 |
frame_policy,
|
209 |
str(action),
|
210 |
(
|
211 |
-
int((action + 0.5) *
|
212 |
frame_policy_h // 2 - 5,
|
213 |
),
|
214 |
cv2.FONT_HERSHEY_SIMPLEX,
|
@@ -226,7 +226,7 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
226 |
action_name,
|
227 |
(
|
228 |
int(
|
229 |
-
(action + 0.5) *
|
230 |
- 5 * len(action_name)
|
231 |
),
|
232 |
frame_policy_h // 2 + 25,
|
@@ -274,7 +274,7 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
274 |
agent_type,
|
275 |
env_name,
|
276 |
np.ones((frame_env_h, frame_env_w, 3)),
|
277 |
-
np.ones((frame_policy_h,
|
278 |
ep_str(episode + 1),
|
279 |
ep_str(episodes_solved),
|
280 |
step_str(step),
|
@@ -345,9 +345,8 @@ with gr.Blocks(title="CS581 Demo") as demo:
|
|
345 |
label="Action Sampled vs Policy Distribution for Current State",
|
346 |
type="numpy",
|
347 |
image_mode="RGB",
|
348 |
-
value=np.ones((16, frame_policy_w)),
|
349 |
-
shape=(16, frame_policy_w),
|
350 |
)
|
|
|
351 |
|
352 |
with gr.Row():
|
353 |
input_epsilon = gr.components.Slider(
|
@@ -367,9 +366,8 @@ with gr.Blocks(title="CS581 Demo") as demo:
|
|
367 |
label="Environment",
|
368 |
type="numpy",
|
369 |
image_mode="RGB",
|
370 |
-
value=np.ones((frame_env_h, frame_env_w, 3)),
|
371 |
-
shape=(frame_env_h, frame_env_w),
|
372 |
)
|
|
|
373 |
|
374 |
with gr.Row():
|
375 |
btn_pause = gr.components.Button(
|
|
|
12 |
default_epsilon = 0.0
|
13 |
default_paused = True
|
14 |
|
15 |
+
frame_env_h, frame_env_w = 512, 768
|
16 |
+
frame_policy_res = 384
|
17 |
|
18 |
# For the dropdown list of policies
|
19 |
policies_folder = "policies"
|
|
|
163 |
state, action, reward = episode_hist[-1]
|
164 |
curr_policy = agent.Pi[state]
|
165 |
|
166 |
+
# frame_env = cv2.resize(
|
167 |
+
# frame_env,
|
168 |
+
# (
|
169 |
+
# int(frame_env.shape[1] / frame_env.shape[0] * frame_env_h),
|
170 |
+
# frame_env_h,
|
171 |
+
# ),
|
172 |
+
# interpolation=cv2.INTER_AREA,
|
173 |
+
# )
|
174 |
+
|
175 |
+
# if frame_env.shape[1] < frame_env_w:
|
176 |
+
# rgb_array_new = np.pad(
|
177 |
+
# frame_env,
|
178 |
+
# (
|
179 |
+
# (0, 0),
|
180 |
+
# (
|
181 |
+
# (frame_env_w - frame_env.shape[1]) // 2,
|
182 |
+
# (frame_env_w - frame_env.shape[1]) // 2,
|
183 |
+
# ),
|
184 |
+
# (0, 0),
|
185 |
+
# ),
|
186 |
+
# "constant",
|
187 |
+
# )
|
188 |
+
# frame_env = np.uint8(rgb_array_new)
|
189 |
+
|
190 |
+
frame_policy_h = frame_policy_res // len(curr_policy)
|
191 |
+
frame_policy = np.zeros((frame_policy_h, frame_policy_res))
|
192 |
for i, p in enumerate(curr_policy):
|
193 |
frame_policy[
|
194 |
:,
|
195 |
i
|
196 |
+
* (frame_policy_res // len(curr_policy)) : (i + 1)
|
197 |
+
* (frame_policy_res // len(curr_policy)),
|
198 |
] = p
|
199 |
|
200 |
frame_policy = scipy.ndimage.gaussian_filter(frame_policy, sigma=1.0)
|
|
|
208 |
frame_policy,
|
209 |
str(action),
|
210 |
(
|
211 |
+
int((action + 0.5) * frame_policy_res // len(curr_policy) - 8),
|
212 |
frame_policy_h // 2 - 5,
|
213 |
),
|
214 |
cv2.FONT_HERSHEY_SIMPLEX,
|
|
|
226 |
action_name,
|
227 |
(
|
228 |
int(
|
229 |
+
(action + 0.5) * frame_policy_res // len(curr_policy)
|
230 |
- 5 * len(action_name)
|
231 |
),
|
232 |
frame_policy_h // 2 + 25,
|
|
|
274 |
agent_type,
|
275 |
env_name,
|
276 |
np.ones((frame_env_h, frame_env_w, 3)),
|
277 |
+
np.ones((frame_policy_h, frame_policy_res)),
|
278 |
ep_str(episode + 1),
|
279 |
ep_str(episodes_solved),
|
280 |
step_str(step),
|
|
|
345 |
label="Action Sampled vs Policy Distribution for Current State",
|
346 |
type="numpy",
|
347 |
image_mode="RGB",
|
|
|
|
|
348 |
)
|
349 |
+
out_image_policy.style(height=200)
|
350 |
|
351 |
with gr.Row():
|
352 |
input_epsilon = gr.components.Slider(
|
|
|
366 |
label="Environment",
|
367 |
type="numpy",
|
368 |
image_mode="RGB",
|
|
|
|
|
369 |
)
|
370 |
+
out_image_frame.style(height=frame_env_h)
|
371 |
|
372 |
with gr.Row():
|
373 |
btn_pause = gr.components.Button(
|