"""Wrapper for rescaling actions to within a max and min action.""" from typing import Union import numpy as np import gym from gym import spaces class RescaleAction(gym.ActionWrapper): """Affinely rescales the continuous action space of the environment to the range [min_action, max_action]. The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action` or :attr:`max_action` are numpy arrays, the shape must match the shape of the environment's action space. Example: >>> import gym >>> env = gym.make('BipedalWalker-v3') >>> env.action_space Box(-1.0, 1.0, (4,), float32) >>> min_action = -0.5 >>> max_action = np.array([0.0, 0.5, 1.0, 0.75]) >>> env = RescaleAction(env, min_action=min_action, max_action=max_action) >>> env.action_space Box(-0.5, [0. 0.5 1. 0.75], (4,), float32) >>> RescaleAction(env, min_action, max_action).action_space == gym.spaces.Box(min_action, max_action) True """ def __init__( self, env: gym.Env, min_action: Union[float, int, np.ndarray], max_action: Union[float, int, np.ndarray], ): """Initializes the :class:`RescaleAction` wrapper. Args: env (Env): The environment to apply the wrapper min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar. max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar. """ assert isinstance( env.action_space, spaces.Box ), f"expected Box action space, got {type(env.action_space)}" assert np.less_equal(min_action, max_action).all(), (min_action, max_action) super().__init__(env) self.min_action = ( np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action ) self.max_action = ( np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + max_action ) self.action_space = spaces.Box( low=min_action, high=max_action, shape=env.action_space.shape, dtype=env.action_space.dtype, ) def action(self, action): """Rescales the action affinely from [:attr:`min_action`, :attr:`max_action`] to the action space of the base environment, :attr:`env`. Args: action: The action to rescale Returns: The rescaled action """ assert np.all(np.greater_equal(action, self.min_action)), ( action, self.min_action, ) assert np.all(np.less_equal(action, self.max_action)), (action, self.max_action) low = self.env.action_space.low high = self.env.action_space.high action = low + (high - low) * ( (action - self.min_action) / (self.max_action - self.min_action) ) action = np.clip(action, low, high) return action