Spaces:
Sleeping
Sleeping
Andrei Cozma
commited on
Commit
·
cf8b7c4
1
Parent(s):
6ee82fe
Updates
Browse files
DPAgent.py
CHANGED
@@ -8,14 +8,16 @@ import warnings
|
|
8 |
|
9 |
|
10 |
class DPAgent(Shared):
|
11 |
-
def __init__(self
|
12 |
-
super().__init__(**kwargs)
|
13 |
-
self.theta = kwargs.get(
|
14 |
print(self.theta)
|
15 |
self.V = np.zeros(self.env.observation_space.n)
|
16 |
self.Pi = np.zeros(self.env.observation_space.n, self.env.action_space.n)
|
17 |
if self.gamma >= 1.0:
|
18 |
-
warnings.warn(
|
|
|
|
|
19 |
|
20 |
def policy(self, state):
|
21 |
return self.Pi[state]
|
@@ -31,9 +33,13 @@ class DPAgent(Shared):
|
|
31 |
Q = np.zeros(self.env.action_space.n)
|
32 |
for action in range(self.env.action_space.n):
|
33 |
expected_value = 0
|
34 |
-
for probability, next_state, reward, done in self.env.P[state][
|
|
|
|
|
35 |
# if state == self.env.observation_space.n-1: reward = 1
|
36 |
-
expected_value += probability * (
|
|
|
|
|
37 |
Q[action] = expected_value
|
38 |
action, value = np.argmax(Q), np.max(Q)
|
39 |
|
@@ -54,12 +60,14 @@ class DPAgent(Shared):
|
|
54 |
expected_value = 0
|
55 |
for probability, next_state, reward, done in self.env.P[s][a]:
|
56 |
# if state == self.env.observation_space.n-1: reward = 1
|
57 |
-
expected_value += probability * (
|
58 |
-
|
|
|
|
|
59 |
idxs = np.argmax(self.Pi, axis=1)
|
60 |
print(idxs)
|
61 |
-
self.Pi = np.zeros((self.env.observation_space.n,self.env.action_space.n))
|
62 |
-
self.Pi[np.arange(self.env.observation_space.n),idxs] = 1
|
63 |
# print(self.Pi)
|
64 |
# return self.V, self.Pi
|
65 |
|
@@ -68,17 +76,22 @@ if __name__ == "__main__":
|
|
68 |
# env = gym.make('FrozenLake-v1', render_mode='human')
|
69 |
dp = DPAgent(env="FrozenLake-v1", gamma=0.99)
|
70 |
dp.train()
|
71 |
-
dp.save_policy(
|
72 |
-
env = gym.make(
|
73 |
-
"
|
74 |
-
"
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
state, _ = env.reset()
|
84 |
done = False
|
|
|
8 |
|
9 |
|
10 |
class DPAgent(Shared):
|
11 |
+
def __init__(self, /, **kwargs):
|
12 |
+
super().__init__(run_name=self.__class__.__name__, **kwargs)
|
13 |
+
self.theta = kwargs.get("theta", 1e-10)
|
14 |
print(self.theta)
|
15 |
self.V = np.zeros(self.env.observation_space.n)
|
16 |
self.Pi = np.zeros(self.env.observation_space.n, self.env.action_space.n)
|
17 |
if self.gamma >= 1.0:
|
18 |
+
warnings.warn(
|
19 |
+
"DP will never converge with a gamma value =1.0. Try 0.99?", UserWarning
|
20 |
+
)
|
21 |
|
22 |
def policy(self, state):
|
23 |
return self.Pi[state]
|
|
|
33 |
Q = np.zeros(self.env.action_space.n)
|
34 |
for action in range(self.env.action_space.n):
|
35 |
expected_value = 0
|
36 |
+
for probability, next_state, reward, done in self.env.P[state][
|
37 |
+
action
|
38 |
+
]:
|
39 |
# if state == self.env.observation_space.n-1: reward = 1
|
40 |
+
expected_value += probability * (
|
41 |
+
reward + self.gamma * self.V[next_state]
|
42 |
+
)
|
43 |
Q[action] = expected_value
|
44 |
action, value = np.argmax(Q), np.max(Q)
|
45 |
|
|
|
60 |
expected_value = 0
|
61 |
for probability, next_state, reward, done in self.env.P[s][a]:
|
62 |
# if state == self.env.observation_space.n-1: reward = 1
|
63 |
+
expected_value += probability * (
|
64 |
+
reward + self.gamma * self.V[next_state]
|
65 |
+
)
|
66 |
+
self.Pi[s, a] = expected_value
|
67 |
idxs = np.argmax(self.Pi, axis=1)
|
68 |
print(idxs)
|
69 |
+
self.Pi = np.zeros((self.env.observation_space.n, self.env.action_space.n))
|
70 |
+
self.Pi[np.arange(self.env.observation_space.n), idxs] = 1
|
71 |
# print(self.Pi)
|
72 |
# return self.V, self.Pi
|
73 |
|
|
|
76 |
# env = gym.make('FrozenLake-v1', render_mode='human')
|
77 |
dp = DPAgent(env="FrozenLake-v1", gamma=0.99)
|
78 |
dp.train()
|
79 |
+
dp.save_policy("dp_policy.npy")
|
80 |
+
env = gym.make(
|
81 |
+
"FrozenLake-v1",
|
82 |
+
render_mode="human",
|
83 |
+
is_slippery=False,
|
84 |
+
desc=[
|
85 |
+
"SFFFFFFF",
|
86 |
+
"FFFFFFFH",
|
87 |
+
"FFFHFFFF",
|
88 |
+
"FFFFFHFF",
|
89 |
+
"FFFHFFFF",
|
90 |
+
"FHHFFFHF",
|
91 |
+
"FHFFHFHF",
|
92 |
+
"FFFHFFFG",
|
93 |
+
],
|
94 |
+
)
|
95 |
|
96 |
state, _ = env.reset()
|
97 |
done = False
|
policies/DPAgent_CliffWalking-v0_e2500_s200_g0.99_e0.4_first_visit.npy
DELETED
Binary file (1.66 kB)
|
|
policies/DPAgent_FrozenLake-v1_e2500_s200_g0.99_e0.4_first_visit.npy
DELETED
Binary file (2.18 kB)
|
|