Spaces:
Sleeping
Sleeping
from typing import Optional, Callable | |
import gym | |
from gym.spaces import Box | |
import numpy as np | |
from ding.envs import BaseEnv, BaseEnvTimestep | |
from ding.envs.common.common_function import affine_transform | |
from ding.torch_utils import to_ndarray | |
from ding.utils import ENV_REGISTRY | |
import dmc2gym | |
from ding.envs import WarpFrameWrapper, ScaledFloatFrameWrapper, ClipRewardWrapper, ActionRepeatWrapper, FrameStackWrapper | |
def dmc2gym_observation_space(dim, minimum=-np.inf, maximum=np.inf, dtype=np.float32) -> Callable: | |
def observation_space(from_pixels=True, height=84, width=84, channels_first=True) -> Box: | |
if from_pixels: | |
shape = [3, height, width] if channels_first else [height, width, 3] | |
return Box(low=0, high=255, shape=shape, dtype=np.uint8) | |
else: | |
return Box(np.repeat(minimum, dim).astype(dtype), np.repeat(maximum, dim).astype(dtype), dtype=dtype) | |
return observation_space | |
def dmc2gym_state_space(dim, minimum=-np.inf, maximum=np.inf, dtype=np.float32) -> Box: | |
return Box(np.repeat(minimum, dim).astype(dtype), np.repeat(maximum, dim).astype(dtype), dtype=dtype) | |
def dmc2gym_action_space(dim, minimum=-1, maximum=1, dtype=np.float32) -> Box: | |
return Box(np.repeat(minimum, dim).astype(dtype), np.repeat(maximum, dim).astype(dtype), dtype=dtype) | |
def dmc2gym_reward_space(minimum=0, maximum=1, dtype=np.float32) -> Callable: | |
def reward_space(frame_skip=1) -> Box: | |
return Box( | |
np.repeat(minimum * frame_skip, 1).astype(dtype), | |
np.repeat(maximum * frame_skip, 1).astype(dtype), | |
dtype=dtype | |
) | |
return reward_space | |
""" | |
default observation, state, action, reward space for dmc2gym env | |
""" | |
dmc2gym_env_info = { | |
"ball_in_cup": { | |
"catch": { | |
"observation_space": dmc2gym_observation_space(8), | |
"state_space": dmc2gym_state_space(8), | |
"action_space": dmc2gym_action_space(2), | |
"reward_space": dmc2gym_reward_space() | |
} | |
}, | |
"cartpole": { | |
"balance": { | |
"observation_space": dmc2gym_observation_space(5), | |
"state_space": dmc2gym_state_space(5), | |
"action_space": dmc2gym_action_space(1), | |
"reward_space": dmc2gym_reward_space() | |
}, | |
"swingup": { | |
"observation_space": dmc2gym_observation_space(5), | |
"state_space": dmc2gym_state_space(5), | |
"action_space": dmc2gym_action_space(1), | |
"reward_space": dmc2gym_reward_space() | |
} | |
}, | |
"cheetah": { | |
"run": { | |
"observation_space": dmc2gym_observation_space(17), | |
"state_space": dmc2gym_state_space(17), | |
"action_space": dmc2gym_action_space(6), | |
"reward_space": dmc2gym_reward_space() | |
} | |
}, | |
"finger": { | |
"spin": { | |
"observation_space": dmc2gym_observation_space(9), | |
"state_space": dmc2gym_state_space(9), | |
"action_space": dmc2gym_action_space(1), | |
"reward_space": dmc2gym_reward_space() | |
} | |
}, | |
"reacher": { | |
"easy": { | |
"observation_space": dmc2gym_observation_space(6), | |
"state_space": dmc2gym_state_space(6), | |
"action_space": dmc2gym_action_space(2), | |
"reward_space": dmc2gym_reward_space() | |
} | |
}, | |
"walker": { | |
"walk": { | |
"observation_space": dmc2gym_observation_space(24), | |
"state_space": dmc2gym_state_space(24), | |
"action_space": dmc2gym_action_space(6), | |
"reward_space": dmc2gym_reward_space() | |
} | |
} | |
} | |
class DMC2GymEnv(BaseEnv): | |
def __init__(self, cfg: dict = {}) -> None: | |
assert cfg.domain_name in dmc2gym_env_info, '{}/{}'.format(cfg.domain_name, dmc2gym_env_info.keys()) | |
assert cfg.task_name in dmc2gym_env_info[ | |
cfg.domain_name], '{}/{}'.format(cfg.task_name, dmc2gym_env_info[cfg.domain_name].keys()) | |
# default config for dmc2gym env | |
self._cfg = { | |
"frame_skip": 4, | |
'warp_frame': False, | |
'scale': False, | |
'clip_rewards': False, | |
'action_repeat': 1, | |
"frame_stack": 3, | |
"from_pixels": True, | |
"visualize_reward": False, | |
"height": 84, | |
"width": 84, | |
"channels_first": True, | |
"resize": 84, | |
} | |
self._cfg.update(cfg) | |
self._init_flag = False | |
self._replay_path = None | |
self._observation_space = dmc2gym_env_info[cfg.domain_name][cfg.task_name]["observation_space"]( | |
from_pixels=self._cfg["from_pixels"], | |
height=self._cfg["height"], | |
width=self._cfg["width"], | |
channels_first=self._cfg["channels_first"] | |
) | |
self._action_space = dmc2gym_env_info[cfg.domain_name][cfg.task_name]["action_space"] | |
self._reward_space = dmc2gym_env_info[cfg.domain_name][cfg.task_name]["reward_space"](self._cfg["frame_skip"]) | |
def reset(self) -> np.ndarray: | |
if not self._init_flag: | |
self._env = dmc2gym.make( | |
domain_name=self._cfg["domain_name"], | |
task_name=self._cfg["task_name"], | |
seed=1, | |
visualize_reward=self._cfg["visualize_reward"], | |
from_pixels=self._cfg["from_pixels"], | |
height=self._cfg["height"], | |
width=self._cfg["width"], | |
frame_skip=self._cfg["frame_skip"], | |
channels_first=self._cfg["channels_first"], | |
) | |
# optional env wrapper | |
if self._cfg['warp_frame']: | |
self._env = WarpFrameWrapper(self._env, size=self._cfg['resize']) | |
if self._cfg['scale']: | |
self._env = ScaledFloatFrameWrapper(self._env) | |
if self._cfg['clip_rewards']: | |
self._env = ClipRewardWrapper(self._env) | |
if self._cfg['action_repeat']: | |
self._env = ActionRepeatWrapper(self._env, self._cfg['action_repeat']) | |
if self._cfg['frame_stack'] > 1: | |
self._env = FrameStackWrapper(self._env, self._cfg['frame_stack']) | |
# set the obs, action space of wrapped env | |
self._observation_space = self._env.observation_space | |
self._action_space = self._env.action_space | |
if self._replay_path is not None: | |
if gym.version.VERSION > '0.22.0': | |
self._env.metadata.update({'render_modes': ["rgb_array"]}) | |
else: | |
self._env.metadata.update({'render.modes': ["rgb_array"]}) | |
self._env = gym.wrappers.RecordVideo( | |
self._env, | |
video_folder=self._replay_path, | |
episode_trigger=lambda episode_id: True, | |
name_prefix='rl-video-{}'.format(id(self)) | |
) | |
self._env.start_video_recorder() | |
self._init_flag = True | |
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: | |
np_seed = 100 * np.random.randint(1, 1000) | |
self._env.seed(self._seed + np_seed) | |
elif hasattr(self, '_seed'): | |
self._env.seed(self._seed) | |
self._eval_episode_return = 0 | |
obs = self._env.reset() | |
obs = to_ndarray(obs).astype(np.float32) | |
return obs | |
def close(self) -> None: | |
if self._init_flag: | |
self._env.close() | |
self._init_flag = False | |
def seed(self, seed: int, dynamic_seed: bool = True) -> None: | |
self._seed = seed | |
self._dynamic_seed = dynamic_seed | |
np.random.seed(self._seed) | |
def step(self, action: np.ndarray) -> BaseEnvTimestep: | |
action = action.astype('float32') | |
action = affine_transform(action, min_val=self._env.action_space.low, max_val=self._env.action_space.high) | |
obs, rew, done, info = self._env.step(action) | |
self._eval_episode_return += rew | |
if done: | |
info['eval_episode_return'] = self._eval_episode_return | |
obs = to_ndarray(obs).astype(np.float32) | |
rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transferred to a array with shape (1,) | |
return BaseEnvTimestep(obs, rew, done, info) | |
def enable_save_replay(self, replay_path: Optional[str] = None) -> None: | |
if replay_path is None: | |
replay_path = './video' | |
self._replay_path = replay_path | |
def random_action(self) -> np.ndarray: | |
random_action = self.action_space.sample().astype(np.float32) | |
return random_action | |
def observation_space(self) -> gym.spaces.Space: | |
return self._observation_space | |
def action_space(self) -> gym.spaces.Space: | |
return self._action_space | |
def reward_space(self) -> gym.spaces.Space: | |
return self._reward_space | |
def __repr__(self) -> str: | |
return "DI-engine DeepMind Control Suite to gym Env: " + self._cfg["domain_name"] + ":" + self._cfg["task_name"] | |