File size: 1,416 Bytes
e085e3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import datetime
from pathlib import Path
from itertools import count
from agent import DQNAgent,  MetricLogger
from wrappers import make_env, make_starpilot


env = make_starpilot()

env.reset()

save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)

checkpoint = Path('checkpoints/procgen-starpilot-dqn/airstriker_net_3.chkpt')

agent = DQNAgent(
    state_dim=(1, 64, 64), 
    action_dim=env.action_space.n,
    save_dir=save_dir, 
    batch_size=256,
    checkpoint=checkpoint, 
    reset_exploration_rate=True, 
    exploration_rate_decay=0.999999,
    training_frequency=10, 
    target_network_sync_frequency=200,
    max_memory_size=3000,
    learning_rate=0.001,
    save_frequency=2000

)
agent.exploration_rate = agent.exploration_rate_min

# logger = MetricLogger(save_dir)

episodes = 100

for e in range(episodes):

    state = env.reset()

    while True:

        env.render()

        action = agent.act(state)

        next_state, reward, done, info = env.step(action)

        agent.cache(state, next_state, action, reward, done)

        # logger.log_step(reward, None, None)

        state = next_state

        if done:
            break

    # logger.log_episode()

    # if e % 20 == 0:
    #     logger.record(
    #         episode=e,
    #         epsilon=agent.exploration_rate,
    #         step=agent.curr_step
    #     )