Kano001's picture
Upload 919 files
375a1cf verified
"""A compatibility wrapper converting an old-style environment into a valid environment."""
import sys
from typing import Any, Dict, Optional, Tuple
import gym
from gym.core import ObsType
from gym.utils.step_api_compatibility import convert_to_terminated_truncated_step_api
if sys.version_info >= (3, 8):
from typing import Protocol, runtime_checkable
elif sys.version_info >= (3, 7):
from typing_extensions import Protocol, runtime_checkable
else:
Protocol = object
runtime_checkable = lambda x: x # noqa: E731
@runtime_checkable
class LegacyEnv(Protocol):
"""A protocol for environments using the old step API."""
observation_space: gym.Space
action_space: gym.Space
def reset(self) -> Any:
"""Reset the environment and return the initial observation."""
...
def step(self, action: Any) -> Tuple[Any, float, bool, Dict]:
"""Run one timestep of the environment's dynamics."""
...
def render(self, mode: Optional[str] = "human") -> Any:
"""Render the environment."""
...
def close(self):
"""Close the environment."""
...
def seed(self, seed: Optional[int] = None):
"""Set the seed for this env's random number generator(s)."""
...
class EnvCompatibility(gym.Env):
r"""A wrapper which can transform an environment from the old API to the new API.
Old step API refers to step() method returning (observation, reward, done, info), and reset() only retuning the observation.
New step API refers to step() method returning (observation, reward, terminated, truncated, info) and reset() returning (observation, info).
(Refer to docs for details on the API change)
Known limitations:
- Environments that use `self.np_random` might not work as expected.
"""
def __init__(self, old_env: LegacyEnv, render_mode: Optional[str] = None):
"""A wrapper which converts old-style envs to valid modern envs.
Some information may be lost in the conversion, so we recommend updating your environment.
Args:
old_env (LegacyEnv): the env to wrap, implemented with the old API
render_mode (str): the render mode to use when rendering the environment, passed automatically to env.render
"""
self.metadata = getattr(old_env, "metadata", {"render_modes": []})
self.render_mode = render_mode
self.reward_range = getattr(old_env, "reward_range", None)
self.spec = getattr(old_env, "spec", None)
self.env = old_env
self.observation_space = old_env.observation_space
self.action_space = old_env.action_space
def reset(
self, seed: Optional[int] = None, options: Optional[dict] = None
) -> Tuple[ObsType, dict]:
"""Resets the environment.
Args:
seed: the seed to reset the environment with
options: the options to reset the environment with
Returns:
(observation, info)
"""
if seed is not None:
self.env.seed(seed)
# Options are ignored
if self.render_mode == "human":
self.render()
return self.env.reset(), {}
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
"""Steps through the environment.
Args:
action: action to step through the environment with
Returns:
(observation, reward, terminated, truncated, info)
"""
obs, reward, done, info = self.env.step(action)
if self.render_mode == "human":
self.render()
return convert_to_terminated_truncated_step_api((obs, reward, done, info))
def render(self) -> Any:
"""Renders the environment.
Returns:
The rendering of the environment, depending on the render mode
"""
return self.env.render(mode=self.render_mode)
def close(self):
"""Closes the environment."""
self.env.close()
def __str__(self):
"""Returns the wrapper name and the unwrapped environment string."""
return f"<{type(self).__name__}{self.env}>"
def __repr__(self):
"""Returns the string representation of the wrapper."""
return str(self)