Spaces:
Sleeping
Sleeping
from collections import defaultdict | |
import math | |
import queue | |
from time import sleep, time | |
import gym | |
from ding.framework import Supervisor | |
from typing import TYPE_CHECKING, Any, List, Union, Dict, Optional, Callable | |
from ding.framework.supervisor import ChildType, RecvPayload, SendPayload | |
from ding.utils import make_key_as_identifier | |
from ditk import logging | |
from ding.data import ShmBufferContainer | |
import enum | |
import treetensor.numpy as tnp | |
import numbers | |
if TYPE_CHECKING: | |
from gym.spaces import Space | |
class EnvState(enum.IntEnum): | |
""" | |
VOID -> RUN -> DONE | |
""" | |
VOID = 0 | |
INIT = 1 | |
RUN = 2 | |
RESET = 3 | |
DONE = 4 | |
ERROR = 5 | |
NEED_RESET = 6 | |
class EnvRetryType(str, enum.Enum): | |
RESET = "reset" | |
RENEW = "renew" | |
class EnvSupervisor(Supervisor): | |
""" | |
Manage multiple envs with supervisor. | |
New features (compared to env manager): | |
- Consistent interface in multi-process and multi-threaded mode. | |
- Add asynchronous features and recommend using asynchronous methods. | |
- Reset is performed after an error is encountered in the step method. | |
Breaking changes (compared to env manager): | |
- Without some states. | |
""" | |
def __init__( | |
self, | |
type_: ChildType = ChildType.PROCESS, | |
env_fn: List[Callable] = None, | |
retry_type: EnvRetryType = EnvRetryType.RESET, | |
max_try: Optional[int] = None, | |
max_retry: Optional[int] = None, | |
auto_reset: bool = True, | |
reset_timeout: Optional[int] = None, | |
step_timeout: Optional[int] = None, | |
retry_waiting_time: Optional[int] = None, | |
episode_num: int = float("inf"), | |
shared_memory: bool = True, | |
copy_on_get: bool = True, | |
**kwargs | |
) -> None: | |
""" | |
Overview: | |
Supervisor that manage a group of envs. | |
Arguments: | |
- type_ (:obj:`ChildType`): Type of child process. | |
- env_fn (:obj:`List[Callable]`): The function to create environment | |
- retry_type (:obj:`EnvRetryType`): Retry reset or renew env. | |
- max_try (:obj:`EasyDict`): Max try times for reset or step action. | |
- max_retry (:obj:`Optional[int]`): Alias of max_try. | |
- auto_reset (:obj:`bool`): Auto reset env if reach done. | |
- reset_timeout (:obj:`Optional[int]`): Timeout in seconds for reset. | |
- step_timeout (:obj:`Optional[int]`): Timeout in seconds for step. | |
- retry_waiting_time (:obj:`Optional[float]`): Wait time on each retry. | |
- shared_memory (:obj:`bool`): Use shared memory in multiprocessing. | |
- copy_on_get (:obj:`bool`): Use copy on get in multiprocessing. | |
""" | |
if kwargs: | |
logging.warning("Unknown parameters on env supervisor: {}".format(kwargs)) | |
super().__init__(type_=type_) | |
if type_ is not ChildType.PROCESS and (shared_memory or copy_on_get): | |
logging.warning("shared_memory and copy_on_get only works in process mode.") | |
self._shared_memory = type_ is ChildType.PROCESS and shared_memory | |
self._copy_on_get = type_ is ChildType.PROCESS and copy_on_get | |
self._env_fn = env_fn | |
self._create_env_ref() | |
self._obs_buffers = None | |
if env_fn: | |
if self._shared_memory: | |
obs_space = self._observation_space | |
if isinstance(obs_space, gym.spaces.Dict): | |
# For multi_agent case, such as multiagent_mujoco and petting_zoo mpe. | |
# Now only for the case that each agent in the team have the same obs structure | |
# and corresponding shape. | |
shape = {k: v.shape for k, v in obs_space.spaces.items()} | |
dtype = {k: v.dtype for k, v in obs_space.spaces.items()} | |
else: | |
shape = obs_space.shape | |
dtype = obs_space.dtype | |
self._obs_buffers = { | |
env_id: ShmBufferContainer(dtype, shape, copy_on_get=self._copy_on_get) | |
for env_id in range(len(self._env_fn)) | |
} | |
for env_init in env_fn: | |
self.register(env_init, shm_buffer=self._obs_buffers, shm_callback=self._shm_callback) | |
else: | |
for env_init in env_fn: | |
self.register(env_init) | |
self._retry_type = retry_type | |
self._auto_reset = auto_reset | |
if max_retry: | |
logging.warning("The `max_retry` is going to be deprecated, use `max_try` instead!") | |
self._max_try = max_try or max_retry or 1 | |
self._reset_timeout = reset_timeout | |
self._step_timeout = step_timeout | |
self._retry_waiting_time = retry_waiting_time | |
self._env_replay_path = None | |
self._episode_num = episode_num | |
self._init_states() | |
def _init_states(self): | |
self._env_seed = {} | |
self._env_dynamic_seed = None | |
self._env_replay_path = None | |
self._env_states = {} | |
self._reset_param = {} | |
self._ready_obs = {} | |
self._env_episode_count = {i: 0 for i in range(self.env_num)} | |
self._retry_times = defaultdict(lambda: 0) | |
self._last_called = defaultdict(lambda: {"step": math.inf, "reset": math.inf}) | |
def _shm_callback(self, payload: RecvPayload, obs_buffers: Any): | |
""" | |
Overview: | |
This method will be called in child worker, so we can put large data into shared memory | |
and replace the original payload data to none, then reduce the serialization/deserialization cost. | |
""" | |
if payload.method == "reset" and payload.data is not None: | |
obs_buffers[payload.proc_id].fill(payload.data) | |
payload.data = None | |
elif payload.method == "step" and payload.data is not None: | |
obs_buffers[payload.proc_id].fill(payload.data.obs) | |
payload.data._replace(obs=None) | |
def _create_env_ref(self): | |
# env_ref is used to acquire some common attributes of env, like obs_shape and act_shape | |
self._env_ref = self._env_fn[0]() | |
self._env_ref.reset() | |
self._observation_space = self._env_ref.observation_space | |
self._action_space = self._env_ref.action_space | |
self._reward_space = self._env_ref.reward_space | |
self._env_ref.close() | |
def step(self, actions: Union[Dict[int, List[Any]], List[Any]], block: bool = True) -> Optional[List[tnp.ndarray]]: | |
""" | |
Overview: | |
Execute env step according to input actions. And reset an env if done. | |
Arguments: | |
- actions (:obj:`List[tnp.ndarray]`): Actions came from outer caller like policy, \ | |
in structure of {env_id: actions}. | |
- block (:obj:`bool`): If block, return timesteps, else return none. | |
Returns: | |
- timesteps (:obj:`List[tnp.ndarray]`): Each timestep is a tnp.array with observation, reward, done, \ | |
info, env_id. | |
""" | |
assert not self.closed, "Env supervisor has closed." | |
if isinstance(actions, List): | |
actions = {i: p for i, p in enumerate(actions)} | |
assert actions, "Action is empty!" | |
send_payloads = [] | |
for env_id, act in actions.items(): | |
payload = SendPayload(proc_id=env_id, method="step", args=[act]) | |
send_payloads.append(payload) | |
self.send(payload) | |
if not block: | |
# Retrieve the data for these steps from the recv method | |
return | |
# Wait for all steps returns | |
recv_payloads = self.recv_all( | |
send_payloads, ignore_err=True, callback=self._recv_callback, timeout=self._step_timeout | |
) | |
return [payload.data for payload in recv_payloads] | |
def recv(self, ignore_err: bool = False) -> RecvPayload: | |
""" | |
Overview: | |
Wait for recv payload, this function will block the thread. | |
Arguments: | |
- ignore_err (:obj:`bool`): If ignore_err is true, payload with error object will be discarded.\ | |
This option will not catch the exception. | |
Returns: | |
- recv_payload (:obj:`RecvPayload`): Recv payload. | |
""" | |
self._detect_timeout() | |
try: | |
payload = super().recv(ignore_err=True, timeout=0.1) | |
payload = self._recv_callback(payload=payload) | |
if payload.err: | |
return self.recv(ignore_err=ignore_err) | |
else: | |
return payload | |
except queue.Empty: | |
return self.recv(ignore_err=ignore_err) | |
def _detect_timeout(self): | |
""" | |
Overview: | |
Try to restart all timeout environments if detected timeout. | |
""" | |
for env_id in self._last_called: | |
if self._step_timeout and time() - self._last_called[env_id]["step"] > self._step_timeout: | |
payload = RecvPayload( | |
proc_id=env_id, method="step", err=TimeoutError("Step timeout on env {}".format(env_id)) | |
) | |
self._recv_queue.put(payload) | |
continue | |
if self._reset_timeout and time() - self._last_called[env_id]["reset"] > self._reset_timeout: | |
payload = RecvPayload( | |
proc_id=env_id, method="reset", err=TimeoutError("Step timeout on env {}".format(env_id)) | |
) | |
self._recv_queue.put(payload) | |
continue | |
def env_num(self) -> int: | |
return len(self._children) | |
def observation_space(self) -> 'Space': | |
return self._observation_space | |
def action_space(self) -> 'Space': | |
return self._action_space | |
def reward_space(self) -> 'Space': | |
return self._reward_space | |
def ready_obs(self) -> tnp.array: | |
""" | |
Overview: | |
Get the ready (next) observation in ``tnp.array`` type, which is uniform for both async/sync scenarios. | |
Return: | |
- ready_obs (:obj:`tnp.array`): A stacked treenumpy-type observation data. | |
Example: | |
>>> obs = env_manager.ready_obs | |
>>> action = model(obs) # model input np obs and output np action | |
>>> timesteps = env_manager.step(action) | |
""" | |
active_env = [i for i, s in self._env_states.items() if s == EnvState.RUN] | |
active_env.sort() | |
obs = [self._ready_obs.get(i) for i in active_env] | |
if len(obs) == 0: | |
return tnp.array([]) | |
return tnp.stack(obs) | |
def ready_obs_id(self) -> List[int]: | |
return [i for i, s in self.env_states.items() if s == EnvState.RUN] | |
def done(self) -> bool: | |
return all([s == EnvState.DONE for s in self.env_states.values()]) | |
def method_name_list(self) -> List[str]: | |
return ['reset', 'step', 'seed', 'close', 'enable_save_replay'] | |
def env_states(self) -> Dict[int, EnvState]: | |
return {env_id: self._env_states.get(env_id) or EnvState.VOID for env_id in range(self.env_num)} | |
def env_state_done(self, env_id: int) -> bool: | |
return self.env_states[env_id] == EnvState.DONE | |
def launch(self, reset_param: Optional[Dict] = None, block: bool = True) -> None: | |
""" | |
Overview: | |
Set up the environments and their parameters. | |
Arguments: | |
- reset_param (:obj:`Optional[Dict]`): Dict of reset parameters for each environment, key is the env_id, \ | |
value is the cooresponding reset parameters. | |
- block (:obj:`block`): Whether will block the process and wait for reset states. | |
""" | |
assert self.closed, "Please first close the env supervisor before launch it" | |
if reset_param is not None: | |
assert len(reset_param) == self.env_num | |
self.start_link() | |
self._send_seed(self._env_seed, self._env_dynamic_seed, block=block) | |
self.reset(reset_param, block=block) | |
self._enable_env_replay() | |
def reset(self, reset_param: Optional[Dict[int, List[Any]]] = None, block: bool = True) -> None: | |
""" | |
Overview: | |
Reset an environment. | |
Arguments: | |
- reset_param (:obj:`Optional[Dict[int, List[Any]]]`): Dict of reset parameters for each environment, \ | |
key is the env_id, value is the cooresponding reset parameters. | |
- block (:obj:`block`): Whether will block the process and wait for reset states. | |
""" | |
if not reset_param: | |
reset_param = {i: {} for i in range(self.env_num)} | |
elif isinstance(reset_param, List): | |
reset_param = {i: p for i, p in enumerate(reset_param)} | |
send_payloads = [] | |
for env_id, kw_param in reset_param.items(): | |
self._reset_param[env_id] = kw_param # For auto reset | |
send_payloads += self._reset(env_id, kw_param=kw_param) | |
if not block: | |
return | |
self.recv_all(send_payloads, ignore_err=True, callback=self._recv_callback, timeout=self._reset_timeout) | |
def _recv_callback( | |
self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None | |
) -> RecvPayload: | |
""" | |
Overview: | |
The callback function for each received payload, within this method will modify the state of \ | |
each environment, replace objects in shared memory, and determine if a retry is needed due to an error. | |
Arguments: | |
- payload (:obj:`RecvPayload`): The received payload. | |
- remain_payloads (:obj:`Optional[Dict[str, SendPayload]]`): The callback may be called many times \ | |
until remain_payloads be cleared, you can append new payload into remain_payloads to call this \ | |
callback recursively. | |
""" | |
self._set_shared_obs(payload=payload) | |
self.change_state(payload=payload) | |
if payload.method == "reset": | |
return self._recv_reset_callback(payload=payload, remain_payloads=remain_payloads) | |
elif payload.method == "step": | |
return self._recv_step_callback(payload=payload, remain_payloads=remain_payloads) | |
return payload | |
def _set_shared_obs(self, payload: RecvPayload): | |
if self._obs_buffers is None: | |
return | |
if payload.method == "reset" and payload.err is None: | |
payload.data = self._obs_buffers[payload.proc_id].get() | |
elif payload.method == "step" and payload.err is None: | |
payload.data._replace(obs=self._obs_buffers[payload.proc_id].get()) | |
def _recv_reset_callback( | |
self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None | |
) -> RecvPayload: | |
assert payload.method == "reset", "Recv error callback({}) in reset callback!".format(payload.method) | |
if remain_payloads is None: | |
remain_payloads = {} | |
env_id = payload.proc_id | |
if payload.err: | |
self._retry_times[env_id] += 1 | |
if self._retry_times[env_id] > self._max_try - 1: | |
self.shutdown(5) | |
raise RuntimeError( | |
"Env {} reset has exceeded max_try({}), and the latest exception is: {}".format( | |
env_id, self._max_try, payload.err | |
) | |
) | |
if self._retry_waiting_time: | |
sleep(self._retry_waiting_time) | |
if self._retry_type == EnvRetryType.RENEW: | |
self._children[env_id].restart() | |
send_payloads = self._reset(env_id) | |
for p in send_payloads: | |
remain_payloads[p.req_id] = p | |
else: | |
self._retry_times[env_id] = 0 | |
self._ready_obs[env_id] = payload.data | |
return payload | |
def _recv_step_callback( | |
self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None | |
) -> RecvPayload: | |
assert payload.method == "step", "Recv error callback({}) in step callback!".format(payload.method) | |
if remain_payloads is None: | |
remain_payloads = {} | |
if payload.err: | |
send_payloads = self._reset(payload.proc_id) | |
for p in send_payloads: | |
remain_payloads[p.req_id] = p | |
info = {"abnormal": True, "err": payload.err} | |
payload.data = tnp.array( | |
{ | |
'obs': None, | |
'reward': None, | |
'done': None, | |
'info': info, | |
'env_id': payload.proc_id | |
} | |
) | |
else: | |
obs, reward, done, info, *_ = payload.data | |
if done: | |
self._env_episode_count[payload.proc_id] += 1 | |
if self._env_episode_count[payload.proc_id] < self._episode_num and self._auto_reset: | |
send_payloads = self._reset(payload.proc_id) | |
for p in send_payloads: | |
remain_payloads[p.req_id] = p | |
# make the type and content of key as similar as identifier, | |
# in order to call them as attribute (e.g. timestep.xxx), such as ``TimeLimit.truncated`` in cartpole info | |
info = make_key_as_identifier(info) | |
payload.data = tnp.array( | |
{ | |
'obs': obs, | |
'reward': reward, | |
'done': done, | |
'info': info, | |
'env_id': payload.proc_id | |
} | |
) | |
self._ready_obs[payload.proc_id] = obs | |
return payload | |
def _reset(self, env_id: int, kw_param: Optional[Dict[str, Any]] = None) -> List[SendPayload]: | |
""" | |
Overview: | |
Reset an environment. This method does not wait for the result to be returned. | |
Arguments: | |
- env_id (:obj:`int`): Environment id. | |
- kw_param (:obj:`Optional[Dict[str, Any]]`): Reset parameters for the environment. | |
Returns: | |
- send_payloads (:obj:`List[SendPayload]`): The request payloads for seed and reset actions. | |
""" | |
assert not self.closed, "Env supervisor has closed." | |
send_payloads = [] | |
kw_param = kw_param or self._reset_param[env_id] | |
if self._env_replay_path is not None and self.env_states[env_id] == EnvState.RUN: | |
logging.warning("Please don't reset an unfinished env when you enable save replay, we just skip it") | |
return send_payloads | |
# Reset env | |
payload = SendPayload(proc_id=env_id, method="reset", kwargs=kw_param) | |
send_payloads.append(payload) | |
self.send(payload) | |
return send_payloads | |
def _send_seed(self, env_seed: Dict[int, int], env_dynamic_seed: Optional[bool] = None, block: bool = True) -> None: | |
send_payloads = [] | |
for env_id, seed in env_seed.items(): | |
if seed is None: | |
continue | |
args = [seed] | |
if env_dynamic_seed is not None: | |
args.append(env_dynamic_seed) | |
payload = SendPayload(proc_id=env_id, method="seed", args=args) | |
send_payloads.append(payload) | |
self.send(payload) | |
if not block or not send_payloads: | |
return | |
self.recv_all(send_payloads, ignore_err=True, callback=self._recv_callback, timeout=self._reset_timeout) | |
def change_state(self, payload: RecvPayload): | |
self._last_called[payload.proc_id][payload.method] = math.inf # Have recevied | |
if payload.err: | |
self._env_states[payload.proc_id] = EnvState.ERROR | |
elif payload.method == "reset": | |
self._env_states[payload.proc_id] = EnvState.RUN | |
elif payload.method == "step": | |
if payload.data[2]: | |
self._env_states[payload.proc_id] = EnvState.DONE | |
def send(self, payload: SendPayload) -> None: | |
self._last_called[payload.proc_id][payload.method] = time() | |
return super().send(payload) | |
def seed(self, seed: Union[Dict[int, int], List[int], int], dynamic_seed: Optional[bool] = None) -> None: | |
""" | |
Overview: | |
Set the seed for each environment. The seed function will not be called until supervisor.launch \ | |
was called. | |
Arguments: | |
- seed (:obj:`Union[Dict[int, int], List[int], int]`): List of seeds for each environment; \ | |
Or one seed for the first environment and other seeds are generated automatically. \ | |
Note that in threading mode, no matter how many seeds are given, only the last one will take effect. \ | |
Because the execution in the thread is asynchronous, the results of each experiment \ | |
are different even if a fixed seed is used. | |
- dynamic_seed (:obj:`Optional[bool]`): Dynamic seed is used in the training environment, \ | |
trying to make the random seed of each episode different, they are all generated in the reset \ | |
method by a random generator 100 * np.random.randint(1 , 1000) (but the seed of this random \ | |
number generator is fixed by the environmental seed method, guranteeing the reproducibility \ | |
of the experiment). You need not pass the dynamic_seed parameter in the seed method, or pass \ | |
the parameter as True. | |
""" | |
self._env_seed = {} | |
if isinstance(seed, numbers.Integral): | |
self._env_seed = {i: seed + i for i in range(self.env_num)} | |
elif isinstance(seed, list): | |
assert len(seed) == self.env_num, "len(seed) {:d} != env_num {:d}".format(len(seed), self.env_num) | |
self._env_seed = {i: _seed for i, _seed in enumerate(seed)} | |
elif isinstance(seed, dict): | |
self._env_seed = {env_id: s for env_id, s in seed.items()} | |
else: | |
raise TypeError("Invalid seed arguments type: {}".format(type(seed))) | |
self._env_dynamic_seed = dynamic_seed | |
def enable_save_replay(self, replay_path: Union[List[str], str]) -> None: | |
""" | |
Overview: | |
Set each env's replay save path. | |
Arguments: | |
- replay_path (:obj:`Union[List[str], str]`): List of paths for each environment; \ | |
Or one path for all environments. | |
""" | |
if isinstance(replay_path, str): | |
replay_path = [replay_path] * self.env_num | |
self._env_replay_path = replay_path | |
def _enable_env_replay(self): | |
if self._env_replay_path is None: | |
return | |
send_payloads = [] | |
for env_id, s in enumerate(self._env_replay_path): | |
payload = SendPayload(proc_id=env_id, method="enable_save_replay", args=[s]) | |
send_payloads.append(payload) | |
self.send(payload) | |
self.recv_all(send_payloads=send_payloads) | |
def __getattr__(self, key: str) -> List[Any]: | |
if not hasattr(self._env_ref, key): | |
raise AttributeError("env `{}` doesn't have the attribute `{}`".format(type(self._env_ref), key)) | |
return super().__getattr__(key) | |
def close(self, timeout: Optional[float] = None) -> None: | |
""" | |
In order to be compatible with BaseEnvManager, the new version can use `shutdown` directly. | |
""" | |
self.shutdown(timeout=timeout) | |
def shutdown(self, timeout: Optional[float] = None) -> None: | |
if self._running: | |
send_payloads = [] | |
for env_id in range(self.env_num): | |
payload = SendPayload(proc_id=env_id, method="close") | |
send_payloads.append(payload) | |
self.send(payload) | |
self.recv_all(send_payloads=send_payloads, ignore_err=True, timeout=timeout) | |
super().shutdown(timeout=timeout) | |
self._init_states() | |
def closed(self) -> bool: | |
return not self._running | |