Spaces:
Sleeping
Sleeping
Andrei Cozma
commited on
Commit
·
4567d2b
1
Parent(s):
8ceccef
Updates
Browse files
MonteCarloAgent.py
CHANGED
@@ -43,36 +43,18 @@ class MonteCarloAgent:
|
|
43 |
print(self.Pi)
|
44 |
print("=" * 80)
|
45 |
|
46 |
-
def choose_action(self, state):
|
47 |
-
# Sample an action from the policy
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
# action = self.choose_action(state)
|
59 |
-
# # Take the action and observe the reward and next state
|
60 |
-
# next_state, reward, done, truncated, _ = self.env.step(action)
|
61 |
-
# # Keeping track of the trajectory
|
62 |
-
# episode_hist.append((state, action, reward))
|
63 |
-
# state = next_state
|
64 |
-
|
65 |
-
# # This is where the agent got to the goal.
|
66 |
-
# # In the case in which agent jumped off the cliff, it is simply respawned at the start position without termination.
|
67 |
-
# if done:
|
68 |
-
# solved = True
|
69 |
-
# break
|
70 |
-
# if truncated:
|
71 |
-
# break
|
72 |
-
|
73 |
-
# rgb_array = self.env.render() if render else None
|
74 |
-
|
75 |
-
# return episode_hist, solved, rgb_array
|
76 |
|
77 |
def generate_episode(self, max_steps=500, render=False, **kwargs):
|
78 |
state, _ = self.env.reset()
|
@@ -82,7 +64,7 @@ class MonteCarloAgent:
|
|
82 |
for _ in range(max_steps):
|
83 |
rgb_array = self.env.render() if render else None
|
84 |
# Sample an action from the policy
|
85 |
-
action = self.choose_action(state)
|
86 |
# Take the action and observe the reward and next state
|
87 |
next_state, reward, done, truncated, _ = self.env.step(action)
|
88 |
# Keeping track of the trajectory
|
@@ -319,7 +301,7 @@ def main():
|
|
319 |
parser.add_argument(
|
320 |
"--epsilon",
|
321 |
type=float,
|
322 |
-
default=0.
|
323 |
help="The value for the epsilon-greedy policy to use. (default: 0.1)",
|
324 |
)
|
325 |
|
|
|
43 |
print(self.Pi)
|
44 |
print("=" * 80)
|
45 |
|
46 |
+
def choose_action(self, state, override_epsilon=False, **kwargs):
|
47 |
+
# Sample an action from the policy.
|
48 |
+
# The override_epsilon argument allows forcing the use of a possibly new self.epsilon value than the one used during training.
|
49 |
+
# The ability to override was mostly added for testing purposes and for the demo.
|
50 |
+
|
51 |
+
if override_epsilon is False:
|
52 |
+
return np.random.choice(self.n_actions, p=self.Pi[state])
|
53 |
+
|
54 |
+
return np.random.choice(
|
55 |
+
[np.argmax(self.Pi[state]), np.random.randint(self.n_actions)],
|
56 |
+
p=[1 - self.epsilon, self.epsilon],
|
57 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
def generate_episode(self, max_steps=500, render=False, **kwargs):
|
60 |
state, _ = self.env.reset()
|
|
|
64 |
for _ in range(max_steps):
|
65 |
rgb_array = self.env.render() if render else None
|
66 |
# Sample an action from the policy
|
67 |
+
action = self.choose_action(state, **kwargs)
|
68 |
# Take the action and observe the reward and next state
|
69 |
next_state, reward, done, truncated, _ = self.env.step(action)
|
70 |
# Keeping track of the trajectory
|
|
|
301 |
parser.add_argument(
|
302 |
"--epsilon",
|
303 |
type=float,
|
304 |
+
default=0.1,
|
305 |
help="The value for the epsilon-greedy policy to use. (default: 0.1)",
|
306 |
)
|
307 |
|
demo.py
CHANGED
@@ -1,13 +1,20 @@
|
|
1 |
import os
|
2 |
import time
|
|
|
3 |
import numpy as np
|
4 |
import gradio as gr
|
5 |
from MonteCarloAgent import MonteCarloAgent
|
6 |
-
|
7 |
|
8 |
# For the dropdown list of policies
|
9 |
policies_folder = "policies"
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# All supported agents
|
13 |
agent_map = {
|
@@ -15,45 +22,50 @@ agent_map = {
|
|
15 |
# TODO: Add DP Agent
|
16 |
}
|
17 |
|
18 |
-
# Global variables
|
19 |
-
|
20 |
-
|
|
|
21 |
|
22 |
|
23 |
-
def
|
24 |
-
print("
|
25 |
-
|
26 |
-
|
27 |
-
policy_path = os.path.join(policies_folder, policy_fname)
|
28 |
-
props = policy_fname.split("_")
|
29 |
-
agent_type, env_name = props[0], props[1]
|
30 |
|
31 |
-
agent = agent_map[agent_type](env_name, render_mode="rgb_array")
|
32 |
-
agent.load_policy(policy_path)
|
33 |
|
34 |
-
|
|
|
|
|
|
|
35 |
|
36 |
|
37 |
-
def
|
38 |
-
print("Changing
|
39 |
-
global
|
40 |
-
|
|
|
|
|
41 |
|
42 |
|
43 |
-
def run(n_test_episodes, max_steps, render_fps):
|
44 |
-
global
|
45 |
-
|
|
|
46 |
print("Running...")
|
47 |
print(f"- n_test_episodes: {n_test_episodes}")
|
48 |
print(f"- max_steps: {max_steps}")
|
49 |
-
print(f"- render_fps: {
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
55 |
|
56 |
-
rgb_array =
|
|
|
57 |
episode, step = 0, 0
|
58 |
state, action, reward = 0, 0, 0
|
59 |
episodes_solved = 0
|
@@ -66,97 +78,162 @@ def run(n_test_episodes, max_steps, render_fps):
|
|
66 |
|
67 |
for episode in range(n_test_episodes):
|
68 |
for step, (episode_hist, solved, rgb_array) in enumerate(
|
69 |
-
agent.generate_episode(
|
|
|
|
|
70 |
):
|
|
|
|
|
|
|
71 |
if solved:
|
72 |
episodes_solved += 1
|
73 |
state, action, reward = episode_hist[-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
print(
|
76 |
-
f"Episode: {ep_str(episode)} - step: {step_str} - state: {state} - action: {action} - reward: {reward} (frame time: {1 / render_fps:.2f}s)"
|
77 |
)
|
78 |
|
79 |
-
time.sleep(1 /
|
80 |
-
|
|
|
|
|
|
|
|
|
81 |
|
82 |
-
yield
|
|
|
|
|
83 |
|
84 |
|
85 |
-
with gr.Blocks() as demo:
|
86 |
-
|
|
|
|
|
|
|
87 |
|
|
|
88 |
with gr.Row():
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
out_agent = gr.components.Textbox(label="Agent")
|
97 |
-
|
98 |
-
btn_load = gr.components.Button("📁 Load")
|
99 |
-
btn_load.click(
|
100 |
-
fn=load_policy,
|
101 |
-
inputs=[input_policy],
|
102 |
-
outputs=[out_environment, out_agent],
|
103 |
-
)
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
value=5,
|
123 |
-
label="Render FPS",
|
124 |
-
)
|
125 |
-
input_render_fps.change(change_render_fps, inputs=[input_render_fps])
|
126 |
-
|
127 |
-
btn_run = gr.components.Button("▶️ Run")
|
128 |
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
with gr.Row():
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
btn_run.click(
|
142 |
fn=run,
|
143 |
inputs=[
|
|
|
144 |
input_n_test_episodes,
|
145 |
input_max_steps,
|
146 |
input_render_fps,
|
|
|
147 |
],
|
148 |
outputs=[
|
149 |
-
|
|
|
|
|
|
|
150 |
out_episode,
|
|
|
151 |
out_step,
|
152 |
out_state,
|
153 |
out_action,
|
154 |
out_reward,
|
155 |
-
out_eps_solved,
|
156 |
out_msg,
|
157 |
],
|
158 |
)
|
159 |
|
160 |
-
|
161 |
-
demo.queue(concurrency_count=2)
|
162 |
demo.launch()
|
|
|
1 |
import os
|
2 |
import time
|
3 |
+
from matplotlib import interactive
|
4 |
import numpy as np
|
5 |
import gradio as gr
|
6 |
from MonteCarloAgent import MonteCarloAgent
|
7 |
+
import scipy.ndimage
|
8 |
|
9 |
# For the dropdown list of policies
|
10 |
policies_folder = "policies"
|
11 |
+
try:
|
12 |
+
all_policies = [
|
13 |
+
file for file in os.listdir(policies_folder) if file.endswith(".npy")
|
14 |
+
]
|
15 |
+
except FileNotFoundError:
|
16 |
+
print("ERROR: No policies folder found!")
|
17 |
+
all_policies = []
|
18 |
|
19 |
# All supported agents
|
20 |
agent_map = {
|
|
|
22 |
# TODO: Add DP Agent
|
23 |
}
|
24 |
|
25 |
+
# Global variables to allow changing it on the fly
|
26 |
+
live_render_fps = 10
|
27 |
+
live_epsilon = 0.0
|
28 |
+
live_paused = False
|
29 |
|
30 |
|
31 |
+
def change_render_fps(x):
|
32 |
+
print("Changing render fps:", x)
|
33 |
+
global live_render_fps
|
34 |
+
live_render_fps = x
|
|
|
|
|
|
|
35 |
|
|
|
|
|
36 |
|
37 |
+
def change_epsilon(x):
|
38 |
+
print("Changing greediness:", x)
|
39 |
+
global live_epsilon
|
40 |
+
live_epsilon = x
|
41 |
|
42 |
|
43 |
+
def change_paused(x):
|
44 |
+
print("Changing paused:", x)
|
45 |
+
global live_paused
|
46 |
+
live_paused = x
|
47 |
+
# change the text to resume
|
48 |
+
return gr.update(value="▶️ Resume" if x else "⏸️ Pause")
|
49 |
|
50 |
|
51 |
+
def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
52 |
+
global live_render_fps, live_epsilon
|
53 |
+
live_render_fps = render_fps
|
54 |
+
live_epsilon = epsilon
|
55 |
print("Running...")
|
56 |
print(f"- n_test_episodes: {n_test_episodes}")
|
57 |
print(f"- max_steps: {max_steps}")
|
58 |
+
print(f"- render_fps: {live_render_fps}")
|
59 |
|
60 |
+
policy_path = os.path.join(policies_folder, policy_fname)
|
61 |
+
props = policy_fname.split("_")
|
62 |
+
agent_type, env_name = props[0], props[1]
|
63 |
+
|
64 |
+
agent = agent_map[agent_type](env_name, render_mode="rgb_array")
|
65 |
+
agent.load_policy(policy_path)
|
66 |
|
67 |
+
rgb_array = None
|
68 |
+
policy_viz = None
|
69 |
episode, step = 0, 0
|
70 |
state, action, reward = 0, 0, 0
|
71 |
episodes_solved = 0
|
|
|
78 |
|
79 |
for episode in range(n_test_episodes):
|
80 |
for step, (episode_hist, solved, rgb_array) in enumerate(
|
81 |
+
agent.generate_episode(
|
82 |
+
max_steps=max_steps, render=True, override_epsilon=True
|
83 |
+
)
|
84 |
):
|
85 |
+
while live_paused:
|
86 |
+
time.sleep(0.1)
|
87 |
+
|
88 |
if solved:
|
89 |
episodes_solved += 1
|
90 |
state, action, reward = episode_hist[-1]
|
91 |
+
curr_policy = agent.Pi[state]
|
92 |
+
|
93 |
+
viz_w, viz_h = 128, 16
|
94 |
+
policy_viz = np.zeros((viz_h, viz_w))
|
95 |
+
for i, p in enumerate(curr_policy):
|
96 |
+
policy_viz[
|
97 |
+
:,
|
98 |
+
i
|
99 |
+
* (viz_w // len(curr_policy)) : (i + 1)
|
100 |
+
* (viz_w // len(curr_policy)),
|
101 |
+
] = p
|
102 |
+
|
103 |
+
policy_viz = scipy.ndimage.gaussian_filter(policy_viz, sigma=1)
|
104 |
+
policy_viz = np.clip(
|
105 |
+
policy_viz * (1 - live_epsilon) + live_epsilon / len(curr_policy), 0, 1
|
106 |
+
)
|
107 |
|
108 |
print(
|
109 |
+
f"Episode: {ep_str(episode)} - step: {step_str(step)} - state: {state} - action: {action} - reward: {reward} (frame time: {1 / render_fps:.2f}s)"
|
110 |
)
|
111 |
|
112 |
+
time.sleep(1 / live_render_fps)
|
113 |
+
# Live-update the agent's epsilon value for demonstration purposes
|
114 |
+
agent.epsilon = live_epsilon
|
115 |
+
yield agent_type, env_name, rgb_array, policy_viz, ep_str(episode), ep_str(
|
116 |
+
episodes_solved
|
117 |
+
), step_str(step), state, action, reward, "Running..."
|
118 |
|
119 |
+
yield agent_type, env_name, rgb_array, policy_viz, ep_str(episode), ep_str(
|
120 |
+
episodes_solved
|
121 |
+
), step_str(step), state, action, reward, "Done!"
|
122 |
|
123 |
|
124 |
+
with gr.Blocks(title="CS581 Demo") as demo:
|
125 |
+
gr.components.HTML(
|
126 |
+
"<h1>Reinforcement Learning: From Dynamic Programming to Monte-Carlo (Demo)</h1>"
|
127 |
+
)
|
128 |
+
gr.components.HTML("<h3>Authors: Andrei Cozma and Landon Harris</h3>")
|
129 |
|
130 |
+
gr.components.HTML("<h2>Select Configuration:</h2>")
|
131 |
with gr.Row():
|
132 |
+
input_policy = gr.components.Dropdown(
|
133 |
+
label="Policy Checkpoint",
|
134 |
+
choices=all_policies,
|
135 |
+
value=all_policies[0] if all_policies else "No policies found :(",
|
136 |
+
)
|
137 |
|
138 |
+
out_environment = gr.components.Textbox(label="Resolved Environment")
|
139 |
+
out_agent = gr.components.Textbox(label="Resolved Agent")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
+
with gr.Row():
|
142 |
+
input_n_test_episodes = gr.components.Slider(
|
143 |
+
minimum=1,
|
144 |
+
maximum=500,
|
145 |
+
value=500,
|
146 |
+
label="Number of episodes",
|
147 |
+
)
|
148 |
+
input_max_steps = gr.components.Slider(
|
149 |
+
minimum=1,
|
150 |
+
maximum=500,
|
151 |
+
value=500,
|
152 |
+
label="Max steps per episode",
|
153 |
+
)
|
154 |
+
|
155 |
+
btn_run = gr.components.Button(
|
156 |
+
"▶️ Start", interactive=True if all_policies else False
|
157 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
+
gr.components.HTML("<h2>Live Statistics & Policy Visualization:</h2>")
|
160 |
+
with gr.Row():
|
161 |
+
with gr.Column():
|
162 |
+
with gr.Row():
|
163 |
+
out_episode = gr.components.Textbox(label="Current Episode")
|
164 |
+
out_step = gr.components.Textbox(label="Current Step")
|
165 |
+
out_eps_solved = gr.components.Textbox(label="Episodes Solved")
|
166 |
|
167 |
+
with gr.Row():
|
168 |
+
out_state = gr.components.Textbox(label="Current State")
|
169 |
+
out_action = gr.components.Textbox(label="Chosen Action")
|
170 |
+
out_reward = gr.components.Textbox(label="Reward Received")
|
171 |
+
|
172 |
+
out_image_policy = gr.components.Image(
|
173 |
+
value=np.ones((16, 128)),
|
174 |
+
label="policy[state]",
|
175 |
+
type="numpy",
|
176 |
+
image_mode="RGB",
|
177 |
+
)
|
178 |
+
|
179 |
+
gr.components.HTML("<h2>Live Customization:</h2>")
|
180 |
with gr.Row():
|
181 |
+
input_epsilon = gr.components.Slider(
|
182 |
+
minimum=0,
|
183 |
+
maximum=1,
|
184 |
+
value=live_epsilon,
|
185 |
+
label="Epsilon (0 = greedy, 1 = random)",
|
186 |
+
)
|
187 |
+
input_epsilon.change(change_epsilon, inputs=[input_epsilon])
|
188 |
+
|
189 |
+
input_render_fps = gr.components.Slider(
|
190 |
+
minimum=1, maximum=60, value=live_render_fps, label="Simulation speed (fps)"
|
191 |
+
)
|
192 |
+
input_render_fps.change(change_render_fps, inputs=[input_render_fps])
|
193 |
+
|
194 |
+
out_image_frame = gr.components.Image(
|
195 |
+
label="Environment", type="numpy", image_mode="RGB"
|
196 |
+
)
|
197 |
|
198 |
+
with gr.Row():
|
199 |
+
# Pause/resume button
|
200 |
+
btn_pause = gr.components.Button("⏸️ Pause", interactive=True)
|
201 |
+
btn_pause.click(
|
202 |
+
fn=change_paused,
|
203 |
+
inputs=[btn_pause],
|
204 |
+
outputs=[btn_pause],
|
205 |
+
)
|
206 |
+
|
207 |
+
out_msg = gr.components.Textbox(
|
208 |
+
value=""
|
209 |
+
if all_policies
|
210 |
+
else "<h2>🚫 ERROR: No policies found! Please train an agent first or add a policy to the policies folder.<h2>",
|
211 |
+
label="Status Message",
|
212 |
+
)
|
213 |
|
214 |
btn_run.click(
|
215 |
fn=run,
|
216 |
inputs=[
|
217 |
+
input_policy,
|
218 |
input_n_test_episodes,
|
219 |
input_max_steps,
|
220 |
input_render_fps,
|
221 |
+
input_epsilon,
|
222 |
],
|
223 |
outputs=[
|
224 |
+
out_agent,
|
225 |
+
out_environment,
|
226 |
+
out_image_frame,
|
227 |
+
out_image_policy,
|
228 |
out_episode,
|
229 |
+
out_eps_solved,
|
230 |
out_step,
|
231 |
out_state,
|
232 |
out_action,
|
233 |
out_reward,
|
|
|
234 |
out_msg,
|
235 |
],
|
236 |
)
|
237 |
|
238 |
+
demo.queue(concurrency_count=3)
|
|
|
239 |
demo.launch()
|
policies/MonteCarloAgent_CliffWalking-v0_e2000_s500_g0.99_e0.1.npy
ADDED
Binary file (1.66 kB). View file
|
|
policies/MonteCarloAgent_CliffWalking-v0_e2000_s500_g0.99_e0.5.npy
DELETED
Binary file (1.66 kB)
|
|