"""A set of functions for passively checking environment implementations.""" import inspect from functools import partial from typing import Callable import numpy as np from gym import Space, error, logger, spaces def _check_box_observation_space(observation_space: spaces.Box): """Checks that a :class:`Box` observation space is defined in a sensible way. Args: observation_space: A box observation space """ # Check if the box is an image if len(observation_space.shape) == 3: if observation_space.dtype != np.uint8: logger.warn( f"It seems a Box observation space is an image but the `dtype` is not `np.uint8`, actual type: {observation_space.dtype}. " "If the Box observation space is not an image, we recommend flattening the observation to have only a 1D vector." ) if np.any(observation_space.low != 0) or np.any(observation_space.high != 255): logger.warn( "It seems a Box observation space is an image but the upper and lower bounds are not in [0, 255]. " "Generally, CNN policies assume observations are within that range, so you may encounter an issue if the observation values are not." ) if len(observation_space.shape) not in [1, 3]: logger.warn( "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). " "We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. " f"Actual observation shape: {observation_space.shape}" ) assert ( observation_space.low.shape == observation_space.shape ), f"The Box observation space shape and low shape have different shapes, low shape: {observation_space.low.shape}, box shape: {observation_space.shape}" assert ( observation_space.high.shape == observation_space.shape ), f"The Box observation space shape and high shape have have different shapes, high shape: {observation_space.high.shape}, box shape: {observation_space.shape}" if np.any(observation_space.low == observation_space.high): logger.warn("A Box observation space maximum and minimum values are equal.") elif np.any(observation_space.high < observation_space.low): logger.warn("A Box observation space low value is greater than a high value.") def _check_box_action_space(action_space: spaces.Box): """Checks that a :class:`Box` action space is defined in a sensible way. Args: action_space: A box action space """ assert ( action_space.low.shape == action_space.shape ), f"The Box action space shape and low shape have have different shapes, low shape: {action_space.low.shape}, box shape: {action_space.shape}" assert ( action_space.high.shape == action_space.shape ), f"The Box action space shape and high shape have different shapes, high shape: {action_space.high.shape}, box shape: {action_space.shape}" if np.any(action_space.low == action_space.high): logger.warn("A Box action space maximum and minimum values are equal.") elif np.any(action_space.high < action_space.low): logger.warn("A Box action space low value is greater than a high value.") def check_space( space: Space, space_type: str, check_box_space_fn: Callable[[spaces.Box], None] ): """A passive check of the environment action space that should not affect the environment.""" if not isinstance(space, spaces.Space): raise AssertionError( f"{space_type} space does not inherit from `gym.spaces.Space`, actual type: {type(space)}" ) elif isinstance(space, spaces.Box): check_box_space_fn(space) elif isinstance(space, spaces.Discrete): assert ( 0 < space.n ), f"Discrete {space_type} space's number of elements must be positive, actual number of elements: {space.n}" assert ( space.shape == () ), f"Discrete {space_type} space's shape should be empty, actual shape: {space.shape}" elif isinstance(space, spaces.MultiDiscrete): assert ( space.shape == space.nvec.shape ), f"Multi-discrete {space_type} space's shape must be equal to the nvec shape, space shape: {space.shape}, nvec shape: {space.nvec.shape}" assert np.all( 0 < space.nvec ), f"Multi-discrete {space_type} space's all nvec elements must be greater than 0, actual nvec: {space.nvec}" elif isinstance(space, spaces.MultiBinary): assert np.all( 0 < np.asarray(space.shape) ), f"Multi-binary {space_type} space's all shape elements must be greater than 0, actual shape: {space.shape}" elif isinstance(space, spaces.Tuple): assert 0 < len( space.spaces ), f"An empty Tuple {space_type} space is not allowed." for subspace in space.spaces: check_space(subspace, space_type, check_box_space_fn) elif isinstance(space, spaces.Dict): assert 0 < len( space.spaces.keys() ), f"An empty Dict {space_type} space is not allowed." for subspace in space.values(): check_space(subspace, space_type, check_box_space_fn) check_observation_space = partial( check_space, space_type="observation", check_box_space_fn=_check_box_observation_space, ) check_action_space = partial( check_space, space_type="action", check_box_space_fn=_check_box_action_space ) def check_obs(obs, observation_space: spaces.Space, method_name: str): """Check that the observation returned by the environment correspond to the declared one. Args: obs: The observation to check observation_space: The observation space of the observation method_name: The method name that generated the observation """ pre = f"The obs returned by the `{method_name}()` method" if isinstance(observation_space, spaces.Discrete): if not isinstance(obs, (np.int64, int)): logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}") elif isinstance(observation_space, spaces.Box): if observation_space.shape != (): if not isinstance(obs, np.ndarray): logger.warn( f"{pre} was expecting a numpy array, actual type: {type(obs)}" ) elif obs.dtype != observation_space.dtype: logger.warn( f"{pre} was expecting numpy array dtype to be {observation_space.dtype}, actual type: {obs.dtype}" ) elif isinstance(observation_space, (spaces.MultiBinary, spaces.MultiDiscrete)): if not isinstance(obs, np.ndarray): logger.warn(f"{pre} was expecting a numpy array, actual type: {type(obs)}") elif isinstance(observation_space, spaces.Tuple): if not isinstance(obs, tuple): logger.warn(f"{pre} was expecting a tuple, actual type: {type(obs)}") assert len(obs) == len( observation_space.spaces ), f"{pre} length is not same as the observation space length, obs length: {len(obs)}, space length: {len(observation_space.spaces)}" for sub_obs, sub_space in zip(obs, observation_space.spaces): check_obs(sub_obs, sub_space, method_name) elif isinstance(observation_space, spaces.Dict): assert isinstance(obs, dict), f"{pre} must be a dict, actual type: {type(obs)}" assert ( obs.keys() == observation_space.spaces.keys() ), f"{pre} observation keys is not same as the observation space keys, obs keys: {list(obs.keys())}, space keys: {list(observation_space.spaces.keys())}" for space_key in observation_space.spaces.keys(): check_obs(obs[space_key], observation_space[space_key], method_name) try: if obs not in observation_space: logger.warn(f"{pre} is not within the observation space.") except Exception as e: logger.warn(f"{pre} is not within the observation space with exception: {e}") def env_reset_passive_checker(env, **kwargs): """A passive check of the `Env.reset` function investigating the returning reset information and returning the data unchanged.""" signature = inspect.signature(env.reset) if "seed" not in signature.parameters and "kwargs" not in signature.parameters: logger.warn( "Future gym versions will require that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator." ) else: seed_param = signature.parameters.get("seed") # Check the default value is None if seed_param is not None and seed_param.default is not None: logger.warn( "The default seed argument in `Env.reset` should be `None`, otherwise the environment will by default always be deterministic. " f"Actual default: {seed_param}" ) if "options" not in signature.parameters and "kwargs" not in signature.parameters: logger.warn( "Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information." ) # Checks the result of env.reset with kwargs result = env.reset(**kwargs) if not isinstance(result, tuple): logger.warn( f"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `{type(result)}`" ) elif len(result) != 2: logger.warn( "The result returned by `env.reset()` should be `(obs, info)` by default, , where `obs` is a observation and `info` is a dictionary containing additional information." ) else: obs, info = result check_obs(obs, env.observation_space, "reset") assert isinstance( info, dict ), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}" return result def env_step_passive_checker(env, action): """A passive check for the environment step, investigating the returning data then returning the data unchanged.""" # We don't check the action as for some environments then out-of-bounds values can be given result = env.step(action) assert isinstance( result, tuple ), f"Expects step result to be a tuple, actual type: {type(result)}" if len(result) == 4: logger.deprecation( "Core environment is written in old step API which returns one bool instead of two. " "It is recommended to rewrite the environment with new step API. " ) obs, reward, done, info = result if not isinstance(done, (bool, np.bool8)): logger.warn( f"Expects `done` signal to be a boolean, actual type: {type(done)}" ) elif len(result) == 5: obs, reward, terminated, truncated, info = result # np.bool is actual python bool not np boolean type, therefore bool_ or bool8 if not isinstance(terminated, (bool, np.bool8)): logger.warn( f"Expects `terminated` signal to be a boolean, actual type: {type(terminated)}" ) if not isinstance(truncated, (bool, np.bool8)): logger.warn( f"Expects `truncated` signal to be a boolean, actual type: {type(truncated)}" ) else: raise error.Error( f"Expected `Env.step` to return a four or five element tuple, actual number of elements returned: {len(result)}." ) check_obs(obs, env.observation_space, "step") if not ( np.issubdtype(type(reward), np.integer) or np.issubdtype(type(reward), np.floating) ): logger.warn( f"The reward returned by `step()` must be a float, int, np.integer or np.floating, actual type: {type(reward)}" ) else: if np.isnan(reward): logger.warn("The reward is a NaN value.") if np.isinf(reward): logger.warn("The reward is an inf value.") assert isinstance( info, dict ), f"The `info` returned by `step()` must be a python dictionary, actual type: {type(info)}" return result def env_render_passive_checker(env, *args, **kwargs): """A passive check of the `Env.render` that the declared render modes/fps in the metadata of the environment is declared.""" render_modes = env.metadata.get("render_modes") if render_modes is None: logger.warn( "No render modes was declared in the environment (env.metadata['render_modes'] is None or not defined), you may have trouble when calling `.render()`." ) else: if not isinstance(render_modes, (list, tuple)): logger.warn( f"Expects the render_modes to be a sequence (i.e. list, tuple), actual type: {type(render_modes)}" ) elif not all(isinstance(mode, str) for mode in render_modes): logger.warn( f"Expects all render modes to be strings, actual types: {[type(mode) for mode in render_modes]}" ) render_fps = env.metadata.get("render_fps") # We only require `render_fps` if rendering is actually implemented if len(render_modes) > 0: if render_fps is None: logger.warn( "No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps." ) else: if not ( np.issubdtype(type(render_fps), np.integer) or np.issubdtype(type(render_fps), np.floating) ): logger.warn( f"Expects the `env.metadata['render_fps']` to be an integer or a float, actual type: {type(render_fps)}" ) else: assert ( render_fps > 0 ), f"Expects the `env.metadata['render_fps']` to be greater than zero, actual value: {render_fps}" # env.render is now an attribute with default None if len(render_modes) == 0: assert ( env.render_mode is None ), f"With no render_modes, expects the Env.render_mode to be None, actual value: {env.render_mode}" else: assert env.render_mode is None or env.render_mode in render_modes, ( "The environment was initialized successfully however with an unsupported render mode. " f"Render mode: {env.render_mode}, modes: {render_modes}" ) result = env.render(*args, **kwargs) # TODO: Check that the result is correct return result