Spaces:
Sleeping
Sleeping
from typing import List, Optional, Union, Dict | |
from easydict import EasyDict | |
import gym | |
import gymnasium | |
import copy | |
import numpy as np | |
import treetensor.numpy as tnp | |
from ding.envs.common.common_function import affine_transform | |
from ding.envs.env_wrappers import create_env_wrapper | |
from ding.torch_utils import to_ndarray | |
from ding.utils import CloudPickleWrapper | |
from .base_env import BaseEnv, BaseEnvTimestep | |
from .default_wrapper import get_default_wrappers | |
class DingEnvWrapper(BaseEnv): | |
""" | |
Overview: | |
This is a wrapper for the BaseEnv class, used to provide a consistent environment interface. | |
Interfaces: | |
__init__, reset, step, close, seed, random_action, _wrap_env, __repr__, create_collector_env_cfg, | |
create_evaluator_env_cfg, enable_save_replay, observation_space, action_space, reward_space, clone | |
""" | |
def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True, caller: str = 'collector') -> None: | |
""" | |
Overview: | |
Initialize the DingEnvWrapper. Either an environment instance or a config to create the environment \ | |
instance should be passed in. For the former, i.e., an environment instance: The `env` parameter must not \ | |
be `None`, but should be the instance. It does not support subprocess environment manager. Thus, it is \ | |
usually used in simple environments. For the latter, i.e., a config to create an environment instance: \ | |
The `cfg` parameter must contain `env_id`. | |
Arguments: | |
- env (:obj:`gym.Env`): An environment instance to be wrapped. | |
- cfg (:obj:`dict`): The configuration dictionary to create an environment instance. | |
- seed_api (:obj:`bool`): Whether to use seed API. Defaults to True. | |
- caller (:obj:`str`): A string representing the caller of this method, including ``collector`` or \ | |
``evaluator``. Different caller may need different wrappers. Default is 'collector'. | |
""" | |
self._env = None | |
self._raw_env = env | |
self._cfg = cfg | |
self._seed_api = seed_api # some env may disable `env.seed` api | |
self._caller = caller | |
if self._cfg is None: | |
self._cfg = {} | |
self._cfg = EasyDict(self._cfg) | |
if 'act_scale' not in self._cfg: | |
self._cfg.act_scale = False | |
if 'rew_clip' not in self._cfg: | |
self._cfg.rew_clip = False | |
if 'env_wrapper' not in self._cfg: | |
self._cfg.env_wrapper = 'default' | |
if 'env_id' not in self._cfg: | |
self._cfg.env_id = None | |
if env is not None: | |
self._env = env | |
self._wrap_env(caller) | |
self._observation_space = self._env.observation_space | |
self._action_space = self._env.action_space | |
self._action_space.seed(0) # default seed | |
self._reward_space = gym.spaces.Box( | |
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32 | |
) | |
self._init_flag = True | |
else: | |
assert 'env_id' in self._cfg | |
self._init_flag = False | |
self._observation_space = None | |
self._action_space = None | |
self._reward_space = None | |
# Only if user specifies the replay_path, will the video be saved. So its inital value is None. | |
self._replay_path = None | |
# override | |
def reset(self) -> np.ndarray: | |
""" | |
Overview: | |
Resets the state of the environment. If the environment is not initialized, it will be created first. | |
Returns: | |
- obs (:obj:`Dict`): The new observation after reset. | |
""" | |
if not self._init_flag: | |
self._env = gym.make(self._cfg.env_id) | |
self._wrap_env(self._caller) | |
self._observation_space = self._env.observation_space | |
self._action_space = self._env.action_space | |
self._reward_space = gym.spaces.Box( | |
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32 | |
) | |
self._init_flag = True | |
if self._replay_path is not None: | |
self._env = gym.wrappers.RecordVideo( | |
self._env, | |
video_folder=self._replay_path, | |
episode_trigger=lambda episode_id: True, | |
name_prefix='rl-video-{}'.format(id(self)) | |
) | |
self._replay_path = None | |
if isinstance(self._env, gym.Env): | |
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: | |
np_seed = 100 * np.random.randint(1, 1000) | |
if self._seed_api: | |
self._env.seed(self._seed + np_seed) | |
self._action_space.seed(self._seed + np_seed) | |
elif hasattr(self, '_seed'): | |
if self._seed_api: | |
self._env.seed(self._seed) | |
self._action_space.seed(self._seed) | |
obs = self._env.reset() | |
elif isinstance(self._env, gymnasium.Env): | |
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: | |
np_seed = 100 * np.random.randint(1, 1000) | |
self._action_space.seed(self._seed + np_seed) | |
obs = self._env.reset(seed=self._seed + np_seed) | |
elif hasattr(self, '_seed'): | |
self._action_space.seed(self._seed) | |
obs = self._env.reset(seed=self._seed) | |
else: | |
obs = self._env.reset() | |
else: | |
raise RuntimeError("not support env type: {}".format(type(self._env))) | |
if self.observation_space.dtype == np.float32: | |
obs = to_ndarray(obs, dtype=np.float32) | |
else: | |
obs = to_ndarray(obs) | |
return obs | |
# override | |
def close(self) -> None: | |
""" | |
Overview: | |
Clean up the environment by closing and deleting it. | |
This method should be called when the environment is no longer needed. | |
Failing to call this method can lead to memory leaks. | |
""" | |
try: | |
self._env.close() | |
del self._env | |
except: # noqa | |
pass | |
# override | |
def seed(self, seed: int, dynamic_seed: bool = True) -> None: | |
""" | |
Overview: | |
Set the seed for the environment. | |
Arguments: | |
- seed (:obj:`int`): The seed to set. | |
- dynamic_seed (:obj:`bool`): Whether to use dynamic seed, default is True. | |
""" | |
self._seed = seed | |
self._dynamic_seed = dynamic_seed | |
np.random.seed(self._seed) | |
# override | |
def step(self, action: Union[np.int64, np.ndarray]) -> BaseEnvTimestep: | |
""" | |
Overview: | |
Execute the given action in the environment, and return the timestep (observation, reward, done, info). | |
Arguments: | |
- action (:obj:`Union[np.int64, np.ndarray]`): The action to execute in the environment. | |
Returns: | |
- timestep (:obj:`BaseEnvTimestep`): The timestep after the action execution. | |
""" | |
action = self._judge_action_type(action) | |
if self._cfg.act_scale: | |
action = affine_transform(action, min_val=self._env.action_space.low, max_val=self._env.action_space.high) | |
obs, rew, done, info = self._env.step(action) | |
if self._cfg.rew_clip: | |
rew = max(-10, rew) | |
rew = np.float32(rew) | |
if self.observation_space.dtype == np.float32: | |
obs = to_ndarray(obs, dtype=np.float32) | |
else: | |
obs = to_ndarray(obs) | |
rew = to_ndarray([rew], np.float32) | |
return BaseEnvTimestep(obs, rew, done, info) | |
def _judge_action_type(self, action: Union[np.ndarray, dict]) -> Union[np.ndarray, dict]: | |
""" | |
Overview: | |
Ensure the action taken by the agent is of the correct type. | |
This method is used to standardize different action types to a common format. | |
Arguments: | |
- action (Union[np.ndarray, dict]): The action taken by the agent. | |
Returns: | |
- action (Union[np.ndarray, dict]): The formatted action. | |
""" | |
if isinstance(action, int): | |
return action | |
elif isinstance(action, np.int64): | |
return int(action) | |
elif isinstance(action, np.ndarray): | |
if action.shape == (): | |
action = action.item() | |
elif action.shape == (1, ) and action.dtype == np.int64: | |
action = action.item() | |
return action | |
elif isinstance(action, dict): | |
for k, v in action.items(): | |
action[k] = self._judge_action_type(v) | |
return action | |
elif isinstance(action, tnp.ndarray): | |
return self._judge_action_type(action.json()) | |
else: | |
raise TypeError( | |
'`action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format( | |
type(action), action | |
) | |
) | |
def random_action(self) -> np.ndarray: | |
""" | |
Overview: | |
Return a random action from the action space of the environment. | |
Returns: | |
- action (:obj:`np.ndarray`): The random action. | |
""" | |
random_action = self.action_space.sample() | |
if isinstance(random_action, np.ndarray): | |
pass | |
elif isinstance(random_action, int): | |
random_action = to_ndarray([random_action], dtype=np.int64) | |
elif isinstance(random_action, dict): | |
random_action = to_ndarray(random_action) | |
else: | |
raise TypeError( | |
'`random_action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format( | |
type(random_action), random_action | |
) | |
) | |
return random_action | |
def _wrap_env(self, caller: str = 'collector') -> None: | |
""" | |
Overview: | |
Wrap the environment according to the configuration. | |
Arguments: | |
- caller (:obj:`str`): The caller of the environment, including ``collector`` or ``evaluator``. \ | |
Different caller may need different wrappers. Default is 'collector'. | |
""" | |
# wrapper_cfgs: Union[str, List] | |
wrapper_cfgs = self._cfg.env_wrapper | |
if isinstance(wrapper_cfgs, str): | |
wrapper_cfgs = get_default_wrappers(wrapper_cfgs, self._cfg.env_id, caller) | |
# self._wrapper_cfgs: List[Union[Callable, Dict]] | |
self._wrapper_cfgs = wrapper_cfgs | |
for wrapper in self._wrapper_cfgs: | |
# wrapper: Union[Callable, Dict] | |
if isinstance(wrapper, Dict): | |
self._env = create_env_wrapper(self._env, wrapper) | |
else: # Callable, such as lambda anonymous function | |
self._env = wrapper(self._env) | |
def __repr__(self) -> str: | |
""" | |
Overview: | |
Return the string representation of the instance. | |
Returns: | |
- str (:obj:`str`): The string representation of the instance. | |
""" | |
return "DI-engine Env({}), generated by DingEnvWrapper".format(self._cfg.env_id) | |
def create_collector_env_cfg(cfg: dict) -> List[dict]: | |
""" | |
Overview: | |
Create a list of environment configuration for collectors based on the input configuration. | |
Arguments: | |
- cfg (:obj:`dict`): The input configuration dictionary. | |
Returns: | |
- env_cfgs (:obj:`List[dict]`): The list of environment configurations for collectors. | |
""" | |
actor_env_num = cfg.pop('collector_env_num') | |
cfg = copy.deepcopy(cfg) | |
cfg.is_train = True | |
return [cfg for _ in range(actor_env_num)] | |
def create_evaluator_env_cfg(cfg: dict) -> List[dict]: | |
""" | |
Overview: | |
Create a list of environment configuration for evaluators based on the input configuration. | |
Arguments: | |
- cfg (:obj:`dict`): The input configuration dictionary. | |
Returns: | |
- env_cfgs (:obj:`List[dict]`): The list of environment configurations for evaluators. | |
""" | |
evaluator_env_num = cfg.pop('evaluator_env_num') | |
cfg = copy.deepcopy(cfg) | |
cfg.is_train = False | |
return [cfg for _ in range(evaluator_env_num)] | |
def enable_save_replay(self, replay_path: Optional[str] = None) -> None: | |
""" | |
Overview: | |
Enable the save replay functionality. The replay will be saved at the specified path. | |
Arguments: | |
- replay_path (:obj:`Optional[str]`): The path to save the replay, default is None. | |
""" | |
if replay_path is None: | |
replay_path = './video' | |
self._replay_path = replay_path | |
def observation_space(self) -> gym.spaces.Space: | |
""" | |
Overview: | |
Return the observation space of the wrapped environment. | |
The observation space represents the range and shape of possible observations | |
that the environment can provide to the agent. | |
Note: | |
If the data type of the observation space is float64, it's converted to float32 | |
for better compatibility with most machine learning libraries. | |
Returns: | |
- observation_space (gym.spaces.Space): The observation space of the environment. | |
""" | |
if self._observation_space.dtype == np.float64: | |
self._observation_space.dtype = np.float32 | |
return self._observation_space | |
def action_space(self) -> gym.spaces.Space: | |
""" | |
Overview: | |
Return the action space of the wrapped environment. | |
The action space represents the range and shape of possible actions | |
that the agent can take in the environment. | |
Returns: | |
- action_space (gym.spaces.Space): The action space of the environment. | |
""" | |
return self._action_space | |
def reward_space(self) -> gym.spaces.Space: | |
""" | |
Overview: | |
Return the reward space of the wrapped environment. | |
The reward space represents the range and shape of possible rewards | |
that the agent can receive as a result of its actions. | |
Returns: | |
- reward_space (gym.spaces.Space): The reward space of the environment. | |
""" | |
return self._reward_space | |
def clone(self, caller: str = 'collector') -> BaseEnv: | |
""" | |
Overview: | |
Clone the current environment wrapper, creating a new environment with the same settings. | |
Arguments: | |
- caller (str): A string representing the caller of this method, including ``collector`` or ``evaluator``. \ | |
Different caller may need different wrappers. Default is 'collector'. | |
Returns: | |
- DingEnvWrapper: A new instance of the environment with the same settings. | |
""" | |
try: | |
spec = copy.deepcopy(self._raw_env.spec) | |
raw_env = CloudPickleWrapper(self._raw_env) | |
raw_env = copy.deepcopy(raw_env).data | |
raw_env.__setattr__('spec', spec) | |
except Exception: | |
raw_env = self._raw_env | |
return DingEnvWrapper(raw_env, self._cfg, self._seed_api, caller) | |