import contextlib import copy import difflib import importlib import importlib.util import re import sys import warnings from dataclasses import dataclass, field from typing import ( Callable, Dict, List, Optional, Sequence, SupportsFloat, Tuple, Union, overload, ) import numpy as np from gym.wrappers import ( AutoResetWrapper, HumanRendering, OrderEnforcing, RenderCollection, TimeLimit, ) from gym.wrappers.compatibility import EnvCompatibility from gym.wrappers.env_checker import PassiveEnvChecker if sys.version_info < (3, 10): import importlib_metadata as metadata # type: ignore else: import importlib.metadata as metadata if sys.version_info >= (3, 8): from typing import Literal else: from typing_extensions import Literal from gym import Env, error, logger ENV_ID_RE = re.compile( r"^(?:(?P[\w:-]+)\/)?(?:(?P[\w:.-]+?))(?:-v(?P\d+))?$" ) def load(name: str) -> callable: """Loads an environment with name and returns an environment creation function Args: name: The environment name Returns: Calls the environment constructor """ mod_name, attr_name = name.split(":") mod = importlib.import_module(mod_name) fn = getattr(mod, attr_name) return fn def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]: """Parse environment ID string format. This format is true today, but it's *not* an official spec. [namespace/](env-name)-v(version) env-name is group 1, version is group 2 2016-10-31: We're experimentally expanding the environment ID format to include an optional namespace. Args: id: The environment id to parse Returns: A tuple of environment namespace, environment name and version number Raises: Error: If the environment id does not a valid environment regex """ match = ENV_ID_RE.fullmatch(id) if not match: raise error.Error( f"Malformed environment ID: {id}." f"(Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))" ) namespace, name, version = match.group("namespace", "name", "version") if version is not None: version = int(version) return namespace, name, version def get_env_id(ns: Optional[str], name: str, version: Optional[int]) -> str: """Get the full env ID given a name and (optional) version and namespace. Inverse of :meth:`parse_env_id`. Args: ns: The environment namespace name: The environment name version: The environment version Returns: The environment id """ full_name = name if version is not None: full_name += f"-v{version}" if ns is not None: full_name = ns + "/" + full_name return full_name @dataclass class EnvSpec: """A specification for creating environments with `gym.make`. * id: The string used to create the environment with `gym.make` * entry_point: The location of the environment to create from * reward_threshold: The reward threshold for completing the environment. * nondeterministic: If the observation of an environment cannot be repeated with the same initial state, random number generator state and actions. * max_episode_steps: The max number of steps that the environment can take before truncation * order_enforce: If to enforce the order of `reset` before `step` and `render` functions * autoreset: If to automatically reset the environment on episode end * disable_env_checker: If to disable the environment checker wrapper in `gym.make`, by default False (runs the environment checker) * kwargs: Additional keyword arguments passed to the environments through `gym.make` """ id: str entry_point: Union[Callable, str] # Environment attributes reward_threshold: Optional[float] = field(default=None) nondeterministic: bool = field(default=False) # Wrappers max_episode_steps: Optional[int] = field(default=None) order_enforce: bool = field(default=True) autoreset: bool = field(default=False) disable_env_checker: bool = field(default=False) apply_api_compatibility: bool = field(default=False) # Environment arguments kwargs: dict = field(default_factory=dict) # post-init attributes namespace: Optional[str] = field(init=False) name: str = field(init=False) version: Optional[int] = field(init=False) def __post_init__(self): # Initialize namespace, name, version self.namespace, self.name, self.version = parse_env_id(self.id) def make(self, **kwargs) -> Env: # For compatibility purposes return make(self, **kwargs) def _check_namespace_exists(ns: Optional[str]): """Check if a namespace exists. If it doesn't, print a helpful error message.""" if ns is None: return namespaces = { spec_.namespace for spec_ in registry.values() if spec_.namespace is not None } if ns in namespaces: return suggestion = ( difflib.get_close_matches(ns, namespaces, n=1) if len(namespaces) > 0 else None ) suggestion_msg = ( f"Did you mean: `{suggestion[0]}`?" if suggestion else f"Have you installed the proper package for {ns}?" ) raise error.NamespaceNotFound(f"Namespace {ns} not found. {suggestion_msg}") def _check_name_exists(ns: Optional[str], name: str): """Check if an env exists in a namespace. If it doesn't, print a helpful error message.""" _check_namespace_exists(ns) names = {spec_.name for spec_ in registry.values() if spec_.namespace == ns} if name in names: return suggestion = difflib.get_close_matches(name, names, n=1) namespace_msg = f" in namespace {ns}" if ns else "" suggestion_msg = f"Did you mean: `{suggestion[0]}`?" if suggestion else "" raise error.NameNotFound( f"Environment {name} doesn't exist{namespace_msg}. {suggestion_msg}" ) def _check_version_exists(ns: Optional[str], name: str, version: Optional[int]): """Check if an env version exists in a namespace. If it doesn't, print a helpful error message. This is a complete test whether an environment identifier is valid, and will provide the best available hints. Args: ns: The environment namespace name: The environment space version: The environment version Raises: DeprecatedEnv: The environment doesn't exist but a default version does VersionNotFound: The ``version`` used doesn't exist DeprecatedEnv: Environment version is deprecated """ if get_env_id(ns, name, version) in registry: return _check_name_exists(ns, name) if version is None: return message = f"Environment version `v{version}` for environment `{get_env_id(ns, name, None)}` doesn't exist." env_specs = [ spec_ for spec_ in registry.values() if spec_.namespace == ns and spec_.name == name ] env_specs = sorted(env_specs, key=lambda spec_: int(spec_.version or -1)) default_spec = [spec_ for spec_ in env_specs if spec_.version is None] if default_spec: message += f" It provides the default version {default_spec[0].id}`." if len(env_specs) == 1: raise error.DeprecatedEnv(message) # Process possible versioned environments versioned_specs = [spec_ for spec_ in env_specs if spec_.version is not None] latest_spec = max(versioned_specs, key=lambda spec: spec.version, default=None) # type: ignore if latest_spec is not None and version > latest_spec.version: version_list_msg = ", ".join(f"`v{spec_.version}`" for spec_ in env_specs) message += f" It provides versioned environments: [ {version_list_msg} ]." raise error.VersionNotFound(message) if latest_spec is not None and version < latest_spec.version: raise error.DeprecatedEnv( f"Environment version v{version} for `{get_env_id(ns, name, None)}` is deprecated. " f"Please use `{latest_spec.id}` instead." ) def find_highest_version(ns: Optional[str], name: str) -> Optional[int]: version: List[int] = [ spec_.version for spec_ in registry.values() if spec_.namespace == ns and spec_.name == name and spec_.version is not None ] return max(version, default=None) def load_env_plugins(entry_point: str = "gym.envs") -> None: # Load third-party environments for plugin in metadata.entry_points(group=entry_point): # Python 3.8 doesn't support plugin.module, plugin.attr # So we'll have to try and parse this ourselves module, attr = None, None try: module, attr = plugin.module, plugin.attr # type: ignore ## error: Cannot access member "attr" for type "EntryPoint" except AttributeError: if ":" in plugin.value: module, attr = plugin.value.split(":", maxsplit=1) else: module, attr = plugin.value, None except Exception as e: warnings.warn( f"While trying to load plugin `{plugin}` from {entry_point}, an exception occurred: {e}" ) module, attr = None, None finally: if attr is None: raise error.Error( f"Gym environment plugin `{module}` must specify a function to execute, not a root module" ) context = namespace(plugin.name) if plugin.name.startswith("__") and plugin.name.endswith("__"): # `__internal__` is an artifact of the plugin system when # the root namespace had an allow-list. The allow-list is now # removed and plugins can register environments in the root # namespace with the `__root__` magic key. if plugin.name == "__root__" or plugin.name == "__internal__": context = contextlib.nullcontext() else: logger.warn( f"The environment namespace magic key `{plugin.name}` is unsupported. " "To register an environment at the root namespace you should specify the `__root__` namespace." ) with context: fn = plugin.load() try: fn() except Exception as e: logger.warn(str(e)) # fmt: off @overload def make(id: str, **kwargs) -> Env: ... @overload def make(id: EnvSpec, **kwargs) -> Env: ... # Classic control # ---------------------------------------- @overload def make(id: Literal["CartPole-v0", "CartPole-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ... @overload def make(id: Literal["MountainCar-v0"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ... @overload def make(id: Literal["MountainCarContinuous-v0"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ... @overload def make(id: Literal["Pendulum-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ... @overload def make(id: Literal["Acrobot-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ... # Box2d # ---------------------------------------- @overload def make(id: Literal["LunarLander-v2", "LunarLanderContinuous-v2"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ... @overload def make(id: Literal["BipedalWalker-v3", "BipedalWalkerHardcore-v3"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ... @overload def make(id: Literal["CarRacing-v2"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, Sequence[SupportsFloat]]]: ... # Toy Text # ---------------------------------------- @overload def make(id: Literal["Blackjack-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ... @overload def make(id: Literal["FrozenLake-v1", "FrozenLake8x8-v1"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ... @overload def make(id: Literal["CliffWalking-v0"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ... @overload def make(id: Literal["Taxi-v3"], **kwargs) -> Env[np.ndarray, Union[np.ndarray, int]]: ... # Mujoco # ---------------------------------------- @overload def make(id: Literal[ "Reacher-v2", "Reacher-v4", "Pusher-v2", "Pusher-v4", "InvertedPendulum-v2", "InvertedPendulum-v4", "InvertedDoublePendulum-v2", "InvertedDoublePendulum-v4", "HalfCheetah-v2", "HalfCheetah-v3", "HalfCheetah-v4", "Hopper-v2", "Hopper-v3", "Hopper-v4", "Swimmer-v2", "Swimmer-v3", "Swimmer-v4", "Walker2d-v2", "Walker2d-v3", "Walker2d-v4", "Ant-v2", "Ant-v3", "Ant-v4", "HumanoidStandup-v2", "HumanoidStandup-v4", "Humanoid-v2", "Humanoid-v3", "Humanoid-v4", ], **kwargs) -> Env[np.ndarray, np.ndarray]: ... # fmt: on # Global registry of environments. Meant to be accessed through `register` and `make` registry: Dict[str, EnvSpec] = {} current_namespace: Optional[str] = None def _check_spec_register(spec: EnvSpec): """Checks whether the spec is valid to be registered. Helper function for `register`.""" global registry latest_versioned_spec = max( ( spec_ for spec_ in registry.values() if spec_.namespace == spec.namespace and spec_.name == spec.name and spec_.version is not None ), key=lambda spec_: int(spec_.version), # type: ignore default=None, ) unversioned_spec = next( ( spec_ for spec_ in registry.values() if spec_.namespace == spec.namespace and spec_.name == spec.name and spec_.version is None ), None, ) if unversioned_spec is not None and spec.version is not None: raise error.RegistrationError( "Can't register the versioned environment " f"`{spec.id}` when the unversioned environment " f"`{unversioned_spec.id}` of the same name already exists." ) elif latest_versioned_spec is not None and spec.version is None: raise error.RegistrationError( "Can't register the unversioned environment " f"`{spec.id}` when the versioned environment " f"`{latest_versioned_spec.id}` of the same name " f"already exists. Note: the default behavior is " f"that `gym.make` with the unversioned environment " f"will return the latest versioned environment" ) # Public API @contextlib.contextmanager def namespace(ns: str): global current_namespace old_namespace = current_namespace current_namespace = ns yield current_namespace = old_namespace def register( id: str, entry_point: Union[Callable, str], reward_threshold: Optional[float] = None, nondeterministic: bool = False, max_episode_steps: Optional[int] = None, order_enforce: bool = True, autoreset: bool = False, disable_env_checker: bool = False, apply_api_compatibility: bool = False, **kwargs, ): """Register an environment with gym. The `id` parameter corresponds to the name of the environment, with the syntax as follows: `(namespace)/(env_name)-v(version)` where `namespace` is optional. It takes arbitrary keyword arguments, which are passed to the `EnvSpec` constructor. Args: id: The environment id entry_point: The entry point for creating the environment reward_threshold: The reward threshold considered to have learnt an environment nondeterministic: If the environment is nondeterministic (even with knowledge of the initial seed and all actions) max_episode_steps: The maximum number of episodes steps before truncation. Used by the Time Limit wrapper. order_enforce: If to enable the order enforcer wrapper to ensure users run functions in the correct order autoreset: If to add the autoreset wrapper such that reset does not need to be called. disable_env_checker: If to disable the environment checker for the environment. Recommended to False. apply_api_compatibility: If to apply the `StepAPICompatibility` wrapper. **kwargs: arbitrary keyword arguments which are passed to the environment constructor """ global registry, current_namespace ns, name, version = parse_env_id(id) if current_namespace is not None: if ( kwargs.get("namespace") is not None and kwargs.get("namespace") != current_namespace ): logger.warn( f"Custom namespace `{kwargs.get('namespace')}` is being overridden by namespace `{current_namespace}`. " f"If you are developing a plugin you shouldn't specify a namespace in `register` calls. " "The namespace is specified through the entry point package metadata." ) ns_id = current_namespace else: ns_id = ns full_id = get_env_id(ns_id, name, version) new_spec = EnvSpec( id=full_id, entry_point=entry_point, reward_threshold=reward_threshold, nondeterministic=nondeterministic, max_episode_steps=max_episode_steps, order_enforce=order_enforce, autoreset=autoreset, disable_env_checker=disable_env_checker, apply_api_compatibility=apply_api_compatibility, **kwargs, ) _check_spec_register(new_spec) if new_spec.id in registry: logger.warn(f"Overriding environment {new_spec.id} already in registry.") registry[new_spec.id] = new_spec def make( id: Union[str, EnvSpec], max_episode_steps: Optional[int] = None, autoreset: bool = False, apply_api_compatibility: Optional[bool] = None, disable_env_checker: Optional[bool] = None, **kwargs, ) -> Env: """Create an environment according to the given ID. To find all available environments use `gym.envs.registry.keys()` for all valid ids. Args: id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0' max_episode_steps: Maximum length of an episode (TimeLimit wrapper). autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper). apply_api_compatibility: Whether to wrap the environment with the `StepAPICompatibility` wrapper that converts the environment step from a done bool to return termination and truncation bools. By default, the argument is None to which the environment specification `apply_api_compatibility` is used which defaults to False. Otherwise, the value of `apply_api_compatibility` is used. If `True`, the wrapper is applied otherwise, the wrapper is not applied. disable_env_checker: If to run the env checker, None will default to the environment specification `disable_env_checker` (which is by default False, running the environment checker), otherwise will run according to this parameter (`True` = not run, `False` = run) kwargs: Additional arguments to pass to the environment constructor. Returns: An instance of the environment. Raises: Error: If the ``id`` doesn't exist then an error is raised """ if isinstance(id, EnvSpec): spec_ = id else: module, id = (None, id) if ":" not in id else id.split(":") if module is not None: try: importlib.import_module(module) except ModuleNotFoundError as e: raise ModuleNotFoundError( f"{e}. Environment registration via importing a module failed. " f"Check whether '{module}' contains env registration and can be imported." ) spec_ = registry.get(id) ns, name, version = parse_env_id(id) latest_version = find_highest_version(ns, name) if ( version is not None and latest_version is not None and latest_version > version ): logger.warn( f"The environment {id} is out of date. You should consider " f"upgrading to version `v{latest_version}`." ) if version is None and latest_version is not None: version = latest_version new_env_id = get_env_id(ns, name, version) spec_ = registry.get(new_env_id) logger.warn( f"Using the latest versioned environment `{new_env_id}` " f"instead of the unversioned environment `{id}`." ) if spec_ is None: _check_version_exists(ns, name, version) raise error.Error(f"No registered env with id: {id}") _kwargs = spec_.kwargs.copy() _kwargs.update(kwargs) if spec_.entry_point is None: raise error.Error(f"{spec_.id} registered but entry_point is not specified") elif callable(spec_.entry_point): env_creator = spec_.entry_point else: # Assume it's a string env_creator = load(spec_.entry_point) mode = _kwargs.get("render_mode") apply_human_rendering = False apply_render_collection = False # If we have access to metadata we check that "render_mode" is valid and see if the HumanRendering wrapper needs to be applied if mode is not None and hasattr(env_creator, "metadata"): assert isinstance( env_creator.metadata, dict ), f"Expect the environment creator ({env_creator}) metadata to be dict, actual type: {type(env_creator.metadata)}" if "render_modes" in env_creator.metadata: render_modes = env_creator.metadata["render_modes"] if not isinstance(render_modes, Sequence): logger.warn( f"Expects the environment metadata render_modes to be a Sequence (tuple or list), actual type: {type(render_modes)}" ) # Apply the `HumanRendering` wrapper, if the mode=="human" but "human" not in render_modes if ( mode == "human" and "human" not in render_modes and ("rgb_array" in render_modes or "rgb_array_list" in render_modes) ): logger.warn( "You are trying to use 'human' rendering for an environment that doesn't natively support it. " "The HumanRendering wrapper is being applied to your environment." ) apply_human_rendering = True if "rgb_array" in render_modes: _kwargs["render_mode"] = "rgb_array" else: _kwargs["render_mode"] = "rgb_array_list" elif ( mode not in render_modes and mode.endswith("_list") and mode[: -len("_list")] in render_modes ): _kwargs["render_mode"] = mode[: -len("_list")] apply_render_collection = True elif mode not in render_modes: logger.warn( f"The environment is being initialised with mode ({mode}) that is not in the possible render_modes ({render_modes})." ) else: logger.warn( f"The environment creator metadata doesn't include `render_modes`, contains: {list(env_creator.metadata.keys())}" ) if apply_api_compatibility is True or ( apply_api_compatibility is None and spec_.apply_api_compatibility is True ): # If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator render_mode = _kwargs.pop("render_mode", None) else: render_mode = None try: env = env_creator(**_kwargs) except TypeError as e: if ( str(e).find("got an unexpected keyword argument 'render_mode'") >= 0 and apply_human_rendering ): raise error.Error( f"You passed render_mode='human' although {id} doesn't implement human-rendering natively. " "Gym tried to apply the HumanRendering wrapper but it looks like your environment is using the old " "rendering API, which is not supported by the HumanRendering wrapper." ) else: raise e # Copies the environment creation specification and kwargs to add to the environment specification details spec_ = copy.deepcopy(spec_) spec_.kwargs = _kwargs env.unwrapped.spec = spec_ # Add step API wrapper if apply_api_compatibility is True or ( apply_api_compatibility is None and spec_.apply_api_compatibility is True ): env = EnvCompatibility(env, render_mode) # Run the environment checker as the lowest level wrapper if disable_env_checker is False or ( disable_env_checker is None and spec_.disable_env_checker is False ): env = PassiveEnvChecker(env) # Add the order enforcing wrapper if spec_.order_enforce: env = OrderEnforcing(env) # Add the time limit wrapper if max_episode_steps is not None: env = TimeLimit(env, max_episode_steps) elif spec_.max_episode_steps is not None: env = TimeLimit(env, spec_.max_episode_steps) # Add the autoreset wrapper if autoreset: env = AutoResetWrapper(env) # Add human rendering wrapper if apply_human_rendering: env = HumanRendering(env) elif apply_render_collection: env = RenderCollection(env) return env def spec(env_id: str) -> EnvSpec: """Retrieve the spec for the given environment from the global registry.""" spec_ = registry.get(env_id) if spec_ is None: ns, name, version = parse_env_id(env_id) _check_version_exists(ns, name, version) raise error.Error(f"No registered env with id: {env_id}") else: assert isinstance(spec_, EnvSpec) return spec_