|
import datetime |
|
from pathlib import Path |
|
from itertools import count |
|
from agent import DQNAgent, MetricLogger |
|
from wrappers import make_env, make_starpilot |
|
|
|
|
|
env = make_starpilot() |
|
|
|
env.reset() |
|
|
|
save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
|
save_dir.mkdir(parents=True) |
|
|
|
checkpoint = Path('checkpoints/procgen-starpilot-dqn/airstriker_net_3.chkpt') |
|
|
|
agent = DQNAgent( |
|
state_dim=(1, 64, 64), |
|
action_dim=env.action_space.n, |
|
save_dir=save_dir, |
|
batch_size=256, |
|
checkpoint=checkpoint, |
|
reset_exploration_rate=True, |
|
exploration_rate_decay=0.999999, |
|
training_frequency=10, |
|
target_network_sync_frequency=200, |
|
max_memory_size=3000, |
|
learning_rate=0.001, |
|
save_frequency=2000 |
|
|
|
) |
|
agent.exploration_rate = agent.exploration_rate_min |
|
|
|
|
|
|
|
episodes = 100 |
|
|
|
for e in range(episodes): |
|
|
|
state = env.reset() |
|
|
|
while True: |
|
|
|
env.render() |
|
|
|
action = agent.act(state) |
|
|
|
next_state, reward, done, info = env.step(action) |
|
|
|
agent.cache(state, next_state, action, reward, done) |
|
|
|
|
|
|
|
state = next_state |
|
|
|
if done: |
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|