Andrei Cozma commited on
Commit
e173b06
·
1 Parent(s): 3266489
agents.py CHANGED
@@ -1,14 +1,44 @@
1
  # All supported agents
 
2
  from MCAgent import MCAgent
3
  from DPAgent import DPAgent
 
4
 
5
  AGENTS_MAP = {"MCAgent": MCAgent, "DPAgent": DPAgent}
6
 
7
 
8
- def load_agent(agent_name, **kwargs):
9
- if agent_name not in AGENTS_MAP:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  raise ValueError(
11
- f"ERROR: Agent '{agent_name}' not valid. Must be one of: {AGENTS_MAP.keys()}"
12
  )
13
 
14
- return AGENTS_MAP[agent_name](**kwargs)
 
 
 
 
 
1
  # All supported agents
2
+ import os
3
  from MCAgent import MCAgent
4
  from DPAgent import DPAgent
5
+ import warnings
6
 
7
  AGENTS_MAP = {"MCAgent": MCAgent, "DPAgent": DPAgent}
8
 
9
 
10
+ def load_agent(agent_key, **kwargs):
11
+ agent_policy_file = agent_key if agent_key.endswith(".npy") else None
12
+ if agent_policy_file is not None:
13
+ props = os.path.basename(agent_key).split("_")
14
+ try:
15
+ agent_key, env_key = props[0], props[1]
16
+ agent_args = {}
17
+ for prop in props[2:]:
18
+ props_split = prop.split(":")
19
+ if len(props_split) == 2:
20
+ agent_args[props_split[0]] = props_split[1]
21
+ else:
22
+ warnings.warn(
23
+ f"Skipping property {prop} as it does not have the format 'key:value'.",
24
+ UserWarning,
25
+ )
26
+
27
+ agent_args["env"] = env_key
28
+ kwargs.update(agent_args)
29
+ print("agent_args:", kwargs)
30
+ except IndexError:
31
+ raise ValueError(
32
+ f"ERROR: Could not parse agent properties. Must be of the format 'AgentName_EnvName_key:value_key:value...'."
33
+ )
34
+
35
+ if agent_key not in AGENTS_MAP:
36
  raise ValueError(
37
+ f"ERROR: Agent '{agent_key}' not valid. Must be one of: {AGENTS_MAP.keys()}"
38
  )
39
 
40
+ agent = AGENTS_MAP[agent_key](**kwargs)
41
+ if agent_policy_file is not None:
42
+ agent.load_policy(agent_policy_file)
43
+
44
+ return agent
demo.py CHANGED
@@ -1,13 +1,12 @@
1
  import os
2
  import time
3
- import warnings
4
  import numpy as np
5
  import gradio as gr
6
 
7
  import scipy.ndimage
8
  import cv2
9
 
10
- from agents import AGENTS_MAP
11
 
12
  default_n_test_episodes = 10
13
  default_max_steps = 500
@@ -137,33 +136,16 @@ def run(
137
  print(f"- epsilon: {localstate.live_steps_forward}")
138
 
139
  policy_path = os.path.join(policies_folder, policy_fname)
140
- props = policy_fname.split("_")
141
 
142
  try:
143
- agent_key, env_key = props[0], props[1]
144
- agent_args = {}
145
- for prop in props[2:]:
146
- props_split = prop.split(":")
147
- if len(props_split) == 2:
148
- agent_args[props_split[0]] = props_split[1]
149
- else:
150
- warnings.warn(
151
- f"Skipping property {prop} as it does not have the format 'key:value'.",
152
- UserWarning,
153
- )
154
- except IndexError:
155
  yield localstate, None, None, None, None, None, None, None, None, None, None, "🚫 Please select a valid policy file."
156
  return
157
 
158
- agent_args.update(
159
- {
160
- "env": env_key,
161
- "render_mode": "rgb_array",
162
- }
163
- )
164
- print("agent_args:", agent_args)
165
- agent = AGENTS_MAP[agent_key](**agent_args)
166
- agent.load_policy(policy_path)
167
  env_action_map = action_map.get(env_key)
168
 
169
  solved, frame_env, frame_policy = None, None, None
 
1
  import os
2
  import time
 
3
  import numpy as np
4
  import gradio as gr
5
 
6
  import scipy.ndimage
7
  import cv2
8
 
9
+ from agents import load_agent
10
 
11
  default_n_test_episodes = 10
12
  default_max_steps = 500
 
136
  print(f"- epsilon: {localstate.live_steps_forward}")
137
 
138
  policy_path = os.path.join(policies_folder, policy_fname)
 
139
 
140
  try:
141
+ agent = load_agent(
142
+ policy_path, return_agent_env_keys=True, render_mode="rgb_array"
143
+ )
144
+ except ValueError:
 
 
 
 
 
 
 
 
145
  yield localstate, None, None, None, None, None, None, None, None, None, None, "🚫 Please select a valid policy file."
146
  return
147
 
148
+ agent_key, env_key = agent.__class__.__name__, agent.env_name
 
 
 
 
 
 
 
 
149
  env_action_map = action_map.get(env_key)
150
 
151
  solved, frame_env, frame_policy = None, None, None
policies/DPAgent_CliffWalking-v0_gamma:0.99_epsilon:0.4_e2500_s200.npy CHANGED
Binary files a/policies/DPAgent_CliffWalking-v0_gamma:0.99_epsilon:0.4_e2500_s200.npy and b/policies/DPAgent_CliffWalking-v0_gamma:0.99_epsilon:0.4_e2500_s200.npy differ
 
policies/MCAgent_FrozenLake-v1_gamma:0.99_epsilon:0.2_size:8_seed:33951_e2500_s200_first_visit.npy DELETED
Binary file (2.18 kB)
 
policies/MCAgent_FrozenLake-v1_gamma:0.99_epsilon:0.4_size:8_seed:16970_e2500_s200_first_visit.npy DELETED
Binary file (2.18 kB)
 
run.py CHANGED
@@ -138,7 +138,9 @@ def main():
138
  args = parser.parse_args()
139
  print(vars(args))
140
 
141
- agent = load_agent(args.agent, **dict(args._get_kwargs()))
 
 
142
 
143
  agent.run_name += f"_e{args.n_train_episodes}_s{args.max_steps}"
144
  if args.wandb_run_name_suffix is not None:
@@ -169,7 +171,6 @@ def main():
169
  if not args.no_save:
170
  agent.save_policy(save_dir=args.save_dir)
171
  elif args.test is not None:
172
- agent.load_policy(args.test)
173
  agent.test(
174
  n_test_episodes=args.n_test_episodes,
175
  max_steps=args.max_steps,
 
138
  args = parser.parse_args()
139
  print(vars(args))
140
 
141
+ agent = load_agent(
142
+ args.agent if args.test is None else args.test, **dict(args._get_kwargs())
143
+ )
144
 
145
  agent.run_name += f"_e{args.n_train_episodes}_s{args.max_steps}"
146
  if args.wandb_run_name_suffix is not None:
 
171
  if not args.no_save:
172
  agent.save_policy(save_dir=args.save_dir)
173
  elif args.test is not None:
 
174
  agent.test(
175
  n_test_episodes=args.n_test_episodes,
176
  max_steps=args.max_steps,