File size: 2,698 Bytes
e085e3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import torch
import matplotlib
import matplotlib.pyplot as plt

from pathlib import Path
from tqdm import trange
from agent import DQNAgent, DDQNAgent, MetricLogger
from wrappers import make_env


# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()


env = make_env()

use_cuda = torch.cuda.is_available()
print(f"Using CUDA: {use_cuda}\n")


checkpoint = None 
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')

path = "checkpoints/airstriker-dqn-new"
save_dir = Path(path) 

isExist = os.path.exists(path)
if not isExist:
   os.makedirs(path)

# Vanilla DQN
print("Training Vanilla DQN Agent!")
agent = DQNAgent(
    state_dim=(1, 84, 84), 
    action_dim=env.action_space.n,
    save_dir=save_dir, 
    batch_size=128,
    checkpoint=checkpoint,  
    exploration_rate_decay=0.995,
    exploration_rate_min=0.05,
    training_frequency=1, 
    target_network_sync_frequency=500,
    max_memory_size=50000,
    learning_rate=0.0005,

)

# Double DQN
# print("Training DDQN Agent!")
# agent = DDQNAgent(
#     state_dim=(1, 84, 84), 
#     action_dim=env.action_space.n,
#     save_dir=save_dir, 
#     checkpoint=checkpoint, 
#     reset_exploration_rate=True, 
#     max_memory_size=max_memory_size
# )

logger = MetricLogger(save_dir)

def fill_memory(agent: DQNAgent, num_episodes=1000):
    print("Filling up memory....")
    for _ in trange(num_episodes):
        state = env.reset()
        done = False 
        while not done:
            action = agent.act(state)
            next_state, reward, done, _ = env.step(action)
            agent.cache(state, next_state, action, reward, done)
            state = next_state


def train(agent: DQNAgent):
    episodes = 10000000
    for e in range(episodes):

        state = env.reset()
        # Play the game!
        while True:
        
            # print(state.shape)
            # Run agent on the state
            action = agent.act(state)
            
            # Agent performs action
            next_state, reward, done, info = env.step(action)
            
            # Remember
            agent.cache(state, next_state, action, reward, done)

            # Learn
            q, loss = agent.learn()

            # Logging
            logger.log_step(reward, loss, q)

            # Update state
            state = next_state
            
            # Check if end of game
            if done or info["gameover"] == 1:
                break

        logger.log_episode(e)

        if e % 20 == 0:
            logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)

fill_memory(agent)
train(agent)