Spaces:
Sleeping
Sleeping
Andrei Cozma
commited on
Commit
·
45dcb54
1
Parent(s):
e9e7977
Updates
Browse files
README.md
CHANGED
@@ -5,8 +5,8 @@ colorFrom: yellow
|
|
5 |
colorTo: orange
|
6 |
sdk: gradio
|
7 |
app_file: demo.py
|
|
|
8 |
pinned: true
|
9 |
-
|
10 |
---
|
11 |
|
12 |
# CS581 Project - Reinforcement Learning: From Dynamic Programming to Monte-Carlo
|
|
|
5 |
colorTo: orange
|
6 |
sdk: gradio
|
7 |
app_file: demo.py
|
8 |
+
fullWidth: true
|
9 |
pinned: true
|
|
|
10 |
---
|
11 |
|
12 |
# CS581 Project - Reinforcement Learning: From Dynamic Programming to Monte-Carlo
|
demo.py
CHANGED
@@ -8,8 +8,11 @@ import cv2
|
|
8 |
|
9 |
default_n_test_episodes = 10
|
10 |
default_max_steps = 500
|
|
|
|
|
|
|
11 |
|
12 |
-
frame_env_h, frame_env_w = 256,
|
13 |
frame_policy_w = 384
|
14 |
|
15 |
# For the dropdown list of policies
|
@@ -50,16 +53,25 @@ pause_val_map = {
|
|
50 |
pause_val_map_inv = {v: k for k, v in pause_val_map.items()}
|
51 |
|
52 |
# Global variables to allow changing it on the fly
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
56 |
live_steps_forward = None
|
57 |
should_reset = False
|
58 |
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
|
65 |
def change_render_fps(x):
|
@@ -78,8 +90,9 @@ def change_paused(x):
|
|
78 |
print("Changing paused:", x)
|
79 |
global live_paused
|
80 |
live_paused = pause_val_map[x]
|
81 |
-
|
82 |
-
|
|
|
83 |
|
84 |
|
85 |
def onclick_btn_forward():
|
@@ -91,7 +104,8 @@ def onclick_btn_forward():
|
|
91 |
|
92 |
|
93 |
def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
94 |
-
global live_render_fps, live_epsilon, live_paused, live_steps_forward, should_reset
|
|
|
95 |
live_render_fps = render_fps
|
96 |
live_epsilon = epsilon
|
97 |
live_steps_forward = None
|
@@ -116,7 +130,7 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
116 |
agent.load_policy(policy_path)
|
117 |
env_action_map = action_map.get(env_name)
|
118 |
|
119 |
-
solved,
|
120 |
episode, step, state, action, reward, last_reward = (
|
121 |
None,
|
122 |
None,
|
@@ -136,9 +150,9 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
136 |
return f"{step + 1}"
|
137 |
|
138 |
for episode in range(n_test_episodes):
|
139 |
-
time.sleep(0.
|
140 |
|
141 |
-
for step, (episode_hist, solved,
|
142 |
agent.generate_episode(
|
143 |
max_steps=max_steps, render=True, epsilon_override=live_epsilon
|
144 |
)
|
@@ -149,58 +163,58 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
149 |
state, action, reward = episode_hist[-1]
|
150 |
curr_policy = agent.Pi[state]
|
151 |
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
for i, p in enumerate(curr_policy):
|
179 |
-
|
180 |
:,
|
181 |
i
|
182 |
* (frame_policy_w // len(curr_policy)) : (i + 1)
|
183 |
* (frame_policy_w // len(curr_policy)),
|
184 |
] = p
|
185 |
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
0.0,
|
190 |
1.0,
|
191 |
)
|
192 |
|
193 |
cv2.putText(
|
194 |
-
|
195 |
str(action),
|
196 |
(
|
197 |
int((action + 0.5) * frame_policy_w // len(curr_policy) - 8),
|
198 |
-
|
199 |
),
|
200 |
cv2.FONT_HERSHEY_SIMPLEX,
|
|
|
201 |
1.0,
|
202 |
-
1
|
203 |
-
2,
|
204 |
cv2.LINE_AA,
|
205 |
)
|
206 |
|
@@ -208,19 +222,19 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
208 |
action_name = env_action_map.get(action, "")
|
209 |
|
210 |
cv2.putText(
|
211 |
-
|
212 |
action_name,
|
213 |
(
|
214 |
int(
|
215 |
(action + 0.5) * frame_policy_w // len(curr_policy)
|
216 |
- 5 * len(action_name)
|
217 |
),
|
218 |
-
|
219 |
),
|
220 |
cv2.FONT_HERSHEY_SIMPLEX,
|
221 |
-
0.
|
222 |
1.0,
|
223 |
-
|
224 |
cv2.LINE_AA,
|
225 |
)
|
226 |
|
@@ -228,7 +242,7 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
228 |
f"Episode: {ep_str(episode + 1)} - step: {step_str(step)} - state: {state} - action: {action} - reward: {reward} (epsilon: {live_epsilon:.2f}) (frame time: {1 / live_render_fps:.2f}s)"
|
229 |
)
|
230 |
|
231 |
-
yield agent_type, env_name,
|
232 |
episode + 1
|
233 |
), ep_str(episodes_solved), step_str(
|
234 |
step
|
@@ -245,37 +259,39 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
245 |
time.sleep(1 / live_render_fps)
|
246 |
|
247 |
while live_paused and live_steps_forward is None:
|
248 |
-
yield agent_type, env_name,
|
249 |
episode + 1
|
250 |
), ep_str(episodes_solved), step_str(
|
251 |
step
|
252 |
), state, action, last_reward, "Paused..."
|
253 |
time.sleep(1 / live_render_fps)
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
|
|
272 |
|
273 |
if solved:
|
274 |
episodes_solved += 1
|
275 |
|
276 |
-
time.sleep(0.
|
277 |
|
278 |
-
|
|
|
279 |
episodes_solved
|
280 |
), step_str(step), state, action, reward, "Done!"
|
281 |
|
@@ -376,7 +392,7 @@ with gr.Blocks(title="CS581 Demo") as demo:
|
|
376 |
label="Status Message",
|
377 |
)
|
378 |
|
379 |
-
|
380 |
|
381 |
btn_run.click(
|
382 |
fn=run,
|
|
|
8 |
|
9 |
default_n_test_episodes = 10
|
10 |
default_max_steps = 500
|
11 |
+
default_render_fps = 5
|
12 |
+
default_epsilon = 0.0
|
13 |
+
default_paused = True
|
14 |
|
15 |
+
frame_env_h, frame_env_w = 256, 768
|
16 |
frame_policy_w = 384
|
17 |
|
18 |
# For the dropdown list of policies
|
|
|
53 |
pause_val_map_inv = {v: k for k, v in pause_val_map.items()}
|
54 |
|
55 |
# Global variables to allow changing it on the fly
|
56 |
+
is_running = False
|
57 |
+
live_render_fps = default_render_fps
|
58 |
+
live_epsilon = default_epsilon
|
59 |
+
live_paused = default_paused
|
60 |
live_steps_forward = None
|
61 |
should_reset = False
|
62 |
|
63 |
|
64 |
+
def reset():
|
65 |
+
global is_running, live_render_fps, live_epsilon, live_paused, live_steps_forward, should_reset
|
66 |
+
if is_running:
|
67 |
+
should_reset = True
|
68 |
+
live_paused = default_paused
|
69 |
+
live_render_fps = default_render_fps
|
70 |
+
live_epsilon = default_epsilon
|
71 |
+
live_steps_forward = None
|
72 |
+
return gr.update(value=pause_val_map_inv[not live_paused]), gr.update(
|
73 |
+
interactive=live_paused
|
74 |
+
)
|
75 |
|
76 |
|
77 |
def change_render_fps(x):
|
|
|
90 |
print("Changing paused:", x)
|
91 |
global live_paused
|
92 |
live_paused = pause_val_map[x]
|
93 |
+
return gr.update(value=pause_val_map_inv[not live_paused]), gr.update(
|
94 |
+
interactive=live_paused
|
95 |
+
)
|
96 |
|
97 |
|
98 |
def onclick_btn_forward():
|
|
|
104 |
|
105 |
|
106 |
def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
107 |
+
global is_running, live_render_fps, live_epsilon, live_paused, live_steps_forward, should_reset
|
108 |
+
is_running = True
|
109 |
live_render_fps = render_fps
|
110 |
live_epsilon = epsilon
|
111 |
live_steps_forward = None
|
|
|
130 |
agent.load_policy(policy_path)
|
131 |
env_action_map = action_map.get(env_name)
|
132 |
|
133 |
+
solved, frame_env, frame_policy = None, None, None
|
134 |
episode, step, state, action, reward, last_reward = (
|
135 |
None,
|
136 |
None,
|
|
|
150 |
return f"{step + 1}"
|
151 |
|
152 |
for episode in range(n_test_episodes):
|
153 |
+
time.sleep(0.25)
|
154 |
|
155 |
+
for step, (episode_hist, solved, frame_env) in enumerate(
|
156 |
agent.generate_episode(
|
157 |
max_steps=max_steps, render=True, epsilon_override=live_epsilon
|
158 |
)
|
|
|
163 |
state, action, reward = episode_hist[-1]
|
164 |
curr_policy = agent.Pi[state]
|
165 |
|
166 |
+
frame_env = cv2.resize(
|
167 |
+
frame_env,
|
168 |
+
(
|
169 |
+
int(frame_env.shape[1] / frame_env.shape[0] * frame_env_h),
|
170 |
+
frame_env_h,
|
171 |
+
),
|
172 |
+
interpolation=cv2.INTER_AREA,
|
173 |
+
)
|
174 |
+
|
175 |
+
if frame_env.shape[1] < frame_env_w:
|
176 |
+
rgb_array_new = np.pad(
|
177 |
+
frame_env,
|
178 |
+
(
|
179 |
+
(0, 0),
|
180 |
+
(
|
181 |
+
(frame_env_w - frame_env.shape[1]) // 2,
|
182 |
+
(frame_env_w - frame_env.shape[1]) // 2,
|
183 |
+
),
|
184 |
+
(0, 0),
|
185 |
+
),
|
186 |
+
"constant",
|
187 |
+
)
|
188 |
+
frame_env = np.uint8(rgb_array_new)
|
189 |
+
|
190 |
+
frame_policy_h = frame_policy_w // len(curr_policy)
|
191 |
+
frame_policy = np.zeros((frame_policy_h, frame_policy_w))
|
192 |
for i, p in enumerate(curr_policy):
|
193 |
+
frame_policy[
|
194 |
:,
|
195 |
i
|
196 |
* (frame_policy_w // len(curr_policy)) : (i + 1)
|
197 |
* (frame_policy_w // len(curr_policy)),
|
198 |
] = p
|
199 |
|
200 |
+
frame_policy = scipy.ndimage.gaussian_filter(frame_policy, sigma=1.0)
|
201 |
+
frame_policy = np.clip(
|
202 |
+
frame_policy * (1.0 - live_epsilon) + live_epsilon / len(curr_policy),
|
203 |
0.0,
|
204 |
1.0,
|
205 |
)
|
206 |
|
207 |
cv2.putText(
|
208 |
+
frame_policy,
|
209 |
str(action),
|
210 |
(
|
211 |
int((action + 0.5) * frame_policy_w // len(curr_policy) - 8),
|
212 |
+
frame_policy_h // 2 - 5,
|
213 |
),
|
214 |
cv2.FONT_HERSHEY_SIMPLEX,
|
215 |
+
0.8,
|
216 |
1.0,
|
217 |
+
1,
|
|
|
218 |
cv2.LINE_AA,
|
219 |
)
|
220 |
|
|
|
222 |
action_name = env_action_map.get(action, "")
|
223 |
|
224 |
cv2.putText(
|
225 |
+
frame_policy,
|
226 |
action_name,
|
227 |
(
|
228 |
int(
|
229 |
(action + 0.5) * frame_policy_w // len(curr_policy)
|
230 |
- 5 * len(action_name)
|
231 |
),
|
232 |
+
frame_policy_h // 2 + 25,
|
233 |
),
|
234 |
cv2.FONT_HERSHEY_SIMPLEX,
|
235 |
+
0.5,
|
236 |
1.0,
|
237 |
+
1,
|
238 |
cv2.LINE_AA,
|
239 |
)
|
240 |
|
|
|
242 |
f"Episode: {ep_str(episode + 1)} - step: {step_str(step)} - state: {state} - action: {action} - reward: {reward} (epsilon: {live_epsilon:.2f}) (frame time: {1 / live_render_fps:.2f}s)"
|
243 |
)
|
244 |
|
245 |
+
yield agent_type, env_name, frame_env, frame_policy, ep_str(
|
246 |
episode + 1
|
247 |
), ep_str(episodes_solved), step_str(
|
248 |
step
|
|
|
259 |
time.sleep(1 / live_render_fps)
|
260 |
|
261 |
while live_paused and live_steps_forward is None:
|
262 |
+
yield agent_type, env_name, frame_env, frame_policy, ep_str(
|
263 |
episode + 1
|
264 |
), ep_str(episodes_solved), step_str(
|
265 |
step
|
266 |
), state, action, last_reward, "Paused..."
|
267 |
time.sleep(1 / live_render_fps)
|
268 |
+
if should_reset is True:
|
269 |
+
break
|
270 |
+
|
271 |
+
if should_reset is True:
|
272 |
+
should_reset = False
|
273 |
+
yield (
|
274 |
+
agent_type,
|
275 |
+
env_name,
|
276 |
+
np.ones((frame_env_h, frame_env_w, 3)),
|
277 |
+
np.ones((frame_policy_h, frame_policy_w)),
|
278 |
+
ep_str(episode + 1),
|
279 |
+
ep_str(episodes_solved),
|
280 |
+
step_str(step),
|
281 |
+
state,
|
282 |
+
action,
|
283 |
+
last_reward,
|
284 |
+
"Reset...",
|
285 |
+
)
|
286 |
+
return
|
287 |
|
288 |
if solved:
|
289 |
episodes_solved += 1
|
290 |
|
291 |
+
time.sleep(0.25)
|
292 |
|
293 |
+
is_running = False
|
294 |
+
yield agent_type, env_name, frame_env, frame_policy, ep_str(episode + 1), ep_str(
|
295 |
episodes_solved
|
296 |
), step_str(step), state, action, reward, "Done!"
|
297 |
|
|
|
392 |
label="Status Message",
|
393 |
)
|
394 |
|
395 |
+
input_policy.change(fn=reset, outputs=[btn_pause, btn_forward])
|
396 |
|
397 |
btn_run.click(
|
398 |
fn=run,
|