File size: 10,369 Bytes
7bfbe05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 |
import gym
import numpy as np
import os
from dataclasses import asdict, astuple
from gym.vector.async_vector_env import AsyncVectorEnv
from gym.vector.sync_vector_env import SyncVectorEnv
from gym.wrappers.resize_observation import ResizeObservation
from gym.wrappers.gray_scale_observation import GrayScaleObservation
from gym.wrappers.frame_stack import FrameStack
from stable_baselines3.common.atari_wrappers import (
MaxAndSkipEnv,
NoopResetEnv,
)
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from torch.utils.tensorboard.writer import SummaryWriter
from typing import Callable, Optional
from rl_algo_impls.runner.config import Config, EnvHyperparams
from rl_algo_impls.shared.policy.policy import VEC_NORMALIZE_FILENAME
from rl_algo_impls.wrappers.atari_wrappers import (
EpisodicLifeEnv,
FireOnLifeStarttEnv,
ClipRewardEnv,
)
from rl_algo_impls.wrappers.episode_record_video import EpisodeRecordVideo
from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter
from rl_algo_impls.wrappers.initial_step_truncate_wrapper import (
InitialStepTruncateWrapper,
)
from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv
from rl_algo_impls.wrappers.no_reward_timeout import NoRewardTimeout
from rl_algo_impls.wrappers.noop_env_seed import NoopEnvSeed
from rl_algo_impls.wrappers.normalize import NormalizeObservation, NormalizeReward
from rl_algo_impls.wrappers.sync_vector_env_render_compat import (
SyncVectorEnvRenderCompat,
)
from rl_algo_impls.wrappers.transpose_image_observation import TransposeImageObservation
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
from rl_algo_impls.wrappers.video_compat_wrapper import VideoCompatWrapper
def make_env(
config: Config,
hparams: EnvHyperparams,
training: bool = True,
render: bool = False,
normalize_load_path: Optional[str] = None,
tb_writer: Optional[SummaryWriter] = None,
) -> VecEnv:
if hparams.env_type == "procgen":
return _make_procgen_env(
config,
hparams,
training=training,
render=render,
normalize_load_path=normalize_load_path,
tb_writer=tb_writer,
)
elif hparams.env_type in {"sb3vec", "gymvec"}:
return _make_vec_env(
config,
hparams,
training=training,
render=render,
normalize_load_path=normalize_load_path,
tb_writer=tb_writer,
)
else:
raise ValueError(f"env_type {hparams.env_type} not supported")
def make_eval_env(
config: Config,
hparams: EnvHyperparams,
override_n_envs: Optional[int] = None,
**kwargs,
) -> VecEnv:
kwargs = kwargs.copy()
kwargs["training"] = False
if override_n_envs is not None:
hparams_kwargs = asdict(hparams)
hparams_kwargs["n_envs"] = override_n_envs
if override_n_envs == 1:
hparams_kwargs["vec_env_class"] = "sync"
hparams = EnvHyperparams(**hparams_kwargs)
return make_env(config, hparams, **kwargs)
def _make_vec_env(
config: Config,
hparams: EnvHyperparams,
training: bool = True,
render: bool = False,
normalize_load_path: Optional[str] = None,
tb_writer: Optional[SummaryWriter] = None,
) -> VecEnv:
(
env_type,
n_envs,
frame_stack,
make_kwargs,
no_reward_timeout_steps,
no_reward_fire_steps,
vec_env_class,
normalize,
normalize_kwargs,
rolling_length,
train_record_video,
video_step_interval,
initial_steps_to_truncate,
clip_atari_rewards,
) = astuple(hparams)
if "BulletEnv" in config.env_id:
import pybullet_envs
spec = gym.spec(config.env_id)
seed = config.seed(training=training)
make_kwargs = make_kwargs.copy() if make_kwargs is not None else {}
if "BulletEnv" in config.env_id and render:
make_kwargs["render"] = True
if "CarRacing" in config.env_id:
make_kwargs["verbose"] = 0
if "procgen" in config.env_id:
if not render:
make_kwargs["render_mode"] = "rgb_array"
def make(idx: int) -> Callable[[], gym.Env]:
def _make() -> gym.Env:
env = gym.make(config.env_id, **make_kwargs)
env = gym.wrappers.RecordEpisodeStatistics(env)
env = VideoCompatWrapper(env)
if training and train_record_video and idx == 0:
env = EpisodeRecordVideo(
env,
config.video_prefix,
step_increment=n_envs,
video_step_interval=int(video_step_interval),
)
if training and initial_steps_to_truncate:
env = InitialStepTruncateWrapper(
env, idx * initial_steps_to_truncate // n_envs
)
if "AtariEnv" in spec.entry_point: # type: ignore
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env, training=training)
action_meanings = env.unwrapped.get_action_meanings()
if "FIRE" in action_meanings: # type: ignore
env = FireOnLifeStarttEnv(env, action_meanings.index("FIRE"))
if clip_atari_rewards:
env = ClipRewardEnv(env, training=training)
env = ResizeObservation(env, (84, 84))
env = GrayScaleObservation(env, keep_dim=False)
env = FrameStack(env, frame_stack)
elif "CarRacing" in config.env_id:
env = ResizeObservation(env, (64, 64))
env = GrayScaleObservation(env, keep_dim=False)
env = FrameStack(env, frame_stack)
elif "procgen" in config.env_id:
# env = GrayScaleObservation(env, keep_dim=False)
env = NoopEnvSeed(env)
env = TransposeImageObservation(env)
if frame_stack > 1:
env = FrameStack(env, frame_stack)
if no_reward_timeout_steps:
env = NoRewardTimeout(
env, no_reward_timeout_steps, n_fire_steps=no_reward_fire_steps
)
if seed is not None:
env.seed(seed + idx)
env.action_space.seed(seed + idx)
env.observation_space.seed(seed + idx)
return env
return _make
if env_type == "sb3vec":
VecEnvClass = {"sync": DummyVecEnv, "async": SubprocVecEnv}[vec_env_class]
elif env_type == "gymvec":
VecEnvClass = {"sync": SyncVectorEnv, "async": AsyncVectorEnv}[vec_env_class]
else:
raise ValueError(f"env_type {env_type} unsupported")
envs = VecEnvClass([make(i) for i in range(n_envs)])
if env_type == "gymvec" and vec_env_class == "sync":
envs = SyncVectorEnvRenderCompat(envs)
if training:
assert tb_writer
envs = EpisodeStatsWriter(
envs, tb_writer, training=training, rolling_length=rolling_length
)
if normalize:
normalize_kwargs = normalize_kwargs or {}
if env_type == "sb3vec":
if normalize_load_path:
envs = VecNormalize.load(
os.path.join(normalize_load_path, VEC_NORMALIZE_FILENAME),
envs, # type: ignore
)
else:
envs = VecNormalize(
envs, # type: ignore
training=training,
**normalize_kwargs,
)
if not training:
envs.norm_reward = False
else:
if normalize_kwargs.get("norm_obs", True):
envs = NormalizeObservation(
envs, training=training, clip=normalize_kwargs.get("clip_obs", 10.0)
)
if training and normalize_kwargs.get("norm_reward", True):
envs = NormalizeReward(
envs,
training=training,
clip=normalize_kwargs.get("clip_reward", 10.0),
)
return envs
def _make_procgen_env(
config: Config,
hparams: EnvHyperparams,
training: bool = True,
render: bool = False,
normalize_load_path: Optional[str] = None,
tb_writer: Optional[SummaryWriter] = None,
) -> VecEnv:
from gym3 import ViewerWrapper, ExtractDictObWrapper
from procgen.env import ProcgenGym3Env, ToBaselinesVecEnv
(
_, # env_type
n_envs,
_, # frame_stack
make_kwargs,
_, # no_reward_timeout_steps
_, # no_reward_fire_steps
_, # vec_env_class
normalize,
normalize_kwargs,
rolling_length,
_, # train_record_video
_, # video_step_interval
_, # initial_steps_to_truncate
_, # clip_atari_rewards
) = astuple(hparams)
seed = config.seed(training=training)
make_kwargs = make_kwargs or {}
make_kwargs["render_mode"] = "rgb_array"
if seed is not None:
make_kwargs["rand_seed"] = seed
envs = ProcgenGym3Env(n_envs, config.env_id, **make_kwargs)
envs = ExtractDictObWrapper(envs, key="rgb")
if render:
envs = ViewerWrapper(envs, info_key="rgb")
envs = ToBaselinesVecEnv(envs)
envs = IsVectorEnv(envs)
# TODO: Handle Grayscale and/or FrameStack
envs = TransposeImageObservation(envs)
envs = gym.wrappers.RecordEpisodeStatistics(envs)
if seed is not None:
envs.action_space.seed(seed)
envs.observation_space.seed(seed)
if training:
assert tb_writer
envs = EpisodeStatsWriter(
envs, tb_writer, training=training, rolling_length=rolling_length
)
if normalize and training:
normalize_kwargs = normalize_kwargs or {}
envs = gym.wrappers.NormalizeReward(envs)
clip_obs = normalize_kwargs.get("clip_reward", 10.0)
envs = gym.wrappers.TransformReward(
envs, lambda r: np.clip(r, -clip_obs, clip_obs)
)
return envs # type: ignore
|