00BER's picture
Upload 36 files
e085e3b
import os
import random, datetime
from pathlib import Path
import retro as gym
from collections import namedtuple, deque
from itertools import count
import torch
import matplotlib
import matplotlib.pyplot as plt
# from agent import MyAgent, MyDQN, MetricLogger
from cartpole import MyAgent, MetricLogger
from wrappers import make_env
import pickle
import gym
from tqdm import trange
# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
from IPython import display
plt.ion()
# env = make_env()
env = gym.make('CartPole-v1')
use_cuda = torch.cuda.is_available()
print(f"Using CUDA: {use_cuda}")
print()
path = "checkpoints/cartpole/latest"
save_dir = Path(path)
isExist = os.path.exists(path)
if not isExist:
os.makedirs(path)
# save_dir.mkdir(parents=True)
checkpoint = None
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
# For cartpole
n_actions = env.action_space.n
state = env.reset()
n_observations = len(state)
max_memory_size=100000
agent = MyAgent(
state_dim=n_observations,
action_dim=n_actions,
save_dir=save_dir,
checkpoint=checkpoint,
reset_exploration_rate=True,
max_memory_size=max_memory_size
)
# For airstriker
# agent = MyAgent(state_dim=(1, 84, 84), action_dim=env.action_space.n, save_dir=save_dir, checkpoint=checkpoint, reset_exploration_rate=True)
logger = MetricLogger(save_dir)
def fill_memory(agent: MyAgent):
print("Filling up memory....")
for _ in trange(max_memory_size):
state = env.reset()
done = False
while not done:
action = agent.act(state)
next_state, reward, done, info = env.step(action)
agent.cache(state, next_state, action, reward, done)
state = next_state
def train(agent: MyAgent):
episodes = 10000000
for e in range(episodes):
state = env.reset()
# Play the game!
while True:
# print(state.shape)
# Run agent on the state
action = agent.act(state)
# Agent performs action
next_state, reward, done, info = env.step(action)
# Remember
agent.cache(state, next_state, action, reward, done)
# Learn
q, loss = agent.learn()
# Logging
logger.log_step(reward, loss, q)
# Update state
state = next_state
# # Check if end of game (for airstriker)
# if done or info["gameover"] == 1:
# break
# Check if end of game (for cartpole)
if done:
break
logger.log_episode(e)
if e % 20 == 0:
logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
fill_memory(agent)
train(agent)