Spaces:
Sleeping
Sleeping
Andrei Cozma
commited on
Commit
·
e24c7c0
1
Parent(s):
ed9cf21
Updates
Browse files
README.md
CHANGED
@@ -8,68 +8,76 @@ Evolution of Reinforcement Learning methods from pure Dynamic Programming-based
|
|
8 |
|
9 |
- Python 3
|
10 |
- Gymnasium: <https://pypi.org/project/gymnasium/>
|
11 |
-
- WandB: <https://pypi.org/project/wandb/>
|
|
|
12 |
|
13 |
-
##
|
|
|
|
|
14 |
|
15 |
-
|
16 |
|
17 |
-
|
18 |
|
19 |
-
###
|
20 |
|
21 |
```bash
|
22 |
-
|
23 |
```
|
24 |
|
25 |
-
|
26 |
|
27 |
-
|
28 |
|
29 |
-
|
30 |
|
31 |
```bash
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
### Visualization
|
36 |
|
37 |
-
|
38 |
python3 MonteCarloAgent.py --test policy_mc_CliffWalking-v0_e2000_s500_g0.99_e0.1.npy --render_mode human
|
39 |
```
|
40 |
|
41 |
-
###
|
42 |
|
43 |
-
```
|
44 |
-
usage: MonteCarloAgent.py [-h] [--train] [--test TEST] [--n_train_episodes N_TRAIN_EPISODES] [--n_test_episodes N_TEST_EPISODES] [--test_every TEST_EVERY] [--max_steps MAX_STEPS] [--
|
45 |
-
[--render_mode RENDER_MODE] [--wandb_project WANDB_PROJECT] [--wandb_group WANDB_GROUP]
|
|
|
46 |
|
47 |
options:
|
48 |
-h, --help show this help message and exit
|
49 |
-
--train Use this flag to train the agent.
|
50 |
--test TEST Use this flag to test the agent. Provide the path to the policy file.
|
51 |
--n_train_episodes N_TRAIN_EPISODES
|
52 |
-
The number of episodes to train for.
|
53 |
--n_test_episodes N_TEST_EPISODES
|
54 |
-
The number of episodes to test for.
|
55 |
--test_every TEST_EVERY
|
56 |
-
During training, test the agent every n episodes.
|
57 |
--max_steps MAX_STEPS
|
58 |
-
The maximum number of steps per episode before the episode is forced to end.
|
59 |
-
--
|
60 |
-
|
61 |
-
--
|
|
|
|
|
|
|
|
|
62 |
--render_mode RENDER_MODE
|
63 |
-
|
64 |
--wandb_project WANDB_PROJECT
|
65 |
-
WandB project name for logging. If not provided, no logging is done.
|
66 |
--wandb_group WANDB_GROUP
|
67 |
WandB group name for logging. (default: monte-carlo)
|
68 |
--wandb_job_type WANDB_JOB_TYPE
|
69 |
WandB job type for logging. (default: train)
|
|
|
|
|
70 |
```
|
71 |
|
72 |
-
## Presentation Guide
|
73 |
|
74 |
1. Title Slide: list the title of your talk along with your name
|
75 |
|
|
|
8 |
|
9 |
- Python 3
|
10 |
- Gymnasium: <https://pypi.org/project/gymnasium/>
|
11 |
+
- WandB: <https://pypi.org/project/wandb/>
|
12 |
+
- Gradio: <https://pypi.org/project/gradio/>
|
13 |
|
14 |
+
## Interactive Demo
|
15 |
+
|
16 |
+
TODO
|
17 |
|
18 |
+
## Dynamic-Programming Agent
|
19 |
|
20 |
+
TODO
|
21 |
|
22 |
+
### Usage
|
23 |
|
24 |
```bash
|
25 |
+
TODO
|
26 |
```
|
27 |
|
28 |
+
## Monte-Carlo Agent
|
29 |
|
30 |
+
The agent starts with a randomly initialized epsilon-greedy policy and uses either the first-visit or every-visit Monte-Carlo update method to learn the optimal policy.
|
31 |
|
32 |
+
Primarily tested on the [Cliff Walking](https://gymnasium.farama.org/environments/toy_text/cliff_walking/) toy environment.
|
33 |
|
34 |
```bash
|
35 |
+
# Training: Policy will be saved as a `.npy` file.
|
36 |
+
python3 MonteCarloAgent.py --train
|
|
|
|
|
37 |
|
38 |
+
# Testing: Use the `--test` flag with the path to the policy file.
|
39 |
python3 MonteCarloAgent.py --test policy_mc_CliffWalking-v0_e2000_s500_g0.99_e0.1.npy --render_mode human
|
40 |
```
|
41 |
|
42 |
+
### Usage
|
43 |
|
44 |
+
```bash
|
45 |
+
usage: MonteCarloAgent.py [-h] [--train] [--test TEST] [--n_train_episodes N_TRAIN_EPISODES] [--n_test_episodes N_TEST_EPISODES] [--test_every TEST_EVERY] [--max_steps MAX_STEPS] [--update_type {first_visit,every_visit}]
|
46 |
+
[--save_dir SAVE_DIR] [--no_save] [--gamma GAMMA] [--epsilon EPSILON] [--env ENV] [--render_mode RENDER_MODE] [--wandb_project WANDB_PROJECT] [--wandb_group WANDB_GROUP]
|
47 |
+
[--wandb_job_type WANDB_JOB_TYPE] [--wandb_run_name_suffix WANDB_RUN_NAME_SUFFIX]
|
48 |
|
49 |
options:
|
50 |
-h, --help show this help message and exit
|
51 |
+
--train Use this flag to train the agent.
|
52 |
--test TEST Use this flag to test the agent. Provide the path to the policy file.
|
53 |
--n_train_episodes N_TRAIN_EPISODES
|
54 |
+
The number of episodes to train for. (default: 2000)
|
55 |
--n_test_episodes N_TEST_EPISODES
|
56 |
+
The number of episodes to test for. (default: 100)
|
57 |
--test_every TEST_EVERY
|
58 |
+
During training, test the agent every n episodes. (default: 100)
|
59 |
--max_steps MAX_STEPS
|
60 |
+
The maximum number of steps per episode before the episode is forced to end. (default: 500)
|
61 |
+
--update_type {first_visit,every_visit}
|
62 |
+
The type of update to use. (default: first_visit)
|
63 |
+
--save_dir SAVE_DIR The directory to save the policy to. (default: policies)
|
64 |
+
--no_save Use this flag to disable saving the policy.
|
65 |
+
--gamma GAMMA The value for the discount factor to use. (default: 0.99)
|
66 |
+
--epsilon EPSILON The value for the epsilon-greedy policy to use. (default: 0.1)
|
67 |
+
--env ENV The Gymnasium environment to use. (default: CliffWalking-v0)
|
68 |
--render_mode RENDER_MODE
|
69 |
+
Render mode passed to the gym.make() function. Use 'human' to render the environment. (default: None)
|
70 |
--wandb_project WANDB_PROJECT
|
71 |
+
WandB project name for logging. If not provided, no logging is done. (default: None)
|
72 |
--wandb_group WANDB_GROUP
|
73 |
WandB group name for logging. (default: monte-carlo)
|
74 |
--wandb_job_type WANDB_JOB_TYPE
|
75 |
WandB job type for logging. (default: train)
|
76 |
+
--wandb_run_name_suffix WANDB_RUN_NAME_SUFFIX
|
77 |
+
WandB run name suffix for logging. (default: None)
|
78 |
```
|
79 |
|
80 |
+
## Presentation Guide
|
81 |
|
82 |
1. Title Slide: list the title of your talk along with your name
|
83 |
|
demo.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import os
|
2 |
import time
|
3 |
-
from matplotlib import interactive
|
4 |
import numpy as np
|
5 |
import gradio as gr
|
6 |
from MonteCarloAgent import MonteCarloAgent
|
@@ -35,9 +34,10 @@ action_map = {
|
|
35 |
}
|
36 |
|
37 |
# Global variables to allow changing it on the fly
|
38 |
-
live_render_fps =
|
39 |
live_epsilon = 0.0
|
40 |
live_paused = False
|
|
|
41 |
|
42 |
|
43 |
def change_render_fps(x):
|
@@ -54,23 +54,44 @@ def change_epsilon(x):
|
|
54 |
|
55 |
def change_paused(x):
|
56 |
print("Changing paused:", x)
|
|
|
|
|
|
|
|
|
|
|
57 |
global live_paused
|
58 |
-
live_paused = x
|
59 |
-
|
60 |
-
return gr.update(value=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
|
63 |
def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
64 |
-
global live_render_fps, live_epsilon
|
65 |
live_render_fps = render_fps
|
66 |
live_epsilon = epsilon
|
|
|
67 |
print("Running...")
|
|
|
68 |
print(f"- n_test_episodes: {n_test_episodes}")
|
69 |
print(f"- max_steps: {max_steps}")
|
70 |
print(f"- render_fps: {live_render_fps}")
|
|
|
71 |
|
72 |
policy_path = os.path.join(policies_folder, policy_fname)
|
73 |
props = policy_fname.split("_")
|
|
|
|
|
|
|
|
|
|
|
74 |
agent_type, env_name = props[0], props[1]
|
75 |
|
76 |
agent = agent_map[agent_type](env_name, render_mode="rgb_array")
|
@@ -82,7 +103,9 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
82 |
episodes_solved = 0
|
83 |
|
84 |
def ep_str(episode):
|
85 |
-
return
|
|
|
|
|
86 |
|
87 |
def step_str(step):
|
88 |
return f"{step + 1}"
|
@@ -93,8 +116,13 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
93 |
max_steps=max_steps, render=True, override_epsilon=True
|
94 |
)
|
95 |
):
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
state, action, reward = episode_hist[-1]
|
100 |
curr_policy = agent.Pi[state]
|
@@ -165,6 +193,14 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
|
165 |
|
166 |
time.sleep(1 / live_render_fps)
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
if solved:
|
169 |
episodes_solved += 1
|
170 |
|
@@ -247,16 +283,22 @@ with gr.Blocks(title="CS581 Demo") as demo:
|
|
247 |
|
248 |
with gr.Row():
|
249 |
btn_pause = gr.components.Button("⏸️ Pause", interactive=True)
|
|
|
|
|
250 |
btn_pause.click(
|
251 |
fn=change_paused,
|
252 |
inputs=[btn_pause],
|
253 |
-
outputs=[btn_pause],
|
|
|
|
|
|
|
|
|
254 |
)
|
255 |
|
256 |
out_msg = gr.components.Textbox(
|
257 |
value=""
|
258 |
if all_policies
|
259 |
-
else "
|
260 |
label="Status Message",
|
261 |
)
|
262 |
|
@@ -284,5 +326,5 @@ with gr.Blocks(title="CS581 Demo") as demo:
|
|
284 |
],
|
285 |
)
|
286 |
|
287 |
-
demo.queue(concurrency_count=
|
288 |
demo.launch()
|
|
|
1 |
import os
|
2 |
import time
|
|
|
3 |
import numpy as np
|
4 |
import gradio as gr
|
5 |
from MonteCarloAgent import MonteCarloAgent
|
|
|
34 |
}
|
35 |
|
36 |
# Global variables to allow changing it on the fly
|
37 |
+
live_render_fps = 5
|
38 |
live_epsilon = 0.0
|
39 |
live_paused = False
|
40 |
+
live_steps_forward = None
|
41 |
|
42 |
|
43 |
def change_render_fps(x):
|
|
|
54 |
|
55 |
def change_paused(x):
|
56 |
print("Changing paused:", x)
|
57 |
+
val_map = {
|
58 |
+
"▶️ Resume": False,
|
59 |
+
"⏸️ Pause": True,
|
60 |
+
}
|
61 |
+
val_map_inv = {v: k for k, v in val_map.items()}
|
62 |
global live_paused
|
63 |
+
live_paused = val_map[x]
|
64 |
+
next_val = val_map_inv[not live_paused]
|
65 |
+
return gr.update(value=next_val), gr.update(interactive=live_paused)
|
66 |
+
|
67 |
+
|
68 |
+
def onclick_btn_forward():
|
69 |
+
print("Step forward")
|
70 |
+
global live_steps_forward
|
71 |
+
if live_steps_forward is None:
|
72 |
+
live_steps_forward = 0
|
73 |
+
live_steps_forward += 1
|
74 |
|
75 |
|
76 |
def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
|
77 |
+
global live_render_fps, live_epsilon, live_paused, live_steps_forward
|
78 |
live_render_fps = render_fps
|
79 |
live_epsilon = epsilon
|
80 |
+
print("=" * 80)
|
81 |
print("Running...")
|
82 |
+
print(f"- policy_fname: {policy_fname}")
|
83 |
print(f"- n_test_episodes: {n_test_episodes}")
|
84 |
print(f"- max_steps: {max_steps}")
|
85 |
print(f"- render_fps: {live_render_fps}")
|
86 |
+
print(f"- epsilon: {live_epsilon}")
|
87 |
|
88 |
policy_path = os.path.join(policies_folder, policy_fname)
|
89 |
props = policy_fname.split("_")
|
90 |
+
|
91 |
+
if len(props) < 2:
|
92 |
+
yield None, None, None, None, None, None, None, None, None, None, "🚫 Please select a valid policy file."
|
93 |
+
return
|
94 |
+
|
95 |
agent_type, env_name = props[0], props[1]
|
96 |
|
97 |
agent = agent_map[agent_type](env_name, render_mode="rgb_array")
|
|
|
103 |
episodes_solved = 0
|
104 |
|
105 |
def ep_str(episode):
|
106 |
+
return (
|
107 |
+
f"{episode} / {n_test_episodes} ({(episode) / n_test_episodes * 100:.2f}%)"
|
108 |
+
)
|
109 |
|
110 |
def step_str(step):
|
111 |
return f"{step + 1}"
|
|
|
116 |
max_steps=max_steps, render=True, override_epsilon=True
|
117 |
)
|
118 |
):
|
119 |
+
if live_steps_forward is not None:
|
120 |
+
if live_steps_forward > 0:
|
121 |
+
live_steps_forward -= 1
|
122 |
+
|
123 |
+
if live_steps_forward == 0:
|
124 |
+
live_steps_forward = None
|
125 |
+
live_paused = True
|
126 |
|
127 |
state, action, reward = episode_hist[-1]
|
128 |
curr_policy = agent.Pi[state]
|
|
|
193 |
|
194 |
time.sleep(1 / live_render_fps)
|
195 |
|
196 |
+
while live_paused and live_steps_forward is None:
|
197 |
+
yield agent_type, env_name, rgb_array, policy_viz, ep_str(
|
198 |
+
episode + 1
|
199 |
+
), ep_str(episodes_solved), step_str(
|
200 |
+
step
|
201 |
+
), state, action, reward, "Paused..."
|
202 |
+
time.sleep(1 / live_render_fps)
|
203 |
+
|
204 |
if solved:
|
205 |
episodes_solved += 1
|
206 |
|
|
|
283 |
|
284 |
with gr.Row():
|
285 |
btn_pause = gr.components.Button("⏸️ Pause", interactive=True)
|
286 |
+
btn_forward = gr.components.Button("⏩ Step", interactive=False)
|
287 |
+
|
288 |
btn_pause.click(
|
289 |
fn=change_paused,
|
290 |
inputs=[btn_pause],
|
291 |
+
outputs=[btn_pause, btn_forward],
|
292 |
+
)
|
293 |
+
|
294 |
+
btn_forward.click(
|
295 |
+
fn=onclick_btn_forward,
|
296 |
)
|
297 |
|
298 |
out_msg = gr.components.Textbox(
|
299 |
value=""
|
300 |
if all_policies
|
301 |
+
else "ERROR: No policies found! Please train an agent first or add a policy to the policies folder.",
|
302 |
label="Status Message",
|
303 |
)
|
304 |
|
|
|
326 |
],
|
327 |
)
|
328 |
|
329 |
+
demo.queue(concurrency_count=2)
|
330 |
demo.launch()
|