Spaces:
Sleeping
Sleeping
File size: 3,530 Bytes
f902143 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
# ExpectedSarsaAgent.py
import sys
from tabulate import tabulate
import numpy as np
episode = [["s1", "E", 0],
["s2", "E", 1],
["s3", "N", 2],
["s3", "N", 3],
["s3", "S", 4],
["s6", "S", 5],
["s9", None, None]]
index_map = {
"s1": 0,
"s2": 1,
"s3": 2,
"s4": 3,
"s5": 4,
"s6": 5,
"s7": 6,
"s8": 7,
"s9": 8,
"N": 0,
"E": 1,
"S": 2,
"W": 3
}
class nStepTDAgent():
def __init__(self, alpha, gamma, num_state, num_actions):
"""
Constructor
Args:
epsilon: The degree of exploration
gamma: The discount factor
num_state: The number of states
num_actions: The number of actions
"""
self.alpha = alpha
self.gamma = gamma
self.num_state = num_state
self.num_actions = num_actions
self.Q = np.zeros((self.num_state, self.num_actions))
def run_episode(self, n, episode):
"""
Update the action value function using the n-step TD update.
"""
rew = [0, ]
bigT = sys.maxsize
print("T: ", bigT)
for t, step in enumerate(episode.reverse()):
print("=" * 80)
print("Step: ", t)
if t < bigT:
s_t, a_t, r_t1 = step
print(f" s_t: {s_t}, a_t: {a_t}, r_t1: {r_t1}")
s_t1, _, _ = episode[t + 1]
rew.append(r_t1)
_, _, r_t2 = episode[t + 1]
if r_t2 is None:
bigT = t + 1
print("TERMINAL => T: ", bigT)
Tt = t - n + 1
print(f" Tt: {Tt}")
if Tt >= 0:
print(f' ==============')
bigG = 0
for i in range(Tt + 1, min(Tt + n , bigT) + 1):
print(f" i: {i}")
r_t1 = rew[i]
print(f" r_t{i}: {r_t1}")
print(f" {bigG} += {self.gamma}^{i - Tt - 1} * {r_t1}")
bigG += self.gamma**(i - Tt - 1) * r_t1
print(f" G: {bigG}")
print(f' --------------')
if Tt + n < bigT:
s_Tn, a_Tn = episode[Tt + n][0], episode[Tt + n][1]
print(f" s_Tn: {s_Tn}, a_Tn: {a_Tn}")
s_Tn, a_Tn = index_map[s_Tn], index_map[a_Tn]
print(f" {bigG} += {self.gamma}^{n} * {self.Q[s_Tn, a_Tn]}")
bigG += (self.gamma**n) * self.Q[s_Tn, a_Tn]
print(f" G: {bigG}")
print(f' ==============')
s_Tt, a_Tt = episode[Tt][0], episode[Tt][1]
print(f" => Update Q[{s_Tt}, {a_Tt}]")
s_Tti, a_Tti = index_map[s_Tt], index_map[a_Tt]
print(f" Q[{s_Tt}, {a_Tt}] = {self.Q[s_Tti, a_Tti]}")
self.Q[s_Tti, a_Tti] += self.alpha * (bigG - self.Q[s_Tti, a_Tti])
print(f" Q[{s_Tt}, {a_Tt}] = {self.Q[s_Tti, a_Tti]}")
print(f"Q:")
print(tabulate(self.Q, tablefmt="fancy_grid"))
if Tt == bigT - 1:
break
def main_r():
print("# nStepTDAgent.py")
agent = nStepTDAgent(0.1, 0.9, 9, 4)
print(agent.Q)
agent.run_episode(3, episode)
if __name__ == "__main__":
main_r()
|