import gym import ale_py import numpy as np from atariari.benchmark.wrapper import AtariARIWrapper from typing import Optional, Union class RepresentedAtariEnv(gym.Wrapper): def __init__(self, env_name, render_mode=None): super().__init__(AtariARIWrapper(gym.make(env_name, render_mode=render_mode))) self.metadata = self.env.metadata self.env_name = env_name self.observation = None self.info = {} self.action_space = self.env.action_space _ = self.env.reset() obs = self.env.labels() obs_dim = len(obs) self.obs_label = obs.keys() self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32) def step(self, action): original_next_obs, reward, env_done, env_truncated, info = self.env.step(action) next_obs = self.env.labels() self.obs_label = next_obs.keys() self.observation = next_obs return np.array(list(next_obs.values())), reward, env_done, env_truncated, info def reset(self, seed=0): obs_original, info = self.env.reset(seed=seed) obs = self.env.labels() self.obs_label = obs.keys() self.observation = obs return np.array(list(obs.values())), info def get_info(self): return self.observation def render(self, render_mode=None): return self.env.render() class RepresentedMsPacman(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "MsPacmanNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) class RepresentedBowling(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "BowlingNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) class RepresentedBoxing(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "BoxingNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) class RepresentedBreakout(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "BreakoutNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) class RepresentedDemonAttack(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "DemonAttackNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) class RepresentedFreeway(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "FreewayNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) class RepresentedFrostbite(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "FrostbiteNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) class RepresentedHero(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "HeroNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) class RepresentedMontezumaRevenge(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "MontezumaRevengeNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) class RepresentedPitfall(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "PitfallNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) class RepresentedPong(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "PongNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) class RepresentedPrivateEye(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "PrivateEyeNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) class RepresentedQbert(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "QbertNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) class RepresentedRiverraid(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "RiverraidNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) class RepresentedSeaquest(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "SeaquestNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) class RepresentedSpaceInvaders(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "SpaceInvadersNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) class RepresentedTennis(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "TennisNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) class RepresentedVenture(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "VentureNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) class RepresentedVideoPinball(RepresentedAtariEnv): def __init__(self, render_mode: Optional[str]=None): env_name = "VideoPinballNoFrameskip-v4" super().__init__(env_name=env_name, render_mode=render_mode) def env_factory(env_class): def _create_instance(render_mode=None): return env_class(render_mode=render_mode) return _create_instance def register_environments(): env_classes = { 'RepresentedMsPacman-v0': RepresentedMsPacman, 'RepresentedBowling-v0': RepresentedBowling, 'RepresentedBoxing-v0': RepresentedBoxing, 'RepresentedBreakout-v0': RepresentedBreakout, 'RepresentedDemonAttack-v0': RepresentedDemonAttack, 'RepresentedFreeway-v0': RepresentedFreeway, 'RepresentedFrostbite-v0': RepresentedFrostbite, 'RepresentedHero-v0': RepresentedHero, 'RepresentedMontezumaRevenge-v0': RepresentedMontezumaRevenge, 'RepresentedPitfall-v0': RepresentedPitfall, 'RepresentedPong-v0': RepresentedPong, 'RepresentedPrivateEye-v0': RepresentedPrivateEye, 'RepresentedQbert-v0': RepresentedQbert, 'RepresentedRiverraid-v0': RepresentedRiverraid, 'RepresentedSeaquest-v0': RepresentedSeaquest, 'RepresentedSpaceInvaders-v0': RepresentedSpaceInvaders, 'RepresentedTennis-v0': RepresentedTennis, 'RepresentedVenture-v0': RepresentedVenture, 'RepresentedVideoPinball-v0': RepresentedVideoPinball } for env_name, env_class in env_classes.items(): gym.register( id=env_name, entry_point=env_factory(env_class), ) # register_environments() # env_classes = { # 'RepresentedMsPacman-v0': RepresentedMsPacman, # 'RepresentedBowling-v0': RepresentedBowling, # 'RepresentedBoxing-v0': RepresentedBoxing, # 'RepresentedBreakout-v0': RepresentedBreakout, # 'RepresentedDemonAttack-v0': RepresentedDemonAttack, # 'RepresentedFreeway-v0': RepresentedFreeway, # 'RepresentedFrostbite-v0': RepresentedFrostbite, # 'RepresentedHero-v0': RepresentedHero, # 'RepresentedMontezumaRevenge-v0': RepresentedMontezumaRevenge, # 'RepresentedPitfall-v0': RepresentedPitfall, # 'RepresentedPong-v0': RepresentedPong, # 'RepresentedPrivateEye-v0': RepresentedPrivateEye, # 'RepresentedQbert-v0': RepresentedQbert, # 'RepresentedRiverraid-v0': RepresentedRiverraid, # 'RepresentedSeaquest-v0': RepresentedSeaquest, # 'RepresentedSpaceInvaders-v0': RepresentedSpaceInvaders, # 'RepresentedTennis-v0': RepresentedTennis, # 'RepresentedVenture-v0': RepresentedVenture, # 'RepresentedVideoPinball-v0': RepresentedVideoPinball # } # # for env, env_class in env_classes.items(): # env_1 = env_class() # env_name = env_1.env_name # env_2 = gym.make(env_name) # print(env_name, env_1.action_space == env_2.action_space, env_1.action_space)