Spaces:
Running
Running
"""Wrapper for rescaling actions to within a max and min action.""" | |
from typing import Union | |
import numpy as np | |
import gym | |
from gym import spaces | |
class RescaleAction(gym.ActionWrapper): | |
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action]. | |
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action` | |
or :attr:`max_action` are numpy arrays, the shape must match the shape of the environment's action space. | |
Example: | |
>>> import gym | |
>>> env = gym.make('BipedalWalker-v3') | |
>>> env.action_space | |
Box(-1.0, 1.0, (4,), float32) | |
>>> min_action = -0.5 | |
>>> max_action = np.array([0.0, 0.5, 1.0, 0.75]) | |
>>> env = RescaleAction(env, min_action=min_action, max_action=max_action) | |
>>> env.action_space | |
Box(-0.5, [0. 0.5 1. 0.75], (4,), float32) | |
>>> RescaleAction(env, min_action, max_action).action_space == gym.spaces.Box(min_action, max_action) | |
True | |
""" | |
def __init__( | |
self, | |
env: gym.Env, | |
min_action: Union[float, int, np.ndarray], | |
max_action: Union[float, int, np.ndarray], | |
): | |
"""Initializes the :class:`RescaleAction` wrapper. | |
Args: | |
env (Env): The environment to apply the wrapper | |
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar. | |
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar. | |
""" | |
assert isinstance( | |
env.action_space, spaces.Box | |
), f"expected Box action space, got {type(env.action_space)}" | |
assert np.less_equal(min_action, max_action).all(), (min_action, max_action) | |
super().__init__(env) | |
self.min_action = ( | |
np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action | |
) | |
self.max_action = ( | |
np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + max_action | |
) | |
self.action_space = spaces.Box( | |
low=min_action, | |
high=max_action, | |
shape=env.action_space.shape, | |
dtype=env.action_space.dtype, | |
) | |
def action(self, action): | |
"""Rescales the action affinely from [:attr:`min_action`, :attr:`max_action`] to the action space of the base environment, :attr:`env`. | |
Args: | |
action: The action to rescale | |
Returns: | |
The rescaled action | |
""" | |
assert np.all(np.greater_equal(action, self.min_action)), ( | |
action, | |
self.min_action, | |
) | |
assert np.all(np.less_equal(action, self.max_action)), (action, self.max_action) | |
low = self.env.action_space.low | |
high = self.env.action_space.high | |
action = low + (high - low) * ( | |
(action - self.min_action) / (self.max_action - self.min_action) | |
) | |
action = np.clip(action, low, high) | |
return action | |