Spaces:
Sleeping
Sleeping
Andrei Cozma
commited on
Commit
·
e173b06
1
Parent(s):
3266489
Updates
Browse files- agents.py +34 -4
- demo.py +6 -24
- policies/DPAgent_CliffWalking-v0_gamma:0.99_epsilon:0.4_e2500_s200.npy +0 -0
- policies/MCAgent_FrozenLake-v1_gamma:0.99_epsilon:0.2_size:8_seed:33951_e2500_s200_first_visit.npy +0 -0
- policies/MCAgent_FrozenLake-v1_gamma:0.99_epsilon:0.4_size:8_seed:16970_e2500_s200_first_visit.npy +0 -0
- run.py +3 -2
agents.py
CHANGED
@@ -1,14 +1,44 @@
|
|
1 |
# All supported agents
|
|
|
2 |
from MCAgent import MCAgent
|
3 |
from DPAgent import DPAgent
|
|
|
4 |
|
5 |
AGENTS_MAP = {"MCAgent": MCAgent, "DPAgent": DPAgent}
|
6 |
|
7 |
|
8 |
-
def load_agent(
|
9 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
raise ValueError(
|
11 |
-
f"ERROR: Agent '{
|
12 |
)
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
1 |
# All supported agents
|
2 |
+
import os
|
3 |
from MCAgent import MCAgent
|
4 |
from DPAgent import DPAgent
|
5 |
+
import warnings
|
6 |
|
7 |
AGENTS_MAP = {"MCAgent": MCAgent, "DPAgent": DPAgent}
|
8 |
|
9 |
|
10 |
+
def load_agent(agent_key, **kwargs):
|
11 |
+
agent_policy_file = agent_key if agent_key.endswith(".npy") else None
|
12 |
+
if agent_policy_file is not None:
|
13 |
+
props = os.path.basename(agent_key).split("_")
|
14 |
+
try:
|
15 |
+
agent_key, env_key = props[0], props[1]
|
16 |
+
agent_args = {}
|
17 |
+
for prop in props[2:]:
|
18 |
+
props_split = prop.split(":")
|
19 |
+
if len(props_split) == 2:
|
20 |
+
agent_args[props_split[0]] = props_split[1]
|
21 |
+
else:
|
22 |
+
warnings.warn(
|
23 |
+
f"Skipping property {prop} as it does not have the format 'key:value'.",
|
24 |
+
UserWarning,
|
25 |
+
)
|
26 |
+
|
27 |
+
agent_args["env"] = env_key
|
28 |
+
kwargs.update(agent_args)
|
29 |
+
print("agent_args:", kwargs)
|
30 |
+
except IndexError:
|
31 |
+
raise ValueError(
|
32 |
+
f"ERROR: Could not parse agent properties. Must be of the format 'AgentName_EnvName_key:value_key:value...'."
|
33 |
+
)
|
34 |
+
|
35 |
+
if agent_key not in AGENTS_MAP:
|
36 |
raise ValueError(
|
37 |
+
f"ERROR: Agent '{agent_key}' not valid. Must be one of: {AGENTS_MAP.keys()}"
|
38 |
)
|
39 |
|
40 |
+
agent = AGENTS_MAP[agent_key](**kwargs)
|
41 |
+
if agent_policy_file is not None:
|
42 |
+
agent.load_policy(agent_policy_file)
|
43 |
+
|
44 |
+
return agent
|
demo.py
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
import os
|
2 |
import time
|
3 |
-
import warnings
|
4 |
import numpy as np
|
5 |
import gradio as gr
|
6 |
|
7 |
import scipy.ndimage
|
8 |
import cv2
|
9 |
|
10 |
-
from agents import
|
11 |
|
12 |
default_n_test_episodes = 10
|
13 |
default_max_steps = 500
|
@@ -137,33 +136,16 @@ def run(
|
|
137 |
print(f"- epsilon: {localstate.live_steps_forward}")
|
138 |
|
139 |
policy_path = os.path.join(policies_folder, policy_fname)
|
140 |
-
props = policy_fname.split("_")
|
141 |
|
142 |
try:
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
if len(props_split) == 2:
|
148 |
-
agent_args[props_split[0]] = props_split[1]
|
149 |
-
else:
|
150 |
-
warnings.warn(
|
151 |
-
f"Skipping property {prop} as it does not have the format 'key:value'.",
|
152 |
-
UserWarning,
|
153 |
-
)
|
154 |
-
except IndexError:
|
155 |
yield localstate, None, None, None, None, None, None, None, None, None, None, "🚫 Please select a valid policy file."
|
156 |
return
|
157 |
|
158 |
-
|
159 |
-
{
|
160 |
-
"env": env_key,
|
161 |
-
"render_mode": "rgb_array",
|
162 |
-
}
|
163 |
-
)
|
164 |
-
print("agent_args:", agent_args)
|
165 |
-
agent = AGENTS_MAP[agent_key](**agent_args)
|
166 |
-
agent.load_policy(policy_path)
|
167 |
env_action_map = action_map.get(env_key)
|
168 |
|
169 |
solved, frame_env, frame_policy = None, None, None
|
|
|
1 |
import os
|
2 |
import time
|
|
|
3 |
import numpy as np
|
4 |
import gradio as gr
|
5 |
|
6 |
import scipy.ndimage
|
7 |
import cv2
|
8 |
|
9 |
+
from agents import load_agent
|
10 |
|
11 |
default_n_test_episodes = 10
|
12 |
default_max_steps = 500
|
|
|
136 |
print(f"- epsilon: {localstate.live_steps_forward}")
|
137 |
|
138 |
policy_path = os.path.join(policies_folder, policy_fname)
|
|
|
139 |
|
140 |
try:
|
141 |
+
agent = load_agent(
|
142 |
+
policy_path, return_agent_env_keys=True, render_mode="rgb_array"
|
143 |
+
)
|
144 |
+
except ValueError:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
yield localstate, None, None, None, None, None, None, None, None, None, None, "🚫 Please select a valid policy file."
|
146 |
return
|
147 |
|
148 |
+
agent_key, env_key = agent.__class__.__name__, agent.env_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
env_action_map = action_map.get(env_key)
|
150 |
|
151 |
solved, frame_env, frame_policy = None, None, None
|
policies/DPAgent_CliffWalking-v0_gamma:0.99_epsilon:0.4_e2500_s200.npy
CHANGED
Binary files a/policies/DPAgent_CliffWalking-v0_gamma:0.99_epsilon:0.4_e2500_s200.npy and b/policies/DPAgent_CliffWalking-v0_gamma:0.99_epsilon:0.4_e2500_s200.npy differ
|
|
policies/MCAgent_FrozenLake-v1_gamma:0.99_epsilon:0.2_size:8_seed:33951_e2500_s200_first_visit.npy
DELETED
Binary file (2.18 kB)
|
|
policies/MCAgent_FrozenLake-v1_gamma:0.99_epsilon:0.4_size:8_seed:16970_e2500_s200_first_visit.npy
DELETED
Binary file (2.18 kB)
|
|
run.py
CHANGED
@@ -138,7 +138,9 @@ def main():
|
|
138 |
args = parser.parse_args()
|
139 |
print(vars(args))
|
140 |
|
141 |
-
agent = load_agent(
|
|
|
|
|
142 |
|
143 |
agent.run_name += f"_e{args.n_train_episodes}_s{args.max_steps}"
|
144 |
if args.wandb_run_name_suffix is not None:
|
@@ -169,7 +171,6 @@ def main():
|
|
169 |
if not args.no_save:
|
170 |
agent.save_policy(save_dir=args.save_dir)
|
171 |
elif args.test is not None:
|
172 |
-
agent.load_policy(args.test)
|
173 |
agent.test(
|
174 |
n_test_episodes=args.n_test_episodes,
|
175 |
max_steps=args.max_steps,
|
|
|
138 |
args = parser.parse_args()
|
139 |
print(vars(args))
|
140 |
|
141 |
+
agent = load_agent(
|
142 |
+
args.agent if args.test is None else args.test, **dict(args._get_kwargs())
|
143 |
+
)
|
144 |
|
145 |
agent.run_name += f"_e{args.n_train_episodes}_s{args.max_steps}"
|
146 |
if args.wandb_run_name_suffix is not None:
|
|
|
171 |
if not args.no_save:
|
172 |
agent.save_policy(save_dir=args.save_dir)
|
173 |
elif args.test is not None:
|
|
|
174 |
agent.test(
|
175 |
n_test_episodes=args.n_test_episodes,
|
176 |
max_steps=args.max_steps,
|