Spaces:
Running
Running
"""A compatibility wrapper converting an old-style environment into a valid environment.""" | |
import sys | |
from typing import Any, Dict, Optional, Tuple | |
import gym | |
from gym.core import ObsType | |
from gym.utils.step_api_compatibility import convert_to_terminated_truncated_step_api | |
if sys.version_info >= (3, 8): | |
from typing import Protocol, runtime_checkable | |
elif sys.version_info >= (3, 7): | |
from typing_extensions import Protocol, runtime_checkable | |
else: | |
Protocol = object | |
runtime_checkable = lambda x: x # noqa: E731 | |
class LegacyEnv(Protocol): | |
"""A protocol for environments using the old step API.""" | |
observation_space: gym.Space | |
action_space: gym.Space | |
def reset(self) -> Any: | |
"""Reset the environment and return the initial observation.""" | |
... | |
def step(self, action: Any) -> Tuple[Any, float, bool, Dict]: | |
"""Run one timestep of the environment's dynamics.""" | |
... | |
def render(self, mode: Optional[str] = "human") -> Any: | |
"""Render the environment.""" | |
... | |
def close(self): | |
"""Close the environment.""" | |
... | |
def seed(self, seed: Optional[int] = None): | |
"""Set the seed for this env's random number generator(s).""" | |
... | |
class EnvCompatibility(gym.Env): | |
r"""A wrapper which can transform an environment from the old API to the new API. | |
Old step API refers to step() method returning (observation, reward, done, info), and reset() only retuning the observation. | |
New step API refers to step() method returning (observation, reward, terminated, truncated, info) and reset() returning (observation, info). | |
(Refer to docs for details on the API change) | |
Known limitations: | |
- Environments that use `self.np_random` might not work as expected. | |
""" | |
def __init__(self, old_env: LegacyEnv, render_mode: Optional[str] = None): | |
"""A wrapper which converts old-style envs to valid modern envs. | |
Some information may be lost in the conversion, so we recommend updating your environment. | |
Args: | |
old_env (LegacyEnv): the env to wrap, implemented with the old API | |
render_mode (str): the render mode to use when rendering the environment, passed automatically to env.render | |
""" | |
self.metadata = getattr(old_env, "metadata", {"render_modes": []}) | |
self.render_mode = render_mode | |
self.reward_range = getattr(old_env, "reward_range", None) | |
self.spec = getattr(old_env, "spec", None) | |
self.env = old_env | |
self.observation_space = old_env.observation_space | |
self.action_space = old_env.action_space | |
def reset( | |
self, seed: Optional[int] = None, options: Optional[dict] = None | |
) -> Tuple[ObsType, dict]: | |
"""Resets the environment. | |
Args: | |
seed: the seed to reset the environment with | |
options: the options to reset the environment with | |
Returns: | |
(observation, info) | |
""" | |
if seed is not None: | |
self.env.seed(seed) | |
# Options are ignored | |
if self.render_mode == "human": | |
self.render() | |
return self.env.reset(), {} | |
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]: | |
"""Steps through the environment. | |
Args: | |
action: action to step through the environment with | |
Returns: | |
(observation, reward, terminated, truncated, info) | |
""" | |
obs, reward, done, info = self.env.step(action) | |
if self.render_mode == "human": | |
self.render() | |
return convert_to_terminated_truncated_step_api((obs, reward, done, info)) | |
def render(self) -> Any: | |
"""Renders the environment. | |
Returns: | |
The rendering of the environment, depending on the render mode | |
""" | |
return self.env.render(mode=self.render_mode) | |
def close(self): | |
"""Closes the environment.""" | |
self.env.close() | |
def __str__(self): | |
"""Returns the wrapper name and the unwrapped environment string.""" | |
return f"<{type(self).__name__}{self.env}>" | |
def __repr__(self): | |
"""Returns the string representation of the wrapper.""" | |
return str(self) | |