Spaces:
Sleeping
Sleeping
from tabnanny import check | |
from typing import Any, Callable, List, Tuple | |
import numpy as np | |
from collections.abc import Sequence | |
from easydict import EasyDict | |
from ding.envs.env import BaseEnv, BaseEnvTimestep | |
from ding.envs.env.tests import DemoEnv | |
# from dizoo.atari.envs import AtariEnv | |
def check_space_dtype(env: BaseEnv) -> None: | |
print("== 0. Test obs/act/rew space's dtype") | |
env.reset() | |
for name, space in zip(['obs', 'act', 'rew'], [env.observation_space, env.action_space, env.reward_space]): | |
if 'float' in repr(space.dtype): | |
assert space.dtype == np.float32, "If float, then must be np.float32, but get {} for {} space".format( | |
space.dtype, name | |
) | |
if 'int' in repr(space.dtype): | |
assert space.dtype == np.int64, "If int, then must be np.int64, but get {} for {} space".format( | |
space.dtype, name | |
) | |
# Util function | |
def check_array_space(ndarray, space, name) -> bool: | |
if isinstance(ndarray, np.ndarray): | |
# print("{}'s type should be np.ndarray".format(name)) | |
assert ndarray.dtype == space.dtype, "{}'s dtype is {}, but requires {}".format( | |
name, ndarray.dtype, space.dtype | |
) | |
assert ndarray.shape == space.shape, "{}'s shape is {}, but requires {}".format( | |
name, ndarray.shape, space.shape | |
) | |
assert (space.low <= ndarray).all() and (ndarray <= space.high).all( | |
), "{}'s value is {}, but requires in range ({},{})".format(name, ndarray, space.low, space.high) | |
elif isinstance(ndarray, Sequence): | |
for i in range(len(ndarray)): | |
try: | |
check_array_space(ndarray[i], space[i], name) | |
except AssertionError as e: | |
print("The following error happens at {}-th index".format(i)) | |
raise e | |
elif isinstance(ndarray, dict): | |
for k in ndarray.keys(): | |
try: | |
check_array_space(ndarray[k], space[k], name) | |
except AssertionError as e: | |
print("The following error happens at key {}".format(k)) | |
raise e | |
else: | |
raise TypeError( | |
"Input array should be np.ndarray or sequence/dict of np.ndarray, but found {}".format(type(ndarray)) | |
) | |
def check_reset(env: BaseEnv) -> None: | |
print('== 1. Test reset method') | |
obs = env.reset() | |
check_array_space(obs, env.observation_space, 'obs') | |
def check_step(env: BaseEnv) -> None: | |
done_times = 0 | |
print('== 2. Test step method') | |
_ = env.reset() | |
if hasattr(env, "random_action"): | |
random_action = env.random_action() | |
else: | |
random_action = env.action_space.sample() | |
while True: | |
obs, rew, done, info = env.step(random_action) | |
for ndarray, space, name in zip([obs, rew], [env.observation_space, env.reward_space], ['obs', 'rew']): | |
check_array_space(ndarray, space, name) | |
if done: | |
assert 'eval_episode_return' in info, "info dict should have 'eval_episode_return' key." | |
done_times += 1 | |
_ = env.reset() | |
if done_times == 3: | |
break | |
# Util function | |
def check_different_memory(array1, array2, step_times) -> None: | |
assert type(array1) == type( | |
array2 | |
), "In step times {}, obs_last_frame({}) and obs_this_frame({}) are not of the same type".format( | |
step_times, type(array1), type(array2) | |
) | |
if isinstance(array1, np.ndarray): | |
assert id(array1) != id( | |
array2 | |
), "In step times {}, obs_last_frame and obs_this_frame are the same np.ndarray".format(step_times) | |
elif isinstance(array1, Sequence): | |
assert len(array1) == len( | |
array2 | |
), "In step times {}, obs_last_frame({}) and obs_this_frame({}) have different sequence lengths".format( | |
step_times, len(array1), len(array2) | |
) | |
for i in range(len(array1)): | |
try: | |
check_different_memory(array1[i], array2[i], step_times) | |
except AssertionError as e: | |
print("The following error happens at {}-th index".format(i)) | |
raise e | |
elif isinstance(array1, dict): | |
assert array1.keys() == array2.keys(), "In step times {}, obs_last_frame({}) and obs_this_frame({}) have \ | |
different dict keys".format(step_times, array1.keys(), array2.keys()) | |
for k in array1.keys(): | |
try: | |
check_different_memory(array1[k], array2[k], step_times) | |
except AssertionError as e: | |
print("The following error happens at key {}".format(k)) | |
raise e | |
else: | |
raise TypeError( | |
"Input array should be np.ndarray or list/dict of np.ndarray, but found {} and {}".format( | |
type(array1), type(array2) | |
) | |
) | |
def check_obs_deepcopy(env: BaseEnv) -> None: | |
step_times = 0 | |
print('== 3. Test observation deepcopy') | |
obs_1 = env.reset() | |
if hasattr(env, "random_action"): | |
random_action = env.random_action() | |
else: | |
random_action = env.action_space.sample() | |
while True: | |
step_times += 1 | |
obs_2, _, done, _ = env.step(random_action) | |
check_different_memory(obs_1, obs_2, step_times) | |
obs_1 = obs_2 | |
if done: | |
break | |
def check_all(env: BaseEnv) -> None: | |
check_space_dtype(env) | |
check_reset(env) | |
check_step(env) | |
check_obs_deepcopy(env) | |
def demonstrate_correct_procedure(env_fn: Callable) -> None: | |
print('== 4. Demonstrate the correct procudures') | |
done_times = 0 | |
# Init the env. | |
env = env_fn({}) | |
# Lazy init. The real env is not initialized until `reset` method is called | |
assert not hasattr(env, "_env") | |
# Must set seed before `reset` method is called. | |
env.seed(4) | |
assert env._seed == 4 | |
# Reset the env. The real env is initialized here. | |
obs = env.reset() | |
while True: | |
# Using the policy to get the action from obs. But here we use `random_action` instead. | |
action = env.random_action() | |
obs, rew, done, info = env.step(action) | |
if done: | |
assert 'eval_episode_return' in info | |
done_times += 1 | |
obs = env.reset() | |
# Seed will not change unless `seed` method is called again. | |
assert env._seed == 4 | |
if done_times == 3: | |
break | |
if __name__ == "__main__": | |
''' | |
# Moethods `check_*` are for user to check whether their implemented env obeys DI-engine's rules. | |
# You can replace `AtariEnv` with your own env. | |
atari_env = AtariEnv(EasyDict(env_id='PongNoFrameskip-v4', frame_stack=4, is_train=False)) | |
check_reset(atari_env) | |
check_step(atari_env) | |
check_obs_deepcopy(atari_env) | |
''' | |
# Method `demonstrate_correct_procudure` is to demonstrate the correct procedure to | |
# use an env to generate trajectories. | |
# You can check whether your env's design is similar to `DemoEnv` | |
demonstrate_correct_procedure(DemoEnv) | |