00BER's picture
Upload 36 files
e085e3b
import os
import torch
import matplotlib
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import trange
from agent import DQNAgent, DDQNAgent, MetricLogger
from wrappers import make_env
# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
from IPython import display
plt.ion()
env = make_env()
use_cuda = torch.cuda.is_available()
print(f"Using CUDA: {use_cuda}\n")
checkpoint = None
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
path = "checkpoints/airstriker-ddqn"
save_dir = Path(path)
isExist = os.path.exists(path)
if not isExist:
os.makedirs(path)
# Vanilla DQN
print("Training Vanilla DQN Agent!")
# agent = DQNAgent(
# state_dim=(1, 84, 84),
# action_dim=env.action_space.n,
# save_dir=save_dir,
# batch_size=128,
# checkpoint=checkpoint,
# exploration_rate_decay=0.995,
# exploration_rate_min=0.05,
# training_frequency=1,
# target_network_sync_frequency=500,
# max_memory_size=50000,
# learning_rate=0.0005,
# )
# Double DQN
print("Training DDQN Agent!")
agent = DDQNAgent(
state_dim=(1, 84, 84),
action_dim=env.action_space.n,
save_dir=save_dir,
batch_size=128,
checkpoint=checkpoint,
exploration_rate_decay=0.995,
exploration_rate_min=0.05,
training_frequency=1,
target_network_sync_frequency=500,
max_memory_size=50000,
learning_rate=0.0005,
)
logger = MetricLogger(save_dir)
def fill_memory(agent: DQNAgent, num_episodes=1000):
print("Filling up memory....")
for _ in trange(num_episodes):
state = env.reset()
done = False
while not done:
action = agent.act(state)
next_state, reward, done, _ = env.step(action)
agent.cache(state, next_state, action, reward, done)
state = next_state
def train(agent: DQNAgent):
episodes = 10000000
for e in range(episodes):
state = env.reset()
# Play the game!
while True:
# print(state.shape)
# Run agent on the state
action = agent.act(state)
# Agent performs action
next_state, reward, done, info = env.step(action)
# Remember
agent.cache(state, next_state, action, reward, done)
# Learn
q, loss = agent.learn()
# Logging
logger.log_step(reward, loss, q)
# Update state
state = next_state
# Check if end of game
if done or info["gameover"] == 1:
break
logger.log_episode(e)
if e % 20 == 0:
logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
fill_memory(agent)
train(agent)