File size: 2,864 Bytes
e085e3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import random, datetime
from pathlib import Path
import retro as gym
from collections import namedtuple, deque
from itertools import count

import torch
import matplotlib
import matplotlib.pyplot as plt
# from agent import MyAgent, MyDQN, MetricLogger
from cartpole import MyAgent, MetricLogger
from wrappers import make_env
import pickle
import gym 
from tqdm import trange

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()


# env = make_env()
env = gym.make('CartPole-v1')

use_cuda = torch.cuda.is_available()
print(f"Using CUDA: {use_cuda}")
print()

path = "checkpoints/cartpole/latest"
save_dir = Path(path) 

isExist = os.path.exists(path)
if not isExist:
   os.makedirs(path)

# save_dir.mkdir(parents=True)


checkpoint = None 
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')

# For cartpole
n_actions = env.action_space.n
state = env.reset()
n_observations = len(state)
max_memory_size=100000
agent = MyAgent(
    state_dim=n_observations, 
    action_dim=n_actions, 
    save_dir=save_dir, 
    checkpoint=checkpoint, 
    reset_exploration_rate=True,
    max_memory_size=max_memory_size
)

# For airstriker
# agent = MyAgent(state_dim=(1, 84, 84), action_dim=env.action_space.n, save_dir=save_dir, checkpoint=checkpoint, reset_exploration_rate=True)


logger = MetricLogger(save_dir)



def fill_memory(agent: MyAgent):
    print("Filling up memory....")
    for _ in trange(max_memory_size):
        state = env.reset()
        done = False 
        while not done:
            action = agent.act(state)
            next_state, reward, done, info = env.step(action)
            agent.cache(state, next_state, action, reward, done)
            state = next_state

def train(agent: MyAgent):
    episodes = 10000000
    for e in range(episodes):

        state = env.reset()
        # Play the game!
        while True:
        
            # print(state.shape)
            # Run agent on the state
            action = agent.act(state)
            
            # Agent performs action
            next_state, reward, done, info = env.step(action)
            
            # Remember
            agent.cache(state, next_state, action, reward, done)

            # Learn
            q, loss = agent.learn()

            # Logging
            logger.log_step(reward, loss, q)

            # Update state
            state = next_state
            
            # # Check if end of game (for airstriker)
            # if done or info["gameover"] == 1:
            #     break
            # Check if end of game (for cartpole)
            if done:
                break

        logger.log_episode(e)

        if e % 20 == 0:
            logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)


fill_memory(agent)
train(agent)