|
import os |
|
import random, datetime |
|
from pathlib import Path |
|
import retro as gym |
|
from collections import namedtuple, deque |
|
from itertools import count |
|
|
|
import torch |
|
import matplotlib |
|
import matplotlib.pyplot as plt |
|
|
|
from cartpole import MyAgent, MetricLogger |
|
from wrappers import make_env |
|
import pickle |
|
import gym |
|
from tqdm import trange |
|
|
|
|
|
is_ipython = 'inline' in matplotlib.get_backend() |
|
if is_ipython: |
|
from IPython import display |
|
|
|
plt.ion() |
|
|
|
|
|
|
|
env = gym.make('CartPole-v1') |
|
|
|
use_cuda = torch.cuda.is_available() |
|
print(f"Using CUDA: {use_cuda}") |
|
print() |
|
|
|
path = "checkpoints/cartpole/latest" |
|
save_dir = Path(path) |
|
|
|
isExist = os.path.exists(path) |
|
if not isExist: |
|
os.makedirs(path) |
|
|
|
|
|
|
|
|
|
checkpoint = None |
|
|
|
|
|
|
|
n_actions = env.action_space.n |
|
state = env.reset() |
|
n_observations = len(state) |
|
max_memory_size=100000 |
|
agent = MyAgent( |
|
state_dim=n_observations, |
|
action_dim=n_actions, |
|
save_dir=save_dir, |
|
checkpoint=checkpoint, |
|
reset_exploration_rate=True, |
|
max_memory_size=max_memory_size |
|
) |
|
|
|
|
|
|
|
|
|
|
|
logger = MetricLogger(save_dir) |
|
|
|
|
|
|
|
def fill_memory(agent: MyAgent): |
|
print("Filling up memory....") |
|
for _ in trange(max_memory_size): |
|
state = env.reset() |
|
done = False |
|
while not done: |
|
action = agent.act(state) |
|
next_state, reward, done, info = env.step(action) |
|
agent.cache(state, next_state, action, reward, done) |
|
state = next_state |
|
|
|
def train(agent: MyAgent): |
|
episodes = 10000000 |
|
for e in range(episodes): |
|
|
|
state = env.reset() |
|
|
|
while True: |
|
|
|
|
|
|
|
action = agent.act(state) |
|
|
|
|
|
next_state, reward, done, info = env.step(action) |
|
|
|
|
|
agent.cache(state, next_state, action, reward, done) |
|
|
|
|
|
q, loss = agent.learn() |
|
|
|
|
|
logger.log_step(reward, loss, q) |
|
|
|
|
|
state = next_state |
|
|
|
|
|
|
|
|
|
|
|
if done: |
|
break |
|
|
|
logger.log_episode(e) |
|
|
|
if e % 20 == 0: |
|
logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step) |
|
|
|
|
|
fill_memory(agent) |
|
train(agent) |
|
|