File size: 1,078 Bytes
e085e3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import os
import torch
from pathlib import Path
from agent import DuelingDDQNAgent, DuelingDDQNAgentWithStepDecay,MetricLogger
from wrappers import make_lunar
import os
from train import train, fill_memory
from params import hyperparams
env = make_lunar()
use_cuda = torch.cuda.is_available()
print(f"Using CUDA: {use_cuda}\n")
checkpoint = None
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
path = "checkpoints/lunar-lander-dueling-ddqn-rc"
save_dir = Path(path)
isExist = os.path.exists(path)
if not isExist:
os.makedirs(path)
logger = MetricLogger(save_dir)
print("Training Dueling DDQN Agent with step decay!")
agent = DuelingDDQNAgentWithStepDecay(
state_dim=8,
action_dim=env.action_space.n,
save_dir=save_dir,
checkpoint=checkpoint,
**hyperparams
)
# print("Training Dueling DDQN Agent!")
# agent = DuelingDDQNAgent(
# state_dim=8,
# action_dim=env.action_space.n,
# save_dir=save_dir,
# checkpoint=checkpoint,
# **hyperparams
# )
# fill_memory(agent, env, 5000)
train(agent, env, logger)
|