Spaces:
Sleeping
Sleeping
Andrei Cozma
commited on
Commit
路
69d9811
1
Parent(s):
4567d2b
Updates
Browse files
demo.py
CHANGED
@@ -5,6 +5,7 @@ import numpy as np
|
|
5 |
import gradio as gr
|
6 |
from MonteCarloAgent import MonteCarloAgent
|
7 |
import scipy.ndimage
|
|
|
8 |
|
9 |
# For the dropdown list of policies
|
10 |
policies_folder = "policies"
|
@@ -90,7 +91,8 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
90 |
state, action, reward = episode_hist[-1]
|
91 |
curr_policy = agent.Pi[state]
|
92 |
|
93 |
-
viz_w
|
|
|
94 |
policy_viz = np.zeros((viz_h, viz_w))
|
95 |
for i, p in enumerate(curr_policy):
|
96 |
policy_viz[
|
@@ -100,6 +102,22 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
100 |
* (viz_w // len(curr_policy)),
|
101 |
] = p
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
policy_viz = scipy.ndimage.gaussian_filter(policy_viz, sigma=1)
|
104 |
policy_viz = np.clip(
|
105 |
policy_viz * (1 - live_epsilon) + live_epsilon / len(curr_policy), 0, 1
|
@@ -152,9 +170,7 @@ with gr.Blocks(title="CS581 Demo") as demo:
|
|
152 |
label="Max steps per episode",
|
153 |
)
|
154 |
|
155 |
-
btn_run = gr.components.Button(
|
156 |
-
"鈻讹笍 Start", interactive=True if all_policies else False
|
157 |
-
)
|
158 |
|
159 |
gr.components.HTML("<h2>Live Statistics & Policy Visualization:</h2>")
|
160 |
with gr.Row():
|
@@ -196,7 +212,6 @@ with gr.Blocks(title="CS581 Demo") as demo:
|
|
196 |
)
|
197 |
|
198 |
with gr.Row():
|
199 |
-
# Pause/resume button
|
200 |
btn_pause = gr.components.Button("鈴革笍 Pause", interactive=True)
|
201 |
btn_pause.click(
|
202 |
fn=change_paused,
|
|
|
5 |
import gradio as gr
|
6 |
from MonteCarloAgent import MonteCarloAgent
|
7 |
import scipy.ndimage
|
8 |
+
import cv2
|
9 |
|
10 |
# For the dropdown list of policies
|
11 |
policies_folder = "policies"
|
|
|
91 |
state, action, reward = episode_hist[-1]
|
92 |
curr_policy = agent.Pi[state]
|
93 |
|
94 |
+
viz_w = 512
|
95 |
+
viz_h = viz_w // len(curr_policy)
|
96 |
policy_viz = np.zeros((viz_h, viz_w))
|
97 |
for i, p in enumerate(curr_policy):
|
98 |
policy_viz[
|
|
|
102 |
* (viz_w // len(curr_policy)),
|
103 |
] = p
|
104 |
|
105 |
+
policy_viz = np.stack([policy_viz] * 3, axis=-1)
|
106 |
+
text_offset = 15
|
107 |
+
cv2.putText(
|
108 |
+
policy_viz,
|
109 |
+
str(action),
|
110 |
+
(
|
111 |
+
int((action + 0.5) * viz_w // len(curr_policy) - text_offset),
|
112 |
+
viz_h // 2 + text_offset,
|
113 |
+
),
|
114 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
115 |
+
1.5,
|
116 |
+
(255, 255, 255),
|
117 |
+
1,
|
118 |
+
cv2.LINE_AA,
|
119 |
+
)
|
120 |
+
|
121 |
policy_viz = scipy.ndimage.gaussian_filter(policy_viz, sigma=1)
|
122 |
policy_viz = np.clip(
|
123 |
policy_viz * (1 - live_epsilon) + live_epsilon / len(curr_policy), 0, 1
|
|
|
170 |
label="Max steps per episode",
|
171 |
)
|
172 |
|
173 |
+
btn_run = gr.components.Button("鈻讹笍 Start", interactive=bool(all_policies))
|
|
|
|
|
174 |
|
175 |
gr.components.HTML("<h2>Live Statistics & Policy Visualization:</h2>")
|
176 |
with gr.Row():
|
|
|
212 |
)
|
213 |
|
214 |
with gr.Row():
|
|
|
215 |
btn_pause = gr.components.Button("鈴革笍 Pause", interactive=True)
|
216 |
btn_pause.click(
|
217 |
fn=change_paused,
|