Spaces:
Running
Running
File size: 2,527 Bytes
375a1cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
"""Wrapper for limiting the time steps of an environment."""
from typing import Optional
import gym
class TimeLimit(gym.Wrapper):
"""This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded.
If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued.
Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP.
Example:
>>> from gym.envs.classic_control import CartPoleEnv
>>> from gym.wrappers import TimeLimit
>>> env = CartPoleEnv()
>>> env = TimeLimit(env, max_episode_steps=1000)
"""
def __init__(
self,
env: gym.Env,
max_episode_steps: Optional[int] = None,
):
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
Args:
env: The environment to apply the wrapper
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
"""
super().__init__(env)
if max_episode_steps is None and self.env.spec is not None:
max_episode_steps = env.spec.max_episode_steps
if self.env.spec is not None:
self.env.spec.max_episode_steps = max_episode_steps
self._max_episode_steps = max_episode_steps
self._elapsed_steps = None
def step(self, action):
"""Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate.
Args:
action: The environment step action
Returns:
The environment step ``(observation, reward, terminated, truncated, info)`` with `truncated=True`
if the number of steps elapsed >= max episode steps
"""
observation, reward, terminated, truncated, info = self.env.step(action)
self._elapsed_steps += 1
if self._elapsed_steps >= self._max_episode_steps:
truncated = True
return observation, reward, terminated, truncated, info
def reset(self, **kwargs):
"""Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.
Args:
**kwargs: The kwargs to reset the environment with
Returns:
The reset environment
"""
self._elapsed_steps = 0
return self.env.reset(**kwargs)
|