Spaces:
Sleeping
Sleeping
Andrei Cozma
commited on
Commit
·
668f525
1
Parent(s):
49c1e2a
Updates
Browse files
Shared.py
CHANGED
@@ -2,10 +2,11 @@ import os
|
|
2 |
import numpy as np
|
3 |
import gymnasium as gym
|
4 |
|
|
|
5 |
class Shared:
|
6 |
-
|
7 |
def __init__(
|
8 |
-
self
|
|
|
9 |
env="CliffWalking-v0",
|
10 |
gamma=0.99,
|
11 |
epsilon=0.1,
|
@@ -20,8 +21,9 @@ class Shared:
|
|
20 |
self.run_name = run_name
|
21 |
self.env_name = env
|
22 |
self.epsilon, self.gamma = epsilon, gamma
|
|
|
23 |
|
24 |
-
self.env_kwargs = {k:v for k,v in kwargs.items() if k in [
|
25 |
if self.env_name == "FrozenLake-v1":
|
26 |
# Can use defaults by defining map_name (4x4 or 8x8) or custom map by defining desc
|
27 |
# self.env_kwargs["map_name"] = "8x8"
|
@@ -46,23 +48,24 @@ class Shared:
|
|
46 |
print(f"- n_states: {self.n_states}")
|
47 |
print(f"- n_actions: {self.n_actions}")
|
48 |
|
49 |
-
def choose_action(self, state,
|
50 |
# Sample an action from the policy.
|
51 |
# The epsilon_override argument allows forcing the use of a new epsilon value than the one previously used during training.
|
52 |
# The ability to override was mostly added for testing purposes and for the demo.
|
53 |
greedy_action = np.argmax(self.Pi[state])
|
54 |
|
55 |
-
if greedy or epsilon_override == 0:
|
56 |
return greedy_action
|
57 |
|
58 |
-
if epsilon_override is None:
|
59 |
return np.random.choice(self.n_actions, p=self.Pi[state])
|
60 |
|
|
|
61 |
return np.random.choice(
|
62 |
[greedy_action, np.random.randint(self.n_actions)],
|
63 |
-
p=[1 - epsilon_override, epsilon_override],
|
64 |
)
|
65 |
-
|
66 |
def generate_episode(self, max_steps=500, render=False, **kwargs):
|
67 |
state, _ = self.env.reset()
|
68 |
episode_hist, solved, rgb_array = (
|
@@ -118,9 +121,9 @@ class Shared:
|
|
118 |
|
119 |
def run_episode(self, max_steps=500, render=False, **kwargs):
|
120 |
# Run the generator until the end
|
121 |
-
episode_hist, solved, rgb_array = list(
|
122 |
-
max_steps, render, **kwargs
|
123 |
-
)
|
124 |
return episode_hist, solved, rgb_array
|
125 |
|
126 |
def test(self, n_test_episodes=100, verbose=True, greedy=True, **kwargs):
|
@@ -143,7 +146,7 @@ class Shared:
|
|
143 |
f"Agent reached the goal in {num_successes}/{n_test_episodes} episodes ({success_rate * 100:.2f}%)"
|
144 |
)
|
145 |
return success_rate
|
146 |
-
|
147 |
def save_policy(self, fname="policy.npy", save_dir=None):
|
148 |
if save_dir is not None:
|
149 |
os.makedirs(save_dir, exist_ok=True)
|
|
|
2 |
import numpy as np
|
3 |
import gymnasium as gym
|
4 |
|
5 |
+
|
6 |
class Shared:
|
|
|
7 |
def __init__(
|
8 |
+
self,
|
9 |
+
/,
|
10 |
env="CliffWalking-v0",
|
11 |
gamma=0.99,
|
12 |
epsilon=0.1,
|
|
|
21 |
self.run_name = run_name
|
22 |
self.env_name = env
|
23 |
self.epsilon, self.gamma = epsilon, gamma
|
24 |
+
self.epsilon_override = None
|
25 |
|
26 |
+
self.env_kwargs = {k: v for k, v in kwargs.items() if k in ["render_mode"]}
|
27 |
if self.env_name == "FrozenLake-v1":
|
28 |
# Can use defaults by defining map_name (4x4 or 8x8) or custom map by defining desc
|
29 |
# self.env_kwargs["map_name"] = "8x8"
|
|
|
48 |
print(f"- n_states: {self.n_states}")
|
49 |
print(f"- n_actions: {self.n_actions}")
|
50 |
|
51 |
+
def choose_action(self, state, greedy=False, **kwargs):
|
52 |
# Sample an action from the policy.
|
53 |
# The epsilon_override argument allows forcing the use of a new epsilon value than the one previously used during training.
|
54 |
# The ability to override was mostly added for testing purposes and for the demo.
|
55 |
greedy_action = np.argmax(self.Pi[state])
|
56 |
|
57 |
+
if greedy or self.epsilon_override == 0.0:
|
58 |
return greedy_action
|
59 |
|
60 |
+
if self.epsilon_override is None:
|
61 |
return np.random.choice(self.n_actions, p=self.Pi[state])
|
62 |
|
63 |
+
print("epsilon_override", self.epsilon_override)
|
64 |
return np.random.choice(
|
65 |
[greedy_action, np.random.randint(self.n_actions)],
|
66 |
+
p=[1.0 - self.epsilon_override, self.epsilon_override],
|
67 |
)
|
68 |
+
|
69 |
def generate_episode(self, max_steps=500, render=False, **kwargs):
|
70 |
state, _ = self.env.reset()
|
71 |
episode_hist, solved, rgb_array = (
|
|
|
121 |
|
122 |
def run_episode(self, max_steps=500, render=False, **kwargs):
|
123 |
# Run the generator until the end
|
124 |
+
episode_hist, solved, rgb_array = list(
|
125 |
+
self.generate_episode(max_steps, render, **kwargs)
|
126 |
+
)[-1]
|
127 |
return episode_hist, solved, rgb_array
|
128 |
|
129 |
def test(self, n_test_episodes=100, verbose=True, greedy=True, **kwargs):
|
|
|
146 |
f"Agent reached the goal in {num_successes}/{n_test_episodes} episodes ({success_rate * 100:.2f}%)"
|
147 |
)
|
148 |
return success_rate
|
149 |
+
|
150 |
def save_policy(self, fname="policy.npy", save_dir=None):
|
151 |
if save_dir is not None:
|
152 |
os.makedirs(save_dir, exist_ok=True)
|
demo.py
CHANGED
@@ -81,12 +81,24 @@ def change_render_fps(state, x):
|
|
81 |
return state
|
82 |
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
def change_epsilon(state, x):
|
85 |
print("Changing greediness:", x)
|
86 |
state.live_epsilon = x
|
87 |
return state
|
88 |
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
def change_paused(state, x):
|
91 |
print("Changing paused:", x)
|
92 |
state.live_paused = pause_val_map[x]
|
@@ -159,9 +171,9 @@ def run(
|
|
159 |
agent.generate_episode(
|
160 |
max_steps=max_steps,
|
161 |
render=True,
|
162 |
-
epsilon_override=localstate.live_epsilon,
|
163 |
)
|
164 |
):
|
|
|
165 |
_, _, last_reward = (
|
166 |
episode_hist[-2] if len(episode_hist) > 1 else (None, None, None)
|
167 |
)
|
@@ -207,7 +219,7 @@ def run(
|
|
207 |
str(action),
|
208 |
(
|
209 |
label_loc_w - label_width // 2,
|
210 |
-
|
211 |
),
|
212 |
frame_policy_label_font,
|
213 |
action_text_scale,
|
@@ -230,9 +242,7 @@ def run(
|
|
230 |
action_name,
|
231 |
(
|
232 |
int(label_loc_w - label_width / 2),
|
233 |
-
frame_policy_h
|
234 |
-
- (frame_policy_h - label_loc_h) // 2
|
235 |
-
+ label_height // 2,
|
236 |
),
|
237 |
frame_policy_label_font,
|
238 |
action_text_label_scale,
|
@@ -363,7 +373,14 @@ with gr.Blocks(title="CS581 Demo") as demo:
|
|
363 |
label="Epsilon (0 = greedy, 1 = random)",
|
364 |
)
|
365 |
input_epsilon.change(
|
366 |
-
change_epsilon,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
)
|
368 |
|
369 |
input_render_fps = gr.components.Slider(
|
@@ -378,6 +395,11 @@ with gr.Blocks(title="CS581 Demo") as demo:
|
|
378 |
inputs=[localstate, input_render_fps],
|
379 |
outputs=[localstate],
|
380 |
)
|
|
|
|
|
|
|
|
|
|
|
381 |
|
382 |
out_image_frame = gr.components.Image(
|
383 |
label="Environment",
|
|
|
81 |
return state
|
82 |
|
83 |
|
84 |
+
def change_render_fps_update(state, x):
|
85 |
+
print("Changing render fps:", x)
|
86 |
+
state.live_render_fps = x
|
87 |
+
return state, gr.update(value=x)
|
88 |
+
|
89 |
+
|
90 |
def change_epsilon(state, x):
|
91 |
print("Changing greediness:", x)
|
92 |
state.live_epsilon = x
|
93 |
return state
|
94 |
|
95 |
|
96 |
+
def change_epsilon_update(state, x):
|
97 |
+
print("Changing greediness:", x)
|
98 |
+
state.live_epsilon = x
|
99 |
+
return state, gr.update(value=x)
|
100 |
+
|
101 |
+
|
102 |
def change_paused(state, x):
|
103 |
print("Changing paused:", x)
|
104 |
state.live_paused = pause_val_map[x]
|
|
|
171 |
agent.generate_episode(
|
172 |
max_steps=max_steps,
|
173 |
render=True,
|
|
|
174 |
)
|
175 |
):
|
176 |
+
agent.epsilon_override = localstate.live_epsilon
|
177 |
_, _, last_reward = (
|
178 |
episode_hist[-2] if len(episode_hist) > 1 else (None, None, None)
|
179 |
)
|
|
|
219 |
str(action),
|
220 |
(
|
221 |
label_loc_w - label_width // 2,
|
222 |
+
frame_policy_h // 3 + label_height // 2,
|
223 |
),
|
224 |
frame_policy_label_font,
|
225 |
action_text_scale,
|
|
|
242 |
action_name,
|
243 |
(
|
244 |
int(label_loc_w - label_width / 2),
|
245 |
+
frame_policy_h - frame_policy_h // 3 + label_height // 2,
|
|
|
|
|
246 |
),
|
247 |
frame_policy_label_font,
|
248 |
action_text_label_scale,
|
|
|
373 |
label="Epsilon (0 = greedy, 1 = random)",
|
374 |
)
|
375 |
input_epsilon.change(
|
376 |
+
change_epsilon,
|
377 |
+
inputs=[localstate, input_epsilon],
|
378 |
+
outputs=[localstate],
|
379 |
+
)
|
380 |
+
input_epsilon.release(
|
381 |
+
change_epsilon_update,
|
382 |
+
inputs=[localstate, input_epsilon],
|
383 |
+
outputs=[localstate, input_epsilon],
|
384 |
)
|
385 |
|
386 |
input_render_fps = gr.components.Slider(
|
|
|
395 |
inputs=[localstate, input_render_fps],
|
396 |
outputs=[localstate],
|
397 |
)
|
398 |
+
input_render_fps.release(
|
399 |
+
change_render_fps_update,
|
400 |
+
inputs=[localstate, input_render_fps],
|
401 |
+
outputs=[localstate, input_render_fps],
|
402 |
+
)
|
403 |
|
404 |
out_image_frame = gr.components.Image(
|
405 |
label="Environment",
|