Spaces:
Sleeping
Sleeping
import random | |
import time | |
from collections import namedtuple | |
import pytest | |
import torch | |
import numpy as np | |
from easydict import EasyDict | |
from functools import partial | |
import gym | |
from ding.envs.env.base_env import BaseEnvTimestep | |
from ding.envs.env_manager.base_env_manager import EnvState | |
from ding.envs.env_manager import BaseEnvManager, SyncSubprocessEnvManager, AsyncSubprocessEnvManager | |
from ding.torch_utils import to_tensor, to_ndarray, to_list | |
from ding.utils import deep_merge_dicts | |
class FakeEnv(object): | |
def __init__(self, cfg): | |
self._scale = cfg.scale | |
self._target_time = random.randint(3, 6) * self._scale | |
self._current_time = 0 | |
self._name = cfg['name'] | |
self._id = time.time() | |
self._stat = None | |
self._seed = 0 | |
self._data_count = 0 | |
self.timeout_flag = False | |
self._launched = False | |
self._state = EnvState.INIT | |
self._dead_once = False | |
self.observation_space = gym.spaces.Box( | |
low=np.array([-1.0, -1.0, -8.0]), high=np.array([1.0, 1.0, 8.0]), shape=(3, ), dtype=np.float32 | |
) | |
self.action_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(1, ), dtype=np.float32) | |
self.reward_space = gym.spaces.Box( | |
low=-1 * (3.14 * 3.14 + 0.1 * 8 * 8 + 0.001 * 2 * 2), high=0.0, shape=(1, ), dtype=np.float32 | |
) | |
def reset(self, stat=None): | |
if isinstance(stat, str) and stat == 'error': | |
self.dead() | |
if isinstance(stat, str) and stat == 'error_once': | |
# Die on every two reset with error_once stat. | |
if self._dead_once: | |
self._dead_once = False | |
self.dead() | |
else: | |
self._dead_once = True | |
if isinstance(stat, str) and stat == "wait": | |
if self.timeout_flag: # after step(), the reset can hall with status of timeout | |
time.sleep(5) | |
if isinstance(stat, str) and stat == "block": | |
self.block() | |
self._launched = True | |
self._current_time = 0 | |
self._stat = stat | |
self._state = EnvState.RUN | |
return to_ndarray(torch.randn(3)) | |
def step(self, action): | |
assert self._launched | |
assert not self._state == EnvState.ERROR | |
self.timeout_flag = True # after one step, enable timeout flag | |
if isinstance(action, str) and action == 'error': | |
self.dead() | |
if isinstance(action, str) and action == 'catched_error': | |
return BaseEnvTimestep(None, None, True, {'abnormal': True}) | |
if isinstance(action, str) and action == "wait": | |
if self.timeout_flag: # after step(), the reset can hall with status of timeout | |
time.sleep(3) | |
if isinstance(action, str) and action == 'block': | |
self.block() | |
obs = to_ndarray(torch.randn(3)) | |
reward = to_ndarray(torch.randint(0, 2, size=[1]).numpy()) | |
done = self._current_time >= self._target_time | |
if done: | |
self._state = EnvState.DONE | |
simulation_time = random.uniform(0.5, 1) * self._scale | |
info = {'name': self._name, 'time': simulation_time, 'tgt': self._target_time, 'cur': self._current_time} | |
time.sleep(simulation_time) | |
self._current_time += simulation_time | |
self._data_count += 1 | |
return BaseEnvTimestep(obs, reward, done, info) | |
def dead(self): | |
self._state = EnvState.ERROR | |
raise RuntimeError("env error, current time {}".format(self._current_time)) | |
def block(self): | |
self._state = EnvState.ERROR | |
time.sleep(1000) | |
def close(self): | |
self._launched = False | |
self._state = EnvState.INIT | |
def seed(self, seed): | |
self._seed = seed | |
def name(self): | |
return self._name | |
def time_id(self): | |
return self._id | |
def user_defined(self): | |
pass | |
def __repr__(self): | |
return self._name | |
class FakeAsyncEnv(FakeEnv): | |
def reset(self, stat=None): | |
super().reset(stat) | |
time.sleep(random.randint(1, 3) * self._scale) | |
return to_ndarray(torch.randn(3)) | |
class FakeGymEnv(FakeEnv): | |
def __init__(self, cfg): | |
super().__init__(cfg) | |
self.metadata = "fake metadata" | |
self.action_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(4, ), dtype=np.float32) | |
def random_action(self) -> np.ndarray: | |
random_action = self.action_space.sample() | |
if isinstance(random_action, np.ndarray): | |
pass | |
elif isinstance(random_action, int): | |
random_action = to_ndarray([random_action], dtype=np.int64) | |
elif isinstance(random_action, dict): | |
random_action = to_ndarray(random_action) | |
else: | |
raise TypeError( | |
'`random_action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format( | |
type(random_action), random_action | |
) | |
) | |
return random_action | |
class FakeModel(object): | |
def forward(self, obs): | |
if random.random() > 0.5: | |
return {k: [] for k in obs} | |
else: | |
env_num = len(obs) | |
exec_env = random.randint(1, env_num + 1) | |
keys = list(obs.keys())[:exec_env] | |
return {k: [] for k in keys} | |
def setup_model_type(): | |
return FakeModel | |
def get_base_manager_cfg(env_num=3): | |
manager_cfg = { | |
'env_cfg': [{ | |
'name': 'name{}'.format(i), | |
'scale': 1.0, | |
} for i in range(env_num)], | |
'episode_num': 2, | |
'reset_timeout': 10, | |
'step_timeout': 8, | |
'max_retry': 5, | |
} | |
return EasyDict(manager_cfg) | |
def get_subprecess_manager_cfg(env_num=3): | |
manager_cfg = { | |
'env_cfg': [{ | |
'name': 'name{}'.format(i), | |
'scale': 1.0, | |
} for i in range(env_num)], | |
'episode_num': 2, | |
#'step_timeout': 8, | |
#'reset_timeout': 10, | |
'connect_timeout': 8, | |
'step_timeout': 5, | |
'max_retry': 2, | |
} | |
return EasyDict(manager_cfg) | |
def get_gym_vector_manager_cfg(env_num=3): | |
manager_cfg = { | |
'env_cfg': [{ | |
'name': 'name{}'.format(i), | |
} for i in range(env_num)], | |
'episode_num': 2, | |
'connect_timeout': 8, | |
'step_timeout': 5, | |
'max_retry': 2, | |
'share_memory': True | |
} | |
return EasyDict(manager_cfg) | |
def setup_base_manager_cfg(): | |
manager_cfg = get_base_manager_cfg(3) | |
env_cfg = manager_cfg.pop('env_cfg') | |
manager_cfg['env_fn'] = [partial(FakeEnv, cfg=c) for c in env_cfg] | |
return deep_merge_dicts(BaseEnvManager.default_config(), EasyDict(manager_cfg)) | |
def setup_fast_base_manager_cfg(): | |
manager_cfg = get_base_manager_cfg(3) | |
env_cfg = manager_cfg.pop('env_cfg') | |
for e in env_cfg: | |
e['scale'] = 0.1 | |
manager_cfg['env_fn'] = [partial(FakeEnv, cfg=c) for c in env_cfg] | |
return deep_merge_dicts(BaseEnvManager.default_config(), EasyDict(manager_cfg)) | |
def setup_sync_manager_cfg(): | |
manager_cfg = get_subprecess_manager_cfg(3) | |
env_cfg = manager_cfg.pop('env_cfg') | |
# TODO(nyz) test fail when shared_memory = True | |
manager_cfg['shared_memory'] = False | |
manager_cfg['env_fn'] = [partial(FakeEnv, cfg=c) for c in env_cfg] | |
return deep_merge_dicts(SyncSubprocessEnvManager.default_config(), EasyDict(manager_cfg)) | |
def setup_async_manager_cfg(): | |
manager_cfg = get_subprecess_manager_cfg(3) | |
env_cfg = manager_cfg.pop('env_cfg') | |
manager_cfg['env_fn'] = [partial(FakeAsyncEnv, cfg=c) for c in env_cfg] | |
manager_cfg['shared_memory'] = False | |
return deep_merge_dicts(AsyncSubprocessEnvManager.default_config(), EasyDict(manager_cfg)) | |
def setup_gym_vector_manager_cfg(): | |
manager_cfg = get_subprecess_manager_cfg(3) | |
env_cfg = manager_cfg.pop('env_cfg') | |
manager_cfg['env_fn'] = [partial(FakeGymEnv, cfg=c) for c in env_cfg] | |
manager_cfg['shared_memory'] = False | |
return EasyDict(manager_cfg) | |