Spaces:
Sleeping
Sleeping
Andrei Cozma
commited on
Commit
Β·
30bb976
1
Parent(s):
120dc90
Updates
Browse files- MCAgent.py +10 -12
- demo.py +3 -1
- policies/{MCAgent_CliffWalking-v0_gamma:1.0_epsilon:0.4_e1500_s200_first_visit.npy β MCAgent_CliffWalking-v0_gamma:1.0_epsilon:0.4_type:onpolicy_e1500_s200.npy} +0 -0
- policies/{MCAgent_FrozenLake-v1_gamma:1.0_epsilon:0.4_size:8_seed:35280_e1500_s200_first_visit.npy β MCAgent_FrozenLake-v1_gamma:1.0_epsilon:0.4_size:8_seed:35280_type:onpolicy_e1500_s200.npy} +0 -0
- policies/{MCAgent_FrozenLake-v1_gamma:1.0_epsilon:0.4_size:8_seed:61252_e1500_s200_first_visit.npy β MCAgent_FrozenLake-v1_gamma:1.0_epsilon:0.4_size:8_seed:61252_type:onpolicy_e1500_s200.npy} +0 -0
- policies/{MCAgent_FrozenLake-v1_gamma:1.0_epsilon:0.4_size:8_seed:96883_e1500_s200_first_visit.npy β MCAgent_FrozenLake-v1_gamma:1.0_epsilon:0.4_size:8_seed:96883_type:onpolicy_e1500_s200.npy} +0 -0
- policies/{MCAgent_Taxi-v3_gamma:1.0_epsilon:0.75_e15000_s200_first_visit.npy β MCAgent_Taxi-v3_gamma:1.0_epsilon:0.75_type:onpolicy_e15000_s200.npy} +0 -0
- run.py +4 -4
MCAgent.py
CHANGED
@@ -5,12 +5,10 @@ from AgentBase import AgentBase
|
|
5 |
|
6 |
|
7 |
class MCAgent(AgentBase):
|
8 |
-
def __init__(
|
9 |
-
self, /, update_type="on-policy", **kwargs # "on-policy" or "off-policy
|
10 |
-
):
|
11 |
super().__init__(run_name=self.__class__.__name__, **kwargs)
|
12 |
-
self.
|
13 |
-
self.run_name
|
14 |
self.initialize()
|
15 |
|
16 |
def initialize(self):
|
@@ -23,13 +21,13 @@ class MCAgent(AgentBase):
|
|
23 |
# self.Q = np.random.rand(self.n_states, self.n_actions)
|
24 |
# self.Q = np.random.normal(0, 1, size=(self.n_states, self.n_actions))
|
25 |
|
26 |
-
if self.
|
27 |
# For On-Policy update type:
|
28 |
# R keeps track of all the returns that have been observed for each state-action pair to update Q
|
29 |
self.R = [[[] for _ in range(self.n_actions)] for _ in range(self.n_states)]
|
30 |
# An arbitrary e-greedy policy:
|
31 |
self.Pi = self.create_soft_policy()
|
32 |
-
elif self.
|
33 |
# For Off-Policy update type:
|
34 |
self.C = np.zeros((self.n_states, self.n_actions))
|
35 |
# Target policy is greedy with respect to the current Q (ties broken consistently)
|
@@ -39,7 +37,7 @@ class MCAgent(AgentBase):
|
|
39 |
self.Pi_behaviour = self.create_soft_policy(coverage_policy=self.Pi)
|
40 |
else:
|
41 |
raise ValueError(
|
42 |
-
f"
|
43 |
)
|
44 |
print("=" * 80)
|
45 |
print("Initial policy:")
|
@@ -67,7 +65,7 @@ class MCAgent(AgentBase):
|
|
67 |
)
|
68 |
return Pi
|
69 |
|
70 |
-
def
|
71 |
G = 0.0
|
72 |
# For each step of the episode, in reverse order
|
73 |
for t in range(len(episode_hist) - 1, -1, -1):
|
@@ -106,7 +104,7 @@ class MCAgent(AgentBase):
|
|
106 |
# 1 - self.epsilon + self.epsilon / self.n_actions
|
107 |
# )
|
108 |
|
109 |
-
def
|
110 |
G, W = 0.0, 1.0
|
111 |
for t in range(len(episode_hist) - 1, -1, -1):
|
112 |
state, action, reward = episode_hist[t]
|
@@ -154,7 +152,7 @@ class MCAgent(AgentBase):
|
|
154 |
"avg_ep_len": avg_ep_len,
|
155 |
}
|
156 |
|
157 |
-
update_func = getattr(self, f"update_{self.
|
158 |
|
159 |
tqrange = tqdm(range(n_train_episodes))
|
160 |
tqrange.set_description("Training")
|
@@ -163,7 +161,7 @@ class MCAgent(AgentBase):
|
|
163 |
self.wandb_log_img(episode=None)
|
164 |
|
165 |
for e in tqrange:
|
166 |
-
policy = self.Pi_behaviour if self.
|
167 |
episode_hist, solved, _ = self.run_episode(policy=policy, **kwargs)
|
168 |
rewards = [x[2] for x in episode_hist]
|
169 |
total_reward, avg_reward = sum(rewards), np.mean(rewards)
|
|
|
5 |
|
6 |
|
7 |
class MCAgent(AgentBase):
|
8 |
+
def __init__(self, /, type="onpolicy", **kwargs): # "on-policy" or "off-policy
|
|
|
|
|
9 |
super().__init__(run_name=self.__class__.__name__, **kwargs)
|
10 |
+
self.type = type
|
11 |
+
self.run_name += f"_type:{self.type}"
|
12 |
self.initialize()
|
13 |
|
14 |
def initialize(self):
|
|
|
21 |
# self.Q = np.random.rand(self.n_states, self.n_actions)
|
22 |
# self.Q = np.random.normal(0, 1, size=(self.n_states, self.n_actions))
|
23 |
|
24 |
+
if self.type.startswith("onpolicy"):
|
25 |
# For On-Policy update type:
|
26 |
# R keeps track of all the returns that have been observed for each state-action pair to update Q
|
27 |
self.R = [[[] for _ in range(self.n_actions)] for _ in range(self.n_states)]
|
28 |
# An arbitrary e-greedy policy:
|
29 |
self.Pi = self.create_soft_policy()
|
30 |
+
elif self.type.startswith("offpolicy"):
|
31 |
# For Off-Policy update type:
|
32 |
self.C = np.zeros((self.n_states, self.n_actions))
|
33 |
# Target policy is greedy with respect to the current Q (ties broken consistently)
|
|
|
37 |
self.Pi_behaviour = self.create_soft_policy(coverage_policy=self.Pi)
|
38 |
else:
|
39 |
raise ValueError(
|
40 |
+
f"Parameter 'type' must be either 'onpolicy' or 'offpolicy', but got '{self.type}'"
|
41 |
)
|
42 |
print("=" * 80)
|
43 |
print("Initial policy:")
|
|
|
65 |
)
|
66 |
return Pi
|
67 |
|
68 |
+
def update_onpolicy(self, episode_hist):
|
69 |
G = 0.0
|
70 |
# For each step of the episode, in reverse order
|
71 |
for t in range(len(episode_hist) - 1, -1, -1):
|
|
|
104 |
# 1 - self.epsilon + self.epsilon / self.n_actions
|
105 |
# )
|
106 |
|
107 |
+
def update_offpolicy(self, episode_hist):
|
108 |
G, W = 0.0, 1.0
|
109 |
for t in range(len(episode_hist) - 1, -1, -1):
|
110 |
state, action, reward = episode_hist[t]
|
|
|
152 |
"avg_ep_len": avg_ep_len,
|
153 |
}
|
154 |
|
155 |
+
update_func = getattr(self, f"update_{self.type}")
|
156 |
|
157 |
tqrange = tqdm(range(n_train_episodes))
|
158 |
tqrange.set_description("Training")
|
|
|
161 |
self.wandb_log_img(episode=None)
|
162 |
|
163 |
for e in tqrange:
|
164 |
+
policy = self.Pi_behaviour if self.type == "off_policy" else self.Pi
|
165 |
episode_hist, solved, _ = self.run_episode(policy=policy, **kwargs)
|
166 |
rewards = [x[2] for x in episode_hist]
|
167 |
total_reward, avg_reward = sum(rewards), np.mean(rewards)
|
demo.py
CHANGED
@@ -154,7 +154,8 @@ def run(
|
|
154 |
agent = load_agent(
|
155 |
policy_path, return_agent_env_keys=True, render_mode="rgb_array"
|
156 |
)
|
157 |
-
except ValueError:
|
|
|
158 |
yield localstate, None, None, None, None, None, None, None, None, None, None, "π« Please select a valid policy file."
|
159 |
return
|
160 |
|
@@ -185,6 +186,7 @@ def run(
|
|
185 |
|
186 |
for step, (episode_hist, solved, frame_env) in enumerate(
|
187 |
agent.generate_episode(
|
|
|
188 |
max_steps=max_steps,
|
189 |
render=True,
|
190 |
)
|
|
|
154 |
agent = load_agent(
|
155 |
policy_path, return_agent_env_keys=True, render_mode="rgb_array"
|
156 |
)
|
157 |
+
except ValueError as e:
|
158 |
+
print(f"π« Error: {e}")
|
159 |
yield localstate, None, None, None, None, None, None, None, None, None, None, "π« Please select a valid policy file."
|
160 |
return
|
161 |
|
|
|
186 |
|
187 |
for step, (episode_hist, solved, frame_env) in enumerate(
|
188 |
agent.generate_episode(
|
189 |
+
policy=agent.Pi,
|
190 |
max_steps=max_steps,
|
191 |
render=True,
|
192 |
)
|
policies/{MCAgent_CliffWalking-v0_gamma:1.0_epsilon:0.4_e1500_s200_first_visit.npy β MCAgent_CliffWalking-v0_gamma:1.0_epsilon:0.4_type:onpolicy_e1500_s200.npy}
RENAMED
File without changes
|
policies/{MCAgent_FrozenLake-v1_gamma:1.0_epsilon:0.4_size:8_seed:35280_e1500_s200_first_visit.npy β MCAgent_FrozenLake-v1_gamma:1.0_epsilon:0.4_size:8_seed:35280_type:onpolicy_e1500_s200.npy}
RENAMED
File without changes
|
policies/{MCAgent_FrozenLake-v1_gamma:1.0_epsilon:0.4_size:8_seed:61252_e1500_s200_first_visit.npy β MCAgent_FrozenLake-v1_gamma:1.0_epsilon:0.4_size:8_seed:61252_type:onpolicy_e1500_s200.npy}
RENAMED
File without changes
|
policies/{MCAgent_FrozenLake-v1_gamma:1.0_epsilon:0.4_size:8_seed:96883_e1500_s200_first_visit.npy β MCAgent_FrozenLake-v1_gamma:1.0_epsilon:0.4_size:8_seed:96883_type:onpolicy_e1500_s200.npy}
RENAMED
File without changes
|
policies/{MCAgent_Taxi-v3_gamma:1.0_epsilon:0.75_e15000_s200_first_visit.npy β MCAgent_Taxi-v3_gamma:1.0_epsilon:0.75_type:onpolicy_e15000_s200.npy}
RENAMED
File without changes
|
run.py
CHANGED
@@ -66,11 +66,11 @@ def main():
|
|
66 |
)
|
67 |
|
68 |
parser.add_argument(
|
69 |
-
"--
|
70 |
type=str,
|
71 |
-
choices=["
|
72 |
-
default="
|
73 |
-
help="The type of update to use. Only supported by Monte-Carlo agent. (default:
|
74 |
)
|
75 |
|
76 |
### Environment parameters
|
|
|
66 |
)
|
67 |
|
68 |
parser.add_argument(
|
69 |
+
"--type",
|
70 |
type=str,
|
71 |
+
choices=["onpolicy", "offpolicy"],
|
72 |
+
default="onpolicy",
|
73 |
+
help="The type of update to use. Only supported by Monte-Carlo agent. (default: onpolicy)",
|
74 |
)
|
75 |
|
76 |
### Environment parameters
|