File size: 9,861 Bytes
9b19c29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 |
import argparse
import datetime
import os
import pprint
import sys
import numpy as np
import torch
from atari_network import Rainbow
from atari_wrapper import make_atari_env
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import C51Policy, RainbowPolicy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument("--eps-train", type=float, default=1.0)
parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=0.0000625)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--num-atoms", type=int, default=51)
parser.add_argument("--v-min", type=float, default=-10.0)
parser.add_argument("--v-max", type=float, default=10.0)
parser.add_argument("--noisy-std", type=float, default=0.1)
parser.add_argument("--no-dueling", action="store_true", default=False)
parser.add_argument("--no-noisy", action="store_true", default=False)
parser.add_argument("--no-priority", action="store_true", default=False)
parser.add_argument("--alpha", type=float, default=0.5)
parser.add_argument("--beta", type=float, default=0.4)
parser.add_argument("--beta-final", type=float, default=1.0)
parser.add_argument("--beta-anneal-step", type=int, default=5000000)
parser.add_argument("--no-weight-norm", action="store_true", default=False)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.0)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument(
"--watch",
default=False,
action="store_true",
help="watch the play of pre-trained policy only",
)
parser.add_argument("--save-buffer-name", type=str, default=None)
return parser.parse_args()
def test_rainbow(args: argparse.Namespace = get_args()) -> None:
env, train_envs, test_envs = make_atari_env(
args.task,
args.seed,
args.training_num,
args.test_num,
scale=args.scale_obs,
frame_stack=args.frames_stack,
)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# define model
net = Rainbow(
*args.state_shape,
args.action_shape,
args.num_atoms,
args.noisy_std,
args.device,
is_dueling=not args.no_dueling,
is_noisy=not args.no_noisy,
)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy: C51Policy = RainbowPolicy(
model=net,
optim=optim,
discount_factor=args.gamma,
action_space=env.action_space,
num_atoms=args.num_atoms,
v_min=args.v_min,
v_max=args.v_max,
estimation_step=args.n_step,
target_update_freq=args.target_update_freq,
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer
if args.no_priority:
buffer = VectorReplayBuffer(
args.buffer_size,
buffer_num=len(train_envs),
ignore_obs_next=True,
save_only_last_obs=True,
stack_num=args.frames_stack,
)
else:
buffer = PrioritizedVectorReplayBuffer(
args.buffer_size,
buffer_num=len(train_envs),
ignore_obs_next=True,
save_only_last_obs=True,
stack_num=args.frames_stack,
alpha=args.alpha,
beta=args.beta,
weight_norm=not args.no_weight_norm,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "rainbow"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
logger_factory = LoggerFactoryDefault()
if args.logger == "wandb":
logger_factory.logger_type = "wandb"
logger_factory.wandb_project = args.wandb_project
else:
logger_factory.logger_type = "tensorboard"
logger = logger_factory.create_logger(
log_dir=log_path,
experiment_name=log_name,
run_id=args.resume_id,
config_dict=vars(args),
)
def save_best_fn(policy: BasePolicy) -> None:
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards: float) -> bool:
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
if "Pong" in args.task:
return mean_rewards >= 20
return False
def train_fn(epoch: int, env_step: int) -> None:
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final)
else:
eps = args.eps_train_final
policy.set_eps(eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
if not args.no_priority:
if env_step <= args.beta_anneal_step:
beta = args.beta - env_step / args.beta_anneal_step * (args.beta - args.beta_final)
else:
beta = args.beta_final
buffer.set_beta(beta)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/beta": beta})
def test_fn(epoch: int, env_step: int | None) -> None:
policy.set_eps(args.eps_test)
# watch agent's performance
def watch() -> None:
print("Setup test envs ...")
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = PrioritizedVectorReplayBuffer(
args.buffer_size,
buffer_num=len(test_envs),
ignore_obs_next=True,
save_only_last_obs=True,
stack_num=args.frames_stack,
alpha=args.alpha,
beta=args.beta,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
result.pprint_asdict()
if args.watch:
watch()
sys.exit(0)
# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OffpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=args.epoch,
step_per_epoch=args.step_per_epoch,
step_per_collect=args.step_per_collect,
episode_per_test=args.test_num,
batch_size=args.batch_size,
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_best_fn=save_best_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False,
).run()
pprint.pprint(result)
watch()
if __name__ == "__main__":
test_rainbow(get_args())
|