File size: 3,530 Bytes
f902143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# ExpectedSarsaAgent.py

import sys

from tabulate import tabulate
import numpy as np

episode = [["s1", "E", 0], 
           ["s2", "E", 1],
           ["s3", "N", 2], 
           ["s3", "N", 3], 
           ["s3", "S", 4],
           ["s6", "S", 5],
           ["s9", None, None]]

index_map = {
    "s1": 0,
    "s2": 1,
    "s3": 2,
    "s4": 3,
    "s5": 4,
    "s6": 5,
    "s7": 6,
    "s8": 7,
    "s9": 8,
    "N": 0,
    "E": 1,
    "S": 2,
    "W": 3
}

class nStepTDAgent():
    def __init__(self, alpha, gamma, num_state, num_actions):
        """
        Constructor
        Args:
            epsilon: The degree of exploration
            gamma: The discount factor
            num_state: The number of states
            num_actions: The number of actions
        """
        self.alpha = alpha
        self.gamma = gamma
        self.num_state = num_state
        self.num_actions = num_actions

        self.Q = np.zeros((self.num_state, self.num_actions))

    def run_episode(self, n, episode):
        """
        Update the action value function using the n-step TD update.
        """
        
        rew = [0, ]
        
        bigT = sys.maxsize
        print("T: ", bigT)
        for t, step in enumerate(episode.reverse()):
            print("=" * 80)
            print("Step: ", t)
            if t < bigT:
                s_t, a_t, r_t1 = step
                print(f" s_t: {s_t}, a_t: {a_t}, r_t1: {r_t1}")
                s_t1, _, _ = episode[t + 1]
                rew.append(r_t1)
                
                _, _, r_t2 = episode[t + 1]
                if r_t2 is None:
                    bigT = t + 1
                    print("TERMINAL => T: ", bigT)
                    
            Tt = t - n + 1
            print(f" Tt: {Tt}")
            if Tt >= 0:
                print(f' ==============')
                bigG = 0
                for i in range(Tt + 1, min(Tt + n , bigT) + 1):
                    print(f" i: {i}")
                    r_t1 = rew[i]
                    print(f" r_t{i}: {r_t1}")
                    print(f"      {bigG} += {self.gamma}^{i - Tt - 1} * {r_t1}")
                    bigG += self.gamma**(i - Tt - 1) * r_t1
                print(f" G: {bigG}")
                print(f' --------------')
                if Tt + n < bigT:
                    s_Tn, a_Tn = episode[Tt + n][0], episode[Tt + n][1]
                    
                    print(f"   s_Tn: {s_Tn}, a_Tn: {a_Tn}")
                    s_Tn, a_Tn = index_map[s_Tn], index_map[a_Tn]
                    print(f"      {bigG} += {self.gamma}^{n} * {self.Q[s_Tn, a_Tn]}")
                    bigG += (self.gamma**n) * self.Q[s_Tn, a_Tn]
                print(f" G: {bigG}")
                print(f' ==============')
                
                s_Tt, a_Tt = episode[Tt][0], episode[Tt][1]
                print(f" => Update Q[{s_Tt}, {a_Tt}]")
                s_Tti, a_Tti = index_map[s_Tt], index_map[a_Tt]
                print(f" Q[{s_Tt}, {a_Tt}] = {self.Q[s_Tti, a_Tti]}")
                self.Q[s_Tti, a_Tti] += self.alpha * (bigG - self.Q[s_Tti, a_Tti])
                print(f" Q[{s_Tt}, {a_Tt}] = {self.Q[s_Tti, a_Tti]}")
            print(f"Q:")
            print(tabulate(self.Q, tablefmt="fancy_grid"))
            if Tt == bigT - 1:
                break
                    
                




def main_r():
    print("# nStepTDAgent.py")
    agent = nStepTDAgent(0.1, 0.9, 9, 4)
    print(agent.Q)
    agent.run_episode(3, episode)



if __name__ == "__main__":
    main_r()